diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml new file mode 100644 index 0000000..97cd358 --- /dev/null +++ b/.github/workflows/formatting.yml @@ -0,0 +1,21 @@ +name: Code Formatting + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + - name: Install dependencies + run: | + python -m pip install --upgrade pip isort + python -m pip install pre-commit + - name: Run pre-commit + run: python -m pre_commit run --all-files diff --git a/.gitignore b/.gitignore index b6e4761..baef139 100644 --- a/.gitignore +++ b/.gitignore @@ -98,6 +98,11 @@ __pypackages__/ celerybeat-schedule celerybeat.pid + +out +traces +*.npy +*.out # SageMath parsed files *.sage.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d476f32..f44eaca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,4 +14,4 @@ repos: rev: 5.13.2 hooks: - id: isort - name: isort (python) \ No newline at end of file + name: isort (python) diff --git a/benchmarks/bench_pm.py b/benchmarks/bench_pm.py new file mode 100644 index 0000000..2bf4534 --- /dev/null +++ b/benchmarks/bench_pm.py @@ -0,0 +1,261 @@ +import os + +os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax +import jax + +jax.distributed.initialize() + +rank = jax.process_index() +size = jax.process_count() + +import argparse +import time + +import jax.numpy as jnp +import jax_cosmo as jc +import numpy as np +from cupy.cuda.nvtx import RangePop, RangePush +from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, + PIDController, SaveAt, Tsit5, diffeqsolve) +from hpc_plotter.timer import Timer +from jax.experimental import mesh_utils +from jax.experimental.multihost_utils import sync_global_devices +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +from jaxpm.kernels import interpolate_power_spectrum +from jaxpm.painting import cic_paint_dx +from jaxpm.pm import linear_field, lpt, make_ode_fn + + +def run_simulation(mesh_shape, + box_size, + halo_size, + solver_choice, + iterations, + pdims=None): + + @jax.jit + def simulate(omega_c, sigma8): + # Create a small function to generate the matter power spectrum + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power( + jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) + pk_fn = lambda x: interpolate_power_spectrum(x, k, pk) + + # Create initial conditions + initial_conditions = linear_field(mesh_shape, + box_size, + pk_fn, + seed=jax.random.PRNGKey(0)) + + # Create particles + cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) + dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) + if solver_choice == "Dopri5": + solver = Dopri5() + elif solver_choice == "LeapfrogMidpoint": + solver = LeapfrogMidpoint() + elif solver_choice == "Tsit5": + solver = Tsit5() + elif solver_choice == "lpt": + lpt_field = cic_paint_dx(dx, halo_size=halo_size) + print(f"TYPE of lpt_field: {type(lpt_field)}") + return lpt_field, {"num_steps": 0} + else: + raise ValueError( + "Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.") + # Evolve the simulation forward + ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) + term = ODETerm( + lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) + + if solver_choice == "Dopri5" or solver_choice == "Tsit5": + stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) + elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler": + stepsize_controller = ConstantStepSize() + res = diffeqsolve(term, + solver, + t0=0.1, + t1=1., + dt0=0.01, + y0=jnp.stack([dx, p], axis=0), + args=cosmo, + saveat=SaveAt(t1=True), + stepsize_controller=stepsize_controller) + + # Return the simulation volume at requested + state = res.ys[-1] + final_field = cic_paint_dx(state[0], halo_size=halo_size) + + return final_field, res.stats + + def run(): + # Warm start + chrono_fun = Timer() + RangePush("warmup") + final_field, stats = chrono_fun.chrono_jit(simulate, + 0.32, + 0.8, + ndarray_arg=0) + RangePop() + sync_global_devices("warmup") + for i in range(iterations): + RangePush(f"sim iter {i}") + final_field, stats = chrono_fun.chrono_fun(simulate, + 0.32, + 0.8, + ndarray_arg=0) + RangePop() + return final_field, stats, chrono_fun + + if jax.device_count() > 1: + devices = mesh_utils.create_device_mesh(pdims) + mesh = Mesh(devices.T, axis_names=('x', 'y')) + with mesh: + # Warm start + final_field, stats, chrono_fun = run() + else: + final_field, stats, chrono_fun = run() + + return final_field, stats, chrono_fun + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='JAX Cosmo Simulation Benchmark') + parser.add_argument('-m', + '--mesh_size', + type=int, + help='Mesh size', + required=True) + parser.add_argument('-b', + '--box_size', + type=float, + help='Box size', + required=True) + parser.add_argument('-p', + '--pdims', + type=str, + help='Processor dimensions', + default=None) + parser.add_argument( + '-pr', + '--precision', + type=str, + help='Precision', + choices=["float32", "float64"], + ) + parser.add_argument('-hs', + '--halo_size', + type=int, + help='Halo size', + required=True) + parser.add_argument('-s', + '--solver', + type=str, + help='Solver', + choices=[ + "Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5", + "LeapfrogMidpoint", "leapfrogmidpoint", "lfm", + "lpt" + ], + required=True) + parser.add_argument('-i', + '--iterations', + type=int, + help='Number of iterations', + default=10) + parser.add_argument('-o', + '--output_path', + type=str, + help='Output path', + default=".") + parser.add_argument('-f', + '--save_fields', + action='store_true', + help='Save fields') + parser.add_argument('-n', + '--nodes', + type=int, + help='Number of nodes', + default=1) + + args = parser.parse_args() + mesh_size = args.mesh_size + box_size = [args.box_size] * 3 + halo_size = args.halo_size + solver_choice = args.solver + iterations = args.iterations + output_path = args.output_path + os.makedirs(output_path, exist_ok=True) + + print(f"solver choice: {solver_choice}") + match solver_choice: + case "Dopri5" | "dopri5" | "d5": + solver_choice = "Dopri5" + case "Tsit5" | "tsit5" | "t5": + solver_choice = "Tsit5" + case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm": + solver_choice = "LeapfrogMidpoint" + case "lpt": + solver_choice = "lpt" + case _: + raise ValueError( + "Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt" + ) + if args.precision == "float32": + jax.config.update("jax_enable_x64", False) + elif args.precision == "float64": + jax.config.update("jax_enable_x64", True) + + if args.pdims: + pdims = tuple(map(int, args.pdims.split("x"))) + else: + pdims = (1, 1) + + mesh_shape = [mesh_size] * 3 + + final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, + halo_size, solver_choice, + iterations, pdims) + + print( + f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}" + ) + + metadata = { + 'rank': rank, + 'function_name': f'JAXPM-{solver_choice}', + 'precision': args.precision, + 'x': str(mesh_size), + 'y': str(mesh_size), + 'z': str(stats["num_steps"]), + 'px': str(pdims[0]), + 'py': str(pdims[1]), + 'backend': 'NCCL', + 'nodes': str(args.nodes) + } + # Print the results to a CSV file + chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) + + # Save the final field + nb_gpus = jax.device_count() + pdm_str = f"{pdims[0]}x{pdims[1]}" + field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" + os.makedirs(field_folder, exist_ok=True) + with open(f'{field_folder}/jaxpm.log', 'w') as f: + f.write(f"Args: {args}\n") + f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") + for i, time in enumerate(chrono_fun.times): + f.write(f"Time {i}: {time:.4f} ms\n") + f.write(f"Stats: {stats}\n") + if args.save_fields: + np.save(f'{field_folder}/final_field_0_{rank}.npy', + final_field.addressable_data(0)) + + print(f"Finished! ") + print(f"Stats {stats}") + print(f"Saving to {output_path}/jax_pm_benchmark.csv") + print(f"Saving field and logs in {field_folder}") diff --git a/benchmarks/bench_pmwd.py b/benchmarks/bench_pmwd.py new file mode 100644 index 0000000..bd11303 --- /dev/null +++ b/benchmarks/bench_pmwd.py @@ -0,0 +1,159 @@ +import os + +# Change JAX GPU memory preallocation fraction +os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95' + +import argparse + +import jax +import matplotlib.pyplot as plt +import numpy as np +from hpc_plotter.timer import Timer +from pmwd import (Configuration, Cosmology, SimpleLCDM, boltzmann, growth, + linear_modes, linear_power, lpt, nbody, scatter, white_noise) +from pmwd.pm_util import fftinv +from pmwd.spec_util import powspec +from pmwd.vis_util import simshow + + +# Simulation configuration +def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver, iterations): + + @jax.jit + def simulate(omega_m, sigma8): + + conf = Configuration(ptcl_spacing, + ptcl_grid_shape=ptcl_grid_shape, + mesh_shape=1, + lpt_order=1, + a_nbody_maxstep=1 / 91) + print(conf) + print( + f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.' + ) + + cosmo = Cosmology(conf, + A_s_1e9=2.0, + n_s=0.96, + Omega_m=omega_m, + Omega_b=sigma8, + h=0.7) + print(cosmo) + + # Boltzmann calculation + cosmo = boltzmann(cosmo, conf) + print("Boltzmann calculation completed.") + + # Generate white noise field and scale with the linear power spectrum + seed = 0 + modes = white_noise(seed, conf) + modes = linear_modes(modes, cosmo, conf) + print("Linear modes generated.") + + # Solve LPT at some early time + ptcl, obsvbl = lpt(modes, cosmo, conf) + print("LPT solved.") + + if solver == "lfm": + # N-body time integration from LPT initial conditions + ptcl, obsvbl = jax.block_until_ready( + nbody(ptcl, obsvbl, cosmo, conf)) + print("N-body time integration completed.") + + # Scatter particles to mesh to get the density field + dens = scatter(ptcl, conf) + return dens + + chrono_timer = Timer() + final_field = chrono_timer.chrono_jit(simulate, 0.3, 0.05) + + for _ in range(iterations): + final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05) + + return final_field, chrono_timer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PMWD Simulation') + parser.add_argument('-m', + '--mesh_size', + type=int, + help='Mesh size', + required=True) + parser.add_argument('-b', + '--box_size', + type=float, + help='Box size', + required=True) + parser.add_argument('-i', + '--iterations', + type=int, + help='Number of iterations', + default=10) + parser.add_argument('-o', + '--output_path', + type=str, + help='Output path', + default=".") + parser.add_argument('-f', + '--save_fields', + action='store_true', + help='Save fields') + parser.add_argument('-s', + '--solver', + type=str, + help='Solver', + choices=["lfm", "lpt"]) + parser.add_argument( + '-pr', + '--precision', + type=str, + help='Precision', + choices=["float32", "float64"], + ) + + args = parser.parse_args() + + mesh_shape = [args.mesh_size] * 3 + ptcl_spacing = args.box_size / args.mesh_size + iterations = args.iterations + solver = args.solver + output_path = args.output_path + if args.precision == "float32": + jax.config.update("jax_enable_x64", False) + elif args.precision == "float64": + jax.config.update("jax_enable_x64", True) + + os.makedirs(output_path, exist_ok=True) + + final_field, chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, + solver, iterations) + print("PMWD simulation completed.") + + metadata = { + 'rank': 0, + 'function_name': f'PMWD-{solver}', + 'precision': args.precision, + 'x': str(mesh_shape[0]), + 'y': str(mesh_shape[0]), + 'z': str(mesh_shape[0]), + 'px': "1", + 'py': "1", + 'backend': 'NCCL', + 'nodes': "1" + } + chrono_fun.print_to_csv(f"{output_path}/pmwd.csv", **metadata) + field_folder = f"{output_path}/final_field/pmwd/1/{args.mesh_size}_{int(args.box_size)}/1x1/{args.solver}/halo_0" + os.makedirs(field_folder, exist_ok=True) + with open(f"{field_folder}/pmwd.log", "w") as f: + f.write(f"PMWD simulation completed.\n") + f.write(f"Args : {args}\n") + f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") + for i, time in enumerate(chrono_fun.times): + f.write(f"Time {i}: {time:.4f} ms\n") + if args.save_fields: + np.save(f"{field_folder}/final_field_0_0.npy", final_field) + print("Fields saved.") + + print(f"saving to {output_path}/pmwd.csv") + print(f"saving field and logs to {field_folder}/pmwd.log") diff --git a/benchmarks/particle_mesh_a100.slurm b/benchmarks/particle_mesh_a100.slurm new file mode 100644 index 0000000..a94019c --- /dev/null +++ b/benchmarks/particle_mesh_a100.slurm @@ -0,0 +1,179 @@ +#!/bin/bash +########################################## +## SELECT EITHER tkc@a100 OR tkc@v100 ## +########################################## +#SBATCH --account tkc@a100 +########################################## +#SBATCH --job-name=1N-FFT-Mesh # nom du job +# Il est possible d'utiliser une autre partition que celle par default +# en activant l'une des 5 directives suivantes : +########################################## +## SELECT EITHER a100 or v100-32g ## +########################################## +#SBATCH -C a100 +########################################## +#****************************************** +########################################## +## SELECT Number of nodes and GPUs per node +## For A100 ntasks-per-node and gres=gpu should be 8 +## For V100 ntasks-per-node and gres=gpu should be 4 +########################################## +#SBATCH --nodes=1 # nombre de noeud +#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud) +#SBATCH --gres=gpu:8 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) +########################################## +## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant +## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant +## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: +########################################## +#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) +########################################## +# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm +#SBATCH --hint=nomultithread # hyperthreading desactive +#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) +#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie +#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) +##SBATCH --qos=qos_gpu-dev +## SBATCH --exclusive # ressources dediees +# Nettoyage des modules charges en interactif et herites par defaut +num_nodes=$SLURM_JOB_NUM_NODES +num_gpu_per_node=$SLURM_NTASKS_PER_NODE +OUTPUT_FOLDER_ARGS=1 +# Calculate the number of GPUs +nb_gpus=$(( num_nodes * num_gpu_per_node)) + +module purge + +echo "Job constraint: $SLURM_JOB_CONSTRAINT" +echo "Job partition: $SLURM_JOB_PARTITION" +# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" +# pour avoir acces aux modules compatibles avec cette partition + +if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then + module load cpuarch/amd + source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate + gpu_name=a100 +else + source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate + gpu_name=v100 +fi + +# Chargement des modules +module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake +module load nvidia-nsight-systems/2024.1.1.59 + + +echo "The number of nodes allocated for this job is: $num_nodes" +echo "The number of GPUs allocated for this job is: $nb_gpus" + +export ENABLE_PERFO_STEP=NVTX +export MPI4JAX_USE_CUDA_MPI=1 +function profile_python() { + if [ $# -lt 1 ]; then + echo "Usage: profile_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="prof_traces/$script_name" + local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="prof_traces/$script_name/$args" + report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" + fi + + mkdir -p "$output_dir" + mkdir -p "$report_dir" + + srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + +function run_python() { + if [ $# -lt 1 ]; then + echo "Usage: run_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="traces/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="traces/$script_name/$args" + fi + + mkdir -p "$output_dir" + + srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + + +# run or profile + +function slaunch() { + run_python "$@" +} + +function plaunch() { + profile_python "$@" +} + +# Echo des commandes lancees +set -x + +# Pour ne pas utiliser le /tmp +export TMPDIR=$JOBSCRATCH +# Pour contourner un bogue dans les versions actuelles de Nsight Systems +# il est également nécessaire de créer un lien symbolique permettant de +# faire pointer le répertoire /tmp/nvidia vers TMPDIR +ln -s $JOBSCRATCH /tmp/nvidia + +declare -A pdims_table +# Define the table +pdims_table[1]="1x1" +pdims_table[4]="2x2 1x4 4x1" +pdims_table[8]="2x4 1x8 8x1 4x2" +pdims_table[16]="4x4 1x16 16x1" +pdims_table[32]="4x8 8x4 1x32 32x1" +pdims_table[64]="8x8 16x4 1x64 64x1" +pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2" +pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4" + + +# mpch=(128 256 512 1024 2048 4096) +grid=(256 512 1024 2048 4096) +precisions=(float32 float64) +pdim="${pdims_table[$nb_gpus]}" +solvers=(lpt lfm) +echo "pdims: $pdim" + +# Check if pdims is not empty +if [ -z "$pdim" ]; then + echo "pdims is empty" + echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160" + echo "Number of nodes selected: $num_nodes" + echo "Number of gpus per node: $num_gpu_per_node" + exit 1 +fi + +# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 +out_dir="pm_prof/$gpu_name/$nb_gpus" + +echo "Output dir is : $out_dir" + +for pr in "${precisions[@]}"; do + for g in "${grid[@]}"; do + for solver in "${solvers[@]}"; do + for p in $pdim; do + halo_size=$((g / 4)) + slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes + done + done + done +done diff --git a/benchmarks/particle_mesh_v100.slurm b/benchmarks/particle_mesh_v100.slurm new file mode 100644 index 0000000..9446b9b --- /dev/null +++ b/benchmarks/particle_mesh_v100.slurm @@ -0,0 +1,181 @@ +#!/bin/bash +########################################## +## SELECT EITHER tkc@a100 OR tkc@v100 ## +########################################## +#SBATCH --account tkc@v100 +########################################## +#SBATCH --job-name=V100Particle-Mesh # nom du job +# Il est possible d'utiliser une autre partition que celle par default +# en activant l'une des 5 directives suivantes : +########################################## +## SELECT EITHER a100 or v100-32g ## +########################################## +#SBATCH -C v100-32g +########################################## +#****************************************** +########################################## +## SELECT Number of nodes and GPUs per node +## For A100 ntasks-per-node and gres=gpu should be 8 +## For V100 ntasks-per-node and gres=gpu should be 4 +########################################## +#SBATCH --nodes=1 # nombre de noeud +#SBATCH --ntasks-per-node=4 # nombre de tache MPI par noeud (= nombre de GPU par noeud) +#SBATCH --gres=gpu:4 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) +########################################## +## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant +## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant +## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: +########################################## +#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) +########################################## +# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm +#SBATCH --hint=nomultithread # hyperthreading desactive +#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS) +#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie +#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) +#SBATCH --qos=qos_gpu-dev +#SBATCH --exclusive # ressources dediees +# Nettoyage des modules charges en interactif et herites par defaut +num_nodes=$SLURM_JOB_NUM_NODES +num_gpu_per_node=$SLURM_NTASKS_PER_NODE +OUTPUT_FOLDER_ARGS=1 +# Calculate the number of GPUs +nb_gpus=$(( num_nodes * num_gpu_per_node)) + +module purge + +echo "Job constraint: $SLURM_JOB_CONSTRAINT" +echo "Job partition: $SLURM_JOB_PARTITION" +# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" +# pour avoir acces aux modules compatibles avec cette partition + +if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then + module load cpuarch/amd + source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate + gpu_name=a100 +else + source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate + gpu_name=v100 +fi + +# Chargement des modules +module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake +module load nvidia-nsight-systems/2024.1.1.59 + + +echo "The number of nodes allocated for this job is: $num_nodes" +echo "The number of GPUs allocated for this job is: $nb_gpus" + +export EQX_ON_ERROR=nan +export CUDA_ALLOC=1 +export ENABLE_PERFO_STEP=NVTX +export MPI4JAX_USE_CUDA_MPI=1 +function profile_python() { + if [ $# -lt 1 ]; then + echo "Usage: profile_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="prof_traces/$script_name" + local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="prof_traces/$script_name/$args" + report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" + fi + + mkdir -p "$output_dir" + mkdir -p "$report_dir" + + srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + +function run_python() { + if [ $# -lt 1 ]; then + echo "Usage: run_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="traces/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="traces/$script_name/$args" + fi + + mkdir -p "$output_dir" + + srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + + +# run or profile + +function slaunch() { + run_python "$@" +} + +function plaunch() { + profile_python "$@" +} + +# Echo des commandes lancees +set -x + +# Pour ne pas utiliser le /tmp +export TMPDIR=$JOBSCRATCH +# Pour contourner un bogue dans les versions actuelles de Nsight Systems +# il est également nécessaire de créer un lien symbolique permettant de +# faire pointer le répertoire /tmp/nvidia vers TMPDIR +ln -s $JOBSCRATCH /tmp/nvidia + +declare -A pdims_table +# Define the table +pdims_table[1]="1x1" +pdims_table[4]="2x2 1x4 4x1" +pdims_table[8]="2x4 1x8 8x1 4x2" +pdims_table[16]="4x4 1x16 16x1" +pdims_table[32]="4x8 8x4 1x32 32x1" +pdims_table[64]="8x8 16x4 1x64 64x1" +pdims_table[128]="8x16 16x8 4x32 1x128 128x1" +pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4" + + +# mpch=(128 256 512 1024 2048 4096) +grid=(256 512 1024 2048 4096) +precisions=(float32 float64) +pdim="${pdims_table[$nb_gpus]}" +solvers=(lpt lfm) +echo "pdims: $pdim" + +# Check if pdims is not empty +if [ -z "$pdim" ]; then + echo "pdims is empty" + echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160" + echo "Number of nodes selected: $num_nodes" + echo "Number of gpus per node: $num_gpu_per_node" + exit 1 +fi + +# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 +out_dir="pm_prof/$gpu_name/$nb_gpus" + +echo "Output dir is : $out_dir" + +for pr in "${precisions[@]}"; do + for g in "${grid[@]}"; do + for solver in "${solvers[@]}"; do + for p in $pdim; do + halo_size=$((g / 4)) + slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes + done + done + done +done diff --git a/benchmarks/pmwd_a100.slurm b/benchmarks/pmwd_a100.slurm new file mode 100644 index 0000000..f57f0ac --- /dev/null +++ b/benchmarks/pmwd_a100.slurm @@ -0,0 +1,162 @@ +#!/bin/bash +########################################## +## SELECT EITHER tkc@a100 OR tkc@v100 ## +########################################## +#SBATCH --account tkc@a100 +########################################## +#SBATCH --job-name=1N-FFT-Mesh # nom du job +# Il est possible d'utiliser une autre partition que celle par default +# en activant l'une des 5 directives suivantes : +########################################## +## SELECT EITHER a100 or v100-32g ## +########################################## +#SBATCH -C a100 +########################################## +#****************************************** +########################################## +## SELECT Number of nodes and GPUs per node +## For A100 ntasks-per-node and gres=gpu should be 8 +## For V100 ntasks-per-node and gres=gpu should be 4 +########################################## +#SBATCH --nodes=1 # nombre de noeud +#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud) +#SBATCH --gres=gpu:1 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) +########################################## +## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant +## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant +## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: +########################################## +#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) +########################################## +# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm +#SBATCH --hint=nomultithread # hyperthreading desactive +#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) +#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie +#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) +##SBATCH --qos=qos_gpu-dev +## SBATCH --exclusive # ressources dediees +# Nettoyage des modules charges en interactif et herites par defaut +num_nodes=$SLURM_JOB_NUM_NODES +num_gpu_per_node=$SLURM_NTASKS_PER_NODE +OUTPUT_FOLDER_ARGS=1 +# Calculate the number of GPUs +nb_gpus=$(( num_nodes * num_gpu_per_node)) + +module purge + +echo "Job constraint: $SLURM_JOB_CONSTRAINT" +echo "Job partition: $SLURM_JOB_PARTITION" +# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" +# pour avoir acces aux modules compatibles avec cette partition + +if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then + module load cpuarch/amd + source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate + gpu_name=a100 +else + source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate + gpu_name=v100 +fi + +# Chargement des modules +module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake +module load nvidia-nsight-systems/2024.1.1.59 + + +echo "The number of nodes allocated for this job is: $num_nodes" +echo "The number of GPUs allocated for this job is: $nb_gpus" + +export EQX_ON_ERROR=nan +export CUDA_ALLOC=1 +export ENABLE_PERFO_STEP=NVTX +export MPI4JAX_USE_CUDA_MPI=1 +function profile_python() { + if [ $# -lt 1 ]; then + echo "Usage: profile_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="prof_traces/$script_name" + local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="prof_traces/$script_name/$args" + report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" + fi + + mkdir -p "$output_dir" + mkdir -p "$report_dir" + + srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + +function run_python() { + if [ $# -lt 1 ]; then + echo "Usage: run_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="traces/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="traces/$script_name/$args" + fi + + mkdir -p "$output_dir" + + srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + + +# run or profile + +function slaunch() { + run_python "$@" +} + +function plaunch() { + profile_python "$@" +} + +# Echo des commandes lancees +set -x + +# Pour ne pas utiliser le /tmp +export TMPDIR=$JOBSCRATCH +# Pour contourner un bogue dans les versions actuelles de Nsight Systems +# il est également nécessaire de créer un lien symbolique permettant de +# faire pointer le répertoire /tmp/nvidia vers TMPDIR +ln -s $JOBSCRATCH /tmp/nvidia + +# mpch=(128 256 512 1024 2048 4096) +grid=(256 512 1024 2048 4096) +precisions=(float32 float64) +solvers=(lpt lfm) + +# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 + +if [ $num_gpu_per_node -eq 8 ]; then + gpu_name="a100" +else + gpu_name="v100" +fi + +out_dir="pm_prof/$gpu_name/$nb_gpus" + +echo "Output dir is : $out_dir" + +for pr in "${precisions[@]}"; do + for g in "${grid[@]}"; do + for solver in "${solvers[@]}"; do + launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f + done + done +done diff --git a/benchmarks/pmwd_v100.slurm b/benchmarks/pmwd_v100.slurm new file mode 100644 index 0000000..9ca5f89 --- /dev/null +++ b/benchmarks/pmwd_v100.slurm @@ -0,0 +1,167 @@ +#!/bin/bash +########################################## +## SELECT EITHER tkc@a100 OR tkc@v100 ## +########################################## +#SBATCH --account tkc@v100 +########################################## +#SBATCH --job-name=16N-V100Particle-Mesh # nom du job +# Il est possible d'utiliser une autre partition que celle par default +# en activant l'une des 5 directives suivantes : +########################################## +## SELECT EITHER a100 or v100-32g ## +########################################## +#SBATCH -C v100-32g +########################################## +#****************************************** +########################################## +## SELECT Number of nodes and GPUs per node +## For A100 ntasks-per-node and gres=gpu should be 8 +## For V100 ntasks-per-node and gres=gpu should be 4 +########################################## +#SBATCH --nodes=1 # nombre de noeud +#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud) +#SBATCH --gres=gpu:1 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) +########################################## +## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant +## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant +## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: +########################################## +#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) +########################################## +# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm +#SBATCH --hint=nomultithread # hyperthreading desactive +#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS) +#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie +#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) +#SBATCH --qos=qos_gpu-dev +#SBATCH --exclusive # ressources dediees +# Nettoyage des modules charges en interactif et herites par defaut +num_nodes=$SLURM_JOB_NUM_NODES +num_gpu_per_node=$SLURM_NTASKS_PER_NODE +OUTPUT_FOLDER_ARGS=1 +# Calculate the number of GPUs +nb_gpus=$(( num_nodes * num_gpu_per_node)) + +module purge + +echo "Job constraint: $SLURM_JOB_CONSTRAINT" +echo "Job partition: $SLURM_JOB_PARTITION" +# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" +# pour avoir acces aux modules compatibles avec cette partition + +if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then + module load cpuarch/amd + source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate + gpu_name=a100 +else + source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate + gpu_name=v100 +fi + +# Chargement des modules +module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake +module load nvidia-nsight-systems/2024.1.1.59 + + +echo "The number of nodes allocated for this job is: $num_nodes" +echo "The number of GPUs allocated for this job is: $nb_gpus" + +export EQX_ON_ERROR=nan +export CUDA_ALLOC=1 +export ENABLE_PERFO_STEP=NVTX +export MPI4JAX_USE_CUDA_MPI=1 +function profile_python() { + if [ $# -lt 1 ]; then + echo "Usage: profile_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="prof_traces/$script_name" + local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="prof_traces/$script_name/$args" + report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" + fi + + mkdir -p "$output_dir" + mkdir -p "$report_dir" + + srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + +function run_python() { + if [ $# -lt 1 ]; then + echo "Usage: run_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="traces/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="traces/$script_name/$args" + fi + + mkdir -p "$output_dir" + + srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + + +# run or profile + +function slaunch() { + run_python "$@" +} + +function plaunch() { + profile_python "$@" +} + +# Echo des commandes lancees +set -x + +# Pour ne pas utiliser le /tmp +export TMPDIR=$JOBSCRATCH +# Pour contourner un bogue dans les versions actuelles de Nsight Systems +# il est également nécessaire de créer un lien symbolique permettant de +# faire pointer le répertoire /tmp/nvidia vers TMPDIR +ln -s $JOBSCRATCH /tmp/nvidia + + + + +# mpch=(128 256 512 1024 2048 4096) +grid=(256 512 1024 2048 4096) +precisions=(float32 float64) +pdim="${pdims_table[$nb_gpus]}" +solvers=(lpt lfm) + +# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 + +if [ $num_gpu_per_node -eq 8 ]; then + gpu_name="a100" +else + gpu_name="v100" +fi + +out_dir="pm_prof/$gpu_name/$nb_gpus" + +echo "Output dir is : $out_dir" + + +for pr in "${precisions[@]}"; do + for g in "${grid[@]}"; do + for solver in "${solvers[@]}"; do + slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f + done + done +done diff --git a/benchmarks/run_all_jobs.sh b/benchmarks/run_all_jobs.sh new file mode 100755 index 0000000..cfee491 --- /dev/null +++ b/benchmarks/run_all_jobs.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Run all slurms jobs +nodes_v100=(1 2 4 8 16) +nodes_a100=(1 2 4 8 16) + + +for n in ${nodes_v100[@]}; do + sbatch --nodes=$n --job-name=v100_$n-JAXPM particle_mesh_v100.slurm +done + +for n in ${nodes_a100[@]}; do + sbatch --nodes=$n --job-name=a100_$n-JAXPM particle_mesh_a100.slurm +done + +# single GPUs +sbatch --job-name=JAXPM-1GPU-V100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_v100.slurm +sbatch --job-name=JAXPM-1GPU-A100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_a100.slurm +sbatch --job-name=PMWD-v100 pmwd_v100.slurm +sbatch --job-name=PMWD-a100 pmwd_a100.slurm diff --git a/dev/jaxdecomp.py b/dev/jaxdecomp.py new file mode 100644 index 0000000..ddb19e5 --- /dev/null +++ b/dev/jaxdecomp.py @@ -0,0 +1,69 @@ +import argparse + +import jax +import numpy as np + +# Setting up distributed jax +jax.distributed.initialize() +rank = jax.process_index() +size = jax.process_count() + +import jax.numpy as jnp +import jax_cosmo as jc +from jax.experimental import mesh_utils +from jax.sharding import Mesh + +from jaxpm.painting import cic_paint +from jaxpm.pm import linear_field, lpt + +mesh_shape = [256, 256, 256] +box_size = [256., 256., 256.] +snapshots = jnp.linspace(0.1, 1., 2) + + +@jax.jit +def run_simulation(omega_c, sigma8, seed): + # Create a cosmology + cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) + + # Create a small function to generate the matter power spectrum + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power( + jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) + pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk + ).reshape(x.shape) + + # Create initial conditions + initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed) + + # Initialize particle displacements + dx, p, f = lpt(cosmo, initial_conditions, 1.0) + + field = cic_paint(jnp.zeros_like(initial_conditions), dx) + return field + + +def main(args): + # Setting up distributed random numbers + master_key = jax.random.PRNGKey(42) + key = jax.random.split(master_key, size)[rank] + + # Create computing mesh and sharding information + devices = mesh_utils.create_device_mesh((2, 2)) + mesh = Mesh(devices.T, axis_names=('x', 'y')) + + # Run the simulation on the compute mesh + with mesh: + field = run_simulation(0.32, 0.8, key) + + print('done') + np.save(f'field_{rank}.npy', field.addressable_data(0)) + + # Closing distributed jax + jax.distributed.shutdown() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Distributed LPT N-body simulation.") + args = parser.parse_args() + main(args) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py new file mode 100644 index 0000000..54377cb --- /dev/null +++ b/jaxpm/distributed.py @@ -0,0 +1,170 @@ +from typing import Any, Callable, Hashable + +Specs = Any +AxisName = Hashable + +try: + import jaxdecomp + distributed = True +except ImportError: + print("jaxdecomp not installed. Distributed functions will not work.") + distributed = False + +from functools import partial + +import jax +import jax.numpy as jnp +from jax._src import mesh as mesh_lib +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P + +# NOTE +# This should not be used as a decorator +# Must be used inside a function only +# Example +# BAD +# @autoshmap +# def foo(): +# pass +# GOOD +# def foo(): +# return autoshmap(foo_impl)() + + +def autoshmap(f: Callable, + in_specs: Specs, + out_specs: Specs, + check_rep: bool = True, + auto: frozenset[AxisName] = frozenset(), + in_fourrier_space=False) -> Callable: + """Helper function to wrap the provided function in a shard map if + the code is being executed in a mesh context.""" + mesh = mesh_lib.thread_resources.env.physical_mesh + if mesh.empty: + return f + else: + if in_fourrier_space and 1 in mesh.devices.shape: + in_specs, out_specs = switch_specs((in_specs, out_specs)) + return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) + + +def switch_specs(specs): + if isinstance(specs, P): + new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax + for ax in specs) + return P(*new_axes) + elif isinstance(specs, tuple): + return tuple(switch_specs(sub_spec) for sub_spec in specs) + else: + raise TypeError("Element must be either a PartitionSpec or a tuple") + + +def fft3d(x): + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): + return jaxdecomp.pfft3d(x.astype(jnp.complex64)) + else: + return jnp.fft.fftn(x.astype(jnp.complex64)) + + +def ifft3d(x): + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): + return jaxdecomp.pifft3d(x).real + else: + return jnp.fft.ifftn(x).real + + +def get_halo_size(halo_size): + mesh = mesh_lib.thread_resources.env.physical_mesh + if mesh.empty: + zero_ext = (0, 0, 0) + zero_tuple = (0, 0) + return (zero_tuple, zero_tuple, zero_tuple), zero_ext + else: + pdims = mesh.devices.shape + halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size) + halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size) + + halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2 + halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2 + return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext, 0)) + + +def halo_exchange(x, halo_extents, halo_periods=(True, True, True)): + mesh = mesh_lib.thread_resources.env.physical_mesh + if distributed and not (mesh.empty) and (halo_extents[0] > 0 + or halo_extents[1] > 0): + return jaxdecomp.halo_exchange(x, halo_extents, halo_periods) + else: + return x + + +def slice_unpad_impl(x, pad_width): + + halo_x, _ = pad_width[0] + halo_y, _ = pad_width[1] + # Apply corrections along x + x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2]) + x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:]) + # Apply corrections along y + x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2]) + x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:]) + + unpad_slice = [slice(None)] * 3 + if halo_x > 0: + unpad_slice[0] = slice(halo_x, -halo_x) + if halo_y > 0: + unpad_slice[1] = slice(halo_y, -halo_y) + + return x[tuple(unpad_slice)] + + +def slice_pad(x, pad_width): + mesh = mesh_lib.thread_resources.env.physical_mesh + if distributed and not (mesh.empty) and (pad_width[0][0] > 0 + or pad_width[1][0] > 0): + return autoshmap((partial(jnp.pad, pad_width=pad_width)), + in_specs=(P('x', 'y')), + out_specs=P('x', 'y'))(x) + else: + return x + + +def slice_unpad(x, pad_width): + mesh = mesh_lib.thread_resources.env.physical_mesh + if distributed and not (mesh.empty) and (pad_width[0][0] > 0 + or pad_width[1][0] > 0): + return autoshmap(partial(slice_unpad_impl, pad_width=pad_width), + in_specs=(P('x', 'y')), + out_specs=P('x', 'y'))(x) + else: + return x + + +def get_local_shape(mesh_shape): + """ Helper function to get the local size of a mesh given the global size. + """ + if mesh_lib.thread_resources.env.physical_mesh.empty: + return mesh_shape + else: + pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape + return [ + mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2] + ] + + +def normal_field(mesh_shape, seed=None): + """Generate a Gaussian random field with the given power spectrum.""" + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): + local_mesh_shape = get_local_shape(mesh_shape) + if seed is None: + key = None + else: + size = jax.process_count() + rank = jax.process_index() + key = jax.random.split(seed, size)[rank] + return autoshmap( + partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'), + in_specs=P(None), + out_specs=P('x', 'y'))(key) # yapf: disable + else: + return jax.random.normal(shape=mesh_shape, key=seed) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 8447f8a..d954132 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,23 +1,86 @@ +from enum import Enum +from functools import partial + import jax.numpy as jnp +import jax_cosmo as jc import numpy as np +from jax._src import mesh as mesh_lib +from jax.sharding import PartitionSpec as P + +from jaxpm.distributed import autoshmap + +class PencilType(Enum): + NO_DECOMP = 0 + SLAB_XY = 1 + SLAB_YZ = 2 + PENCILS = 3 + + +def get_pencil_type(): + mesh = mesh_lib.thread_resources.env.physical_mesh + if mesh.empty: + pdims = None + else: + pdims = mesh.devices.shape[::-1] + + if pdims == (1, 1) or pdims == None: + return PencilType.NO_DECOMP + elif pdims[0] == 1: + return PencilType.SLAB_XY + elif pdims[1] == 1: + return PencilType.SLAB_YZ + else: + return PencilType.PENCILS + + +def fftk(shape, dtype=np.float32): + """ + Generate Fourier transform wave numbers for a given mesh. -def fftk(shape, symmetric=True, finite=False, dtype=np.float32): - """ Return k_vector given a shape (nc, nc, nc) and box_size + Args: + nc (int): Shape of the mesh grid. + + Returns: + list: List of wave number arrays for each dimension in + the order [kx, ky, kz]. """ - k = [] - for d in range(len(shape)): - kd = np.fft.fftfreq(shape[d]) - kd *= 2 * np.pi - kdshape = np.ones(len(shape), dtype='int') - if symmetric and d == len(shape) - 1: - kd = kd[:shape[d] // 2 + 1] - kdshape[d] = len(kd) - kd = kd.reshape(kdshape) + kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape] + + @partial(autoshmap, + in_specs=(P('x'), P('y'), P(None)), + out_specs=(P('x'), P(None, 'y'), P(None)), + in_fourrier_space=True) + def get_kvec(ky, kz, kx): + return (ky.reshape([-1, 1, 1]), + kz.reshape([1, -1, 1]), + kx.reshape([1, 1, -1])) # yapf: disable + + pencil_type = get_pencil_type() + # YZ returns Y pencil + # XY and pencils returns a Z pencil + # NO_DECOMP returns a X pencil + if pencil_type == PencilType.NO_DECOMP: + kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil + elif pencil_type == PencilType.SLAB_YZ: + kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil + elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS: + ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil + else: + raise ValueError("Unknown pencil type") + + # to the order of dimensions in the transposed FFT + return kx, ky, kz - k.append(kd.astype(dtype)) - del kd, kdshape - return k + +def interpolate_power_spectrum(input, k, pk): + + pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk + ).reshape(x.shape) + return autoshmap(pk_fn, + in_specs=P('x', 'y'), + out_specs=P('x', 'y'), + in_fourrier_space=True)(input) def gradient_kernel(kvec, direction, order=1): @@ -60,11 +123,7 @@ def laplace_kernel(kvec): Complex kernel """ kk = sum(ki**2 for ki in kvec) - mask = (kk == 0).nonzero() - kk[mask] = 1 - wts = 1. / kk - imask = (~(kk == 0)).astype(int) - wts *= imask + wts = jnp.where(kk == 0, 1., 1. / kk) return wts diff --git a/jaxpm/painting.py b/jaxpm/painting.py index fb5dbd5..975e43c 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,15 +1,28 @@ +from functools import partial + import jax import jax.lax as lax import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange, + slice_pad, slice_unpad) from jaxpm.kernels import cic_compensation, fftk +from jaxpm.painting_utils import gather, scatter -def cic_paint(mesh, positions, weight=None): +def cic_paint_impl(mesh, displacement, weight=None): """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ + mesh: [nx, ny, nz] + displacement field: [nx, ny, nz, 3] + """ + part_shape = displacement.shape + positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]), + jnp.arange(part_shape[1]), + jnp.arange(part_shape[2]), + indexing='ij'), + axis=-1) + displacement + positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -34,11 +47,34 @@ def cic_paint(mesh, positions, weight=None): return mesh -def cic_read(mesh, positions): +@partial(jax.jit, static_argnums=(2, )) +def cic_paint(mesh, positions, halo_size=0, weight=None): + + halo_size, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_size) + mesh = autoshmap(cic_paint_impl, + in_specs=(P('x', 'y'), P('x', 'y'), P()), + out_specs=P('x', 'y'))(mesh, positions, weight) + mesh = halo_exchange(mesh, + halo_extents=halo_extents, + halo_periods=(True, True, True)) + mesh = slice_unpad(mesh, halo_size) + return mesh + + +def cic_read_impl(mesh, displacement): """ Paints positions onto mesh mesh: [nx, ny, nz] - positions: [npart, 3] + displacement: [nx,ny,nz, 3] """ + # Compute the position of the particles on a regular grid + part_shape = displacement.shape + positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]), + jnp.arange(part_shape[1]), + jnp.arange(part_shape[2]), + indexing='ij'), + axis=-1) + displacement + positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -52,7 +88,23 @@ def cic_read(mesh, positions): jnp.array(mesh.shape)) return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], - neighboor_coords[..., 3]] * kernel).sum(axis=-1) + neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape( + displacement.shape[:-1]) + + +@partial(jax.jit, static_argnums=(2, )) +def cic_read(mesh, displacement, halo_size=0): + + halo_size, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_size) + mesh = halo_exchange(mesh, + halo_extents=halo_extents, + halo_periods=(True, True, True)) + displacement = autoshmap(cic_read_impl, + in_specs=(P('x', 'y'), P('x', 'y')), + out_specs=P('x', 'y'))(mesh, displacement) + + return displacement def cic_paint_2d(mesh, positions, weight): @@ -84,6 +136,78 @@ def cic_paint_2d(mesh, positions, weight): return mesh +def cic_paint_dx_impl(displacements, halo_size): + + halo_x, _ = halo_size[0] + halo_y, _ = halo_size[1] + + original_shape = displacements.shape + particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') + + # Padding is forced to be zero in a single gpu run + + a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), + jnp.arange(particle_mesh.shape[1]), + jnp.arange(particle_mesh.shape[2]), + indexing='ij') + + particle_mesh = jnp.pad(particle_mesh, halo_size) + pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) + pmid = pmid.reshape([-1, 3]) + return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh) + + +@partial(jax.jit, static_argnums=(1, )) +def cic_paint_dx(displacements, halo_size=0): + + halo_size, halo_extents = get_halo_size(halo_size) + + mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), + in_specs=(P('x', 'y')), + out_specs=P('x', 'y'))(displacements) + + mesh = halo_exchange(mesh, + halo_extents=halo_extents, + halo_periods=(True, True, True)) + mesh = slice_unpad(mesh, halo_size) + return mesh + + +def cic_read_dx_impl(mesh, halo_size): + + halo_x, _ = halo_size[0] + halo_y, _ = halo_size[1] + + original_shape = [ + dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size) + ] + a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]), + jnp.arange(original_shape[1]), + jnp.arange(original_shape[2]), + indexing='ij') + + pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) + + pmid = pmid.reshape([-1, 3]) + + return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape) + + +@partial(jax.jit, static_argnums=(1, )) +def cic_read_dx(mesh, halo_size=0): + # return mesh + halo_size, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_size) + mesh = halo_exchange(mesh, + halo_extents=halo_extents, + halo_periods=(True, True, True)) + displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), + in_specs=(P('x', 'y')), + out_specs=P('x', 'y'))(mesh) + + return displacements + + def compensate_cic(field): """ Compensate for CiC painting diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py new file mode 100644 index 0000000..1d929ea --- /dev/null +++ b/jaxpm/painting_utils.py @@ -0,0 +1,185 @@ +import jax +import jax.numpy as jnp +from jax.lax import scan + + +def _chunk_split(ptcl_num, chunk_size, *arrays): + """Split and reshape particle arrays into chunks and remainders, with the remainders + preceding the chunks. 0D ones are duplicated as full arrays in the chunks.""" + chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num) + remainder_size = ptcl_num % chunk_size + chunk_num = ptcl_num // chunk_size + + remainder = None + chunks = arrays + if remainder_size: + remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays] + chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays] + + # `scan` triggers errors in scatter and gather without the `full` + chunks = [ + x.reshape(chunk_num, chunk_size, *x.shape[1:]) + if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks + ] + + return remainder, chunks + + +def enmesh(i1, d1, a1, s1, b12, a2, s2): + """Multilinear enmeshing.""" + i1 = jnp.asarray(i1) + d1 = jnp.asarray(d1) + a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype) + if s1 is not None: + s1 = jnp.array(s1, dtype=i1.dtype) + b12 = jnp.float64(b12) + if a2 is not None: + a2 = jnp.float64(a2) + if s2 is not None: + s2 = jnp.array(s2, dtype=i1.dtype) + + dim = i1.shape[1] + neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> + jnp.arange(dim, dtype=i1.dtype)) & 1 + + if a2 is not None: + P = i1 * a1 + d1 - b12 + P = P[:, jnp.newaxis] # insert neighbor axis + i2 = P + neighbors * a2 # multilinear + + if s1 is not None: + L = s1 * a1 + i2 %= L + + i2 //= a2 + d2 = P - i2 * a2 + + if s1 is not None: + d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected + + i2 = i2.astype(i1.dtype) + d2 = d2.astype(d1.dtype) + a2 = a2.astype(d1.dtype) + + d2 /= a2 + else: + i12, d12 = jnp.divmod(b12, a1) + i1 -= i12.astype(i1.dtype) + d1 -= d12.astype(d1.dtype) + + # insert neighbor axis + i1 = i1[:, jnp.newaxis] + d1 = d1[:, jnp.newaxis] + + # multilinear + d1 /= a1 + i2 = jnp.floor(d1).astype(i1.dtype) + i2 += neighbors + d2 = d1 - i2 + i2 += i1 + + if s1 is not None: + i2 %= s1 + + f2 = 1 - jnp.abs(d2) + + if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None + i2 = jnp.where(i2 < 0, s2, i2) + + f2 = f2.prod(axis=-1) + + return i2, f2 + + +def _scatter_chunk(carry, chunk): + mesh, offset, cell_size, mesh_shape = carry + pmid, disp, val = chunk + spatial_ndim = pmid.shape[1] + spatial_shape = mesh.shape + + # multilinear mesh indices and fractions + ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, + spatial_shape) + # scatter + ind = tuple(ind[..., i] for i in range(spatial_ndim)) + mesh = mesh.at[ind].add(val * frac) + + carry = mesh, offset, cell_size, mesh_shape + return carry, None + + +def scatter(pmid, + disp, + mesh, + chunk_size=2**24, + val=1., + offset=0, + cell_size=1.): + + ptcl_num, spatial_ndim = pmid.shape + val = jnp.asarray(val) + mesh = jnp.asarray(mesh) + + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) + carry = mesh, offset, cell_size, mesh.shape + if remainder is not None: + carry = _scatter_chunk(carry, remainder)[0] + carry = scan(_scatter_chunk, carry, chunks)[0] + mesh = carry[0] + return mesh + + +def _chunk_cat(remainder_array, chunked_array): + """Reshape and concatenate one remainder and one chunked particle arrays.""" + array = chunked_array.reshape(-1, *chunked_array.shape[2:]) + + if remainder_array is not None: + array = jnp.concatenate((remainder_array, array), axis=0) + + return array + + +def gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.): + ptcl_num, spatial_ndim = pmid.shape + + mesh = jnp.asarray(mesh) + + val = jnp.asarray(val) + + if mesh.shape[spatial_ndim:] != val.shape[1:]: + raise ValueError('channel shape mismatch: ' + f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}') + + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) + + carry = mesh, offset, cell_size, mesh.shape + val_0 = None + if remainder is not None: + val_0 = _gather_chunk(carry, remainder)[1] + val = scan(_gather_chunk, carry, chunks)[1] + + val = _chunk_cat(val_0, val) + + return val + + +def _gather_chunk(carry, chunk): + mesh, offset, cell_size, mesh_shape = carry + pmid, disp, val = chunk + + spatial_ndim = pmid.shape[1] + + spatial_shape = mesh.shape[:spatial_ndim] + chan_ndim = mesh.ndim - spatial_ndim + chan_axis = tuple(range(-chan_ndim, 0)) + + # multilinear mesh indices and fractions + ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, + spatial_shape) + + # gather + ind = tuple(ind[..., i] for i in range(spatial_ndim)) + frac = jnp.expand_dims(frac, chan_axis) + val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1) + + return carry, val diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 9b14a87..377df8e 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,42 +1,96 @@ +from functools import partial + import jax import jax.numpy as jnp import jax_cosmo as jc +from jax.sharding import PartitionSpec as P -from jaxpm.growth import dGfa, growth_factor, growth_rate +from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d, + normal_field) +from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, + growth_rate, growth_rate_second) from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, longrange_kernel) -from jaxpm.painting import cic_paint, cic_read +from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx -def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): +def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): """ Computes gravitational forces on particles using a PM scheme """ if mesh_shape is None: + assert (delta is not None + ), "If mesh_shape is not provided, delta should be provided" mesh_shape = delta.shape kvec = fftk(mesh_shape) if delta is None: - delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions)) + delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size)) else: - delta_k = jnp.fft.rfftn(delta) + delta_k = fft3d(delta) # Computes gravitational potential pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) # Computes gravitational forces - return jnp.stack([ - cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions) - for i in range(3) + forces = jnp.stack([ + cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), + halo_size=halo_size) for i in range(3) ], - axis=-1) + axis=-1) + + return forces + + +def lpt2_source(mesh_size, initial_conditions): + + kvec = fftk(mesh_size) + # TODO : this has already been done for LPT1, we should reuse it + delta_k = fft3d(initial_conditions) + source = jnp.zeros_like(delta_k) -def lpt(cosmo, initial_conditions, positions, a): + D1 = [1, 2, 0] + D2 = [2, 0, 1] + + # laplace_kernel should be actually inv laplace_kernel + # adding a minus sign here that will be negated when computing forces + # because F = -grad(phi) + # and phi = -laplace_kernel(delta_k) + pot_k = delta_k * laplace_kernel(delta_k) + + nabla_i_nabla_i = [ + ifft3d(gradient_kernel(kvec, i)**2 * pot_k) for i in range(3) + ] + # for diagonal terms + source += nabla_i_nabla_i[D1[0]] * nabla_i_nabla_i[D2[0]] + source += nabla_i_nabla_i[D1[1]] * nabla_i_nabla_i[D2[1]] + source += nabla_i_nabla_i[D1[2]] * nabla_i_nabla_i[D2[2]] + + # off diag terms + for i in range(3): + nabla_i_nabla_j = gradient_kernel(kvec, D1[i]) * gradient_kernel( + kvec, D2[i]) + phi = ifft3d(nabla_i_nabla_j * pot_k) + source -= phi**2 + + return source + + +def lpt(cosmo, initial_conditions, a, halo_size=0): """ Computes first order LPT displacement """ - initial_force = pm_forces(positions, delta=initial_conditions) + local_mesh_shape = (*get_local_shape(initial_conditions.shape), 3) + displacement = autoshmap( + partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), + in_specs=(), + out_specs=P('x', 'y'))() # yapf: disable + + + initial_force = pm_forces(displacement, + delta=initial_conditions, + halo_size=halo_size) a = jnp.atleast_1d(a) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, @@ -46,6 +100,39 @@ def lpt(cosmo, initial_conditions, positions, a): return dx, p, f +# @Credit Hugo Simon https://github.com/hsimonfroy/montecosmo +def lpt2(cosmo, initial_conditions, dx, p, f, a, halo_size=0): + + mesh_size = initial_conditions.shape + local_mesh_shape = (*get_local_shape(initial_conditions.shape), 3) + # TODO + # Displacements have been created in the previous step + # find a way to reuse them + displacement = autoshmap( + partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), + in_specs=(), + out_specs=P('x', 'y'))() # yapf: disable + + lpt2_delta = lpt2_source(mesh_size, initial_conditions) + delta2_k = fft3d(lpt2_delta) + + lpt2_forces = pm_forces(displacement, + mesh_size, + delta_k=delta2_k, + halo_size=halo_size) + dx2 = 3 / 7 * growth_factor_second(cosmo, a) * lpt2_forces + p2 = a**2 * growth_rate_second(cosmo, a) * jnp.sqrt( + jc.background.Esqr(cosmo, a)) * dx2 + f2 = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGf2a(cosmo, + a) * lpt2_forces + + dx += dx2 + p += p2 + f += f2 + + return dx, p, f + + def linear_field(mesh_shape, box_size, pk, seed): """ Generate initial conditions. @@ -56,13 +143,14 @@ def linear_field(mesh_shape, box_size, pk, seed): pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / ( box_size[0] * box_size[1] * box_size[2]) - field = jax.random.normal(seed, mesh_shape) - field = jnp.fft.rfftn(field) * pkmesh**0.5 - field = jnp.fft.irfftn(field) + # Initialize a random field with one slice on each gpu + field = normal_field(mesh_shape, seed=seed) + field = fft3d(field) * pkmesh**0.5 + field = ifft3d(field) return field -def make_ode_fn(mesh_shape): +def make_ode_fn(mesh_shape, halo_size=0): def nbody_ode(state, a, cosmo): """ @@ -70,7 +158,8 @@ def nbody_ode(state, a, cosmo): """ pos, vel = state - forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m + forces = pm_forces(pos, mesh_shape=mesh_shape, + halo_size=halo_size) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel @@ -94,19 +183,23 @@ def pgd_correction(pos, mesh_shape, params): delta = cic_paint(jnp.zeros(mesh_shape), pos) alpha, kl, ks = params delta_k = jnp.fft.rfftn(delta) - PGD_range=PGD_kernel(kvec, kl, ks) - - pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range - - forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos) - for i in range(3)],axis=-1) - - dpos_pgd = forces_pgd*alpha - + PGD_range = PGD_kernel(kvec, kl, ks) + + pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range + + forces_pgd = jnp.stack([ + cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos) + for i in range(3) + ], + axis=-1) + + dpos_pgd = forces_pgd * alpha + return dpos_pgd def make_neural_ode_fn(model, mesh_shape): + def neural_nbody_ode(state, a, cosmo, params): """ state is a tuple (position, velocities) @@ -119,15 +212,19 @@ def neural_nbody_ode(state, a, cosmo, params): delta_k = jnp.fft.rfftn(delta) # Computes gravitational potential - pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) + pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, + r_split=0) # Apply a correction filter - kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) - pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) + kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec)) + pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a))) # Computes gravitational forces - forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos) - for i in range(3)],axis=-1) + forces = jnp.stack([ + cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), pos) + for i in range(3) + ], + axis=-1) forces = forces * 1.5 * cosmo.Omega_m @@ -138,4 +235,5 @@ def neural_nbody_ode(state, a, cosmo, params): dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces return dpos, dvel + return neural_nbody_ode diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 1593ba0..7c6af44 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -83,53 +83,58 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): return kbins, P / norm -def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False): - """ +def cross_correlation_coefficients(field_a, + field_b, + kmin=5, + dk=0.5, + boxsize=False): + """ Calculate the cross correlation coefficients given two real space field - + Args: - - field_a: real valued field - field_b: real valued field + + field_a: real valued field + field_b: real valued field kmin: minimum k-value for binned powerspectra dk: differential in each kbin boxsize: length of each boxlength (can be strangly shaped?) - + Returns: - + kbins: the central value of the bins for plotting - P / norm: normalized cross correlation coefficient between two field a and b - + P / norm: normalized cross correlation coefficient between two field a and b + """ - shape = field_a.shape - nx, ny, nz = shape + shape = field_a.shape + nx, ny, nz = shape - #initialze values related to powerspectra (mode bins and weights) - dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) + #initialze values related to powerspectra (mode bins and weights) + dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) - #fast fourier transform - fft_image_a = jnp.fft.fftn(field_a) - fft_image_b = jnp.fft.fftn(field_b) + #fast fourier transform + fft_image_a = jnp.fft.fftn(field_a) + fft_image_b = jnp.fft.fftn(field_b) - #absolute value of fast fourier transform - pk = fft_image_a * jnp.conj(fft_image_b) + #absolute value of fast fourier transform + pk = fft_image_a * jnp.conj(fft_image_b) - #calculating powerspectra - real = jnp.real(pk).reshape([-1]) - imag = jnp.imag(pk).reshape([-1]) + #calculating powerspectra + real = jnp.real(pk).reshape([-1]) + imag = jnp.imag(pk).reshape([-1]) - Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j - Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) + Psum = jnp.bincount(dig, weights=(W.flatten() * imag), + length=xsum.size) * 1j + Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) - P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') + P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') - #normalization for powerspectra - norm = np.prod(np.array(shape[:])).astype('float32')**2 + #normalization for powerspectra + norm = np.prod(np.array(shape[:])).astype('float32')**2 - #find central values of each bin - kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 + #find central values of each bin + kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 - return kbins, P / norm + return kbins, P / norm def gaussian_smoothing(im, sigma): diff --git a/scripts/distributed_pm.py b/scripts/distributed_pm.py new file mode 100644 index 0000000..ad699c6 --- /dev/null +++ b/scripts/distributed_pm.py @@ -0,0 +1,101 @@ +import os + +os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax +import jax + +jax.distributed.initialize() + +rank = jax.process_index() +size = jax.process_count() + +import jax.numpy as jnp +import jax_cosmo as jc +import numpy as np +from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve +from jax.experimental import mesh_utils +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +from jaxpm.kernels import interpolate_power_spectrum +from jaxpm.painting import cic_paint_dx +from jaxpm.pm import linear_field, lpt, make_ode_fn + +size = 256 +mesh_shape = [size] * 3 +box_size = [float(size)] * 3 +snapshots = jnp.linspace(0.1, 1., 4) +halo_size = 64 +if jax.device_count() > 1: + + pdims = (4, 2) + devices = mesh_utils.create_device_mesh(pdims) + mesh = Mesh(devices.T, axis_names=('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + + +@jax.jit +def run_simulation(omega_c, sigma8): + # Create a small function to generate the matter power spectrum + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power( + jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) + + pk_fn = lambda x: interpolate_power_spectrum(x, k, pk) + # Create initial conditions + initial_conditions = linear_field(mesh_shape, + box_size, + pk_fn, + seed=jax.random.PRNGKey(0)) + + cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) + + # Initial displacement + dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) + + # Evolve the simulation forward + ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) + term = ODETerm( + lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) + solver = Dopri5() + + stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) + res = diffeqsolve(term, + solver, + t0=0.1, + t1=1., + dt0=0.01, + y0=jnp.stack([dx, p], axis=0), + args=cosmo, + saveat=SaveAt(ts=snapshots), + stepsize_controller=stepsize_controller) + + # Return the simulation volume at requested + states = res.ys + field = cic_paint_dx(dx, halo_size=halo_size) + final_fields = [ + cic_paint_dx(state[0], halo_size=halo_size) for state in states + ] + + return initial_conditions, field, final_fields, res.stats + + +# Run the simulation +if jax.device_count() > 1: + with mesh: + init, field, final_fields, stats = run_simulation(0.32, 0.8) + +else: + init, field, final_fields, stats = run_simulation(0.32, 0.8) + +# # Print the statistics +print(stats) + +# # save the final state +np.save(f'initial_conditions_{rank}.npy', init.addressable_data(0)) +np.save(f'field_{rank}.npy', field.addressable_data(0)) + +if final_fields is not None: + for i, final_field in enumerate(final_fields): + np.save(f'final_field_{i}_{rank}.npy', final_field.addressable_data(0)) + +print(f"Finished!!") diff --git a/scripts/eval_decomp_ode.ipynb b/scripts/eval_decomp_ode.ipynb new file mode 100644 index 0000000..c7ff38e --- /dev/null +++ b/scripts/eval_decomp_ode.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "id": "2206b58f", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "9f910706", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "pdims=(1 , 1)\n", + "#for single gpu\n", + "# pdims = (1 , 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "8bc804ab", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(128, 128, 128)\n" + ] + } + ], + "source": [ + "folder = f'out/final_field/pmwd/1/128_128/{pdims[0]}x{pdims[1]}/lfm/halo_0'\n", + "folder = f'out/final_field/jaxpm/1/128_128/{pdims[0]}x{pdims[1]}/LeapfrogMidpoint/halo_32'\n", + "folder = f'out/final_field/jaxpm/1/128_128/{pdims[0]}x{pdims[1]}/lpt/halo_32'\n", + "\n", + "only_final_fields = True\n", + "\n", + "init_field_slices = []\n", + "field_slices = []\n", + "nb_solutions = 1\n", + "nb_to_plot = 1\n", + "final_slices = []\n", + "\n", + "for _ in range(nb_solutions):\n", + " final_slices.append([])\n", + "\n", + "for i in range(pdims[0]):\n", + " row_init_field = []\n", + " row_field = []\n", + " row_final_field = []\n", + " for _ in range(nb_solutions):\n", + " row_final_field.append([])\n", + " \n", + " for j in range(pdims[1]):\n", + " slice_index = i * pdims[1] + j \n", + " if not only_final_fields:\n", + " row_field.append(np.load(f'{folder}/field_{slice_index}.npy'))\n", + " row_init_field.append(np.load(f'{folder}/initial_conditions_{slice_index}.npy'))\n", + "\n", + " for sol_indx in range((nb_solutions - nb_to_plot) , nb_solutions):\n", + " row_final_field[sol_indx].append(np.load(f'{folder}/final_field_{sol_indx}_{slice_index}.npy'))\n", + " \n", + " if not only_final_fields:\n", + " field_slices.append(np.vstack(row_field))\n", + " init_field_slices.append(np.vstack(row_init_field))\n", + "\n", + " for sol_indx in range((nb_solutions - nb_to_plot) , nb_solutions):\n", + " final_slices[sol_indx].append(np.vstack(row_final_field[sol_indx]))\n", + "\n", + "if not only_final_fields:\n", + " field = np.hstack(field_slices)\n", + " initial_conditions = np.hstack(init_field_slices)\n", + "final_fields = []\n", + "\n", + "for sol_indx in range(nb_solutions - nb_to_plot , nb_solutions):\n", + " final_fields.append(np.hstack(final_slices[sol_indx]))\n", + "\n", + "if not only_final_fields:\n", + " print(field.shape)\n", + " box_size = field.shape\n", + "else:\n", + " print(final_fields[-1].shape)\n", + " box_size = final_fields[-1].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "bfa0c523", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABdEAAAH9CAYAAAD1dDVUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9ebRtV1XnP/c+fX/O7dvX573kvfShC0ISOhFUREArVCjEQrBK1HLYFIXgDyhgaIkgCBaCQ5HGsoJBq7BBBekFaRLSJy+vv+/dvjl9u8/e+/dHfGd+57z3vIQymG5+x7hjrHvOOmuvvZo519rnzM9ywjAMyWQymUwmk8lkMplMJpPJZDKZTCbTNrmPdgVMJpPJZDKZTCaTyWQymUwmk8lkeqzKHqKbTCaTyWQymUwmk8lkMplMJpPJNET2EN1kMplMJpPJZDKZTCaTyWQymUymIbKH6CaTyWQymUwmk8lkMplMJpPJZDINkT1EN5lMJpPJZDKZTCaTyWQymUwmk2mI7CG6yWQymUwmk8lkMplMJpPJZDKZTENkD9FNJpPJZDKZTCaTyWQymUwmk8lkGiJ7iG4ymUwmk8lkMplMJpPJZDKZTCbTENlDdJPJZDKZTCaTyWQymUwmk8lkMpmGyB6im0ymJ4Qcx6G3ve1t/+bXfdvb3kaO4/ybX9dkMplMpse7zHebTCaTyfT4kvlu05NZ9hDd9ITTPffcQ6961atodnaWEokEzczM0E033UT33HPPtrx/8id/Qo7jDP6SySTNzMzQC1/4Qvq93/s9qtfr2z5z3ngP+1tZWblg/Xq9Hr3//e+nq666ivL5PBWLRTpy5Ai9/vWvp/vvv3+Q7+tf/zq97W1vo0ql8q9uk3+Nvv71r9OznvUsSqfTNDU1Rb/4i79IjUbjUa3Tv7VarRa97W1voy996UuPdlW26TOf+QxdffXVlEwmadeuXfTWt76V+v3+o10tk8lk+p5kvvuRlfnux67vvvnmm+lVr3oVXXTRReQ4Dt1www2PdpVMJpPp/0nmux9Zme9+bPruzc1Neve7303XXXcdjY+PU7FYpGc84xl08803P9pVMz0Kij7aFTCZHkn9xV/8Bb3yla+kkZEReu1rX0t79+6l06dP0x/90R/RLbfcQv/7f/9v+vEf//Ftn/vv//2/0969e8nzPFpZWaEvfelL9Eu/9Ev03ve+lz7zmc/Q5Zdfvu0zH/rQhyibzW57vVgsXrCOL3/5y+mzn/0svfKVr6TXve515Hke3X///fTXf/3X9MxnPpMuvvhiInrQib797W+n17zmNQ9Z5vdLt99+Oz3vec+jSy65hN773vfSuXPn6Hd+53fo2LFj9NnPfvZRqdMwtdttika/Pyat1WrR29/+diKibZvdt7zlLfTf/tt/+75c96H02c9+ll760pfSDTfcQB/4wAforrvuone+8520trZGH/rQhx6VOplMJtP3KvPdj6zMdz+ox6rv/tCHPkS33norPfWpT6XNzc1HpQ4mk8n0r5X57kdW5rsf1GPRd3/jG9+gN7/5zfTiF7+Y3vKWt1A0GqVPf/rTdOONN9K99947qK/pSaLQZHqC6Pjx42E6nQ4vvvjicG1tTby3vr4eXnzxxWEmkwlPnDgxeP2jH/1oSETht7/97W3l/eM//mOYSqXC3bt3h61Wa/D6W9/61pCIwvX19e+5jt/61rdCIgrf9a53bXuv3++HGxsbg//f/e53h0QUnjp16nu+ziOlF73oReH09HRYrVYHr/3hH/5hSETh3//9339fr91sNr+v5X8vWl9fD4kofOtb3/poV0Xo8OHD4RVXXBF6njd47c1vfnPoOE543333PYo1M5lMpocn892PvMx3P6jHqu9eWFgIfd8PwzAMjxw5El5//fWPboVMJpPpe5T57kde5rsf1GPRd588eTI8ffq0eC0IgvC5z31umEgkwkaj8SjVzPRoyHAupieM3v3ud1Or1aKPfOQjND4+Lt4bGxujD3/4w9RsNum3f/u3H1Z5z33uc+k3fuM36MyZM/TJT37yEanjiRMniIjoB37gB7a9F4lEaHR0lIgeDF37tV/7NSIi2rt37yBk7fTp04P8n/zkJ+maa66hVCpFIyMjdOONN9LZs2dFmTfccANdeumldOutt9Izn/lMSqVStHfvXvqDP/iDh6xrrVajz33uc/SqV72K8vn84PVXv/rVlM1m6VOf+tQFP/+lL32JHMehm2++mX7913+dpqamKJPJ0Ete8pIL1vO6666jdDpNv/7rv05ERGtra/Ta176WJicnKZlM0hVXXEEf+9jHtl1vJzbb4uIi/cf/+B9pcnKSEokEHTlyhP74j/9422c7nQ697W1vo4MHD1IymaTp6Wl62cteRidOnKDTp08PxtPb3/72QV+cv9ZObLZ+v0/veMc7aP/+/ZRIJGjPnj3067/+69TtdkW+PXv20I/8yI/Q1772NXra055GyWSS9u3bRx//+Mcv2LZERPfeey/de++99PrXv178EuDnfu7nKAxDuuWWWx6yDJPJZHq0Zb7bfPeTyXcTEc3Pz5Pr2hbMZDI9fmW+23z3k8l37927l3bv3r2tDV760pdSt9ulkydPPmQZpieObAVnesLor/7qr2jPnj307Gc/e8f3r7vuOtqzZw/9zd/8zcMu8z/8h/9ARET/8A//sO29ra0t2tjYEH8PxVE7b3z/9E//9ILc6pe97GX0yle+koiIfvd3f5c+8YlP0Cc+8YmBU3nXu95Fr371q+miiy6i9773vfRLv/RL9I//+I903XXXbatDuVymF7/4xXTNNdfQb//2b9Pc3Bz95//8n3d0aqi77rqL+v0+PeUpTxGvx+NxuvLKK+m73/3uBT9/Xu9617vob/7mb+iNb3wj/eIv/iJ97nOfo+c///nUbrdFvs3NTXrRi15EV155Jb3vfe+j5zznOdRut+mGG26gT3ziE3TTTTfRu9/9bioUCvSa17yG3v/+91/wuqurq/SMZzyDPv/5z9PP//zP0/vf/346cOAAvfa1r6X3ve99g3y+79OP/MiP0Nvf/na65ppr6D3veQ/9l//yX6hardLdd99N4+PjAzTKj//4jw/64mUve9nQa//Mz/wM/X//3/9HV199Nf3u7/4uXX/99fSbv/mbdOONN27Le/z4cXrFK15BL3jBC+g973kPlUoles1rXrMjSxB1vv11/8zMzNDc3NzD7h+TyWR6NGW+23w36onuu00mk+mJIPPd5rtRT1bffZ7JPzY29v/0edPjVI/2T+FNpkdClUolJKLwx37sxy6Y7yUveUlIRGGtVgvD8MJhZedVKBTCq666avD/+bCynf4OHTp0wesHQRBef/31IRGFk5OT4Stf+crw93//98MzZ85syzssrOz06dNhJBLZFpp21113hdFoVLx+/lrvec97Bq91u93wyiuvDCcmJsJerze0rn/+538eElH4la98Zdt7P/ETPxFOTU1d8F6/+MUvhkQUzs7ODto7DMPwU5/6VEhE4fvf//5t9fyDP/gDUcb73ve+kIjCT37yk4PXer1eeO2114bZbFaUSyrs67WvfW04PT0tQvXCMAxvvPHGsFAoDEIF//iP/zgkovC9733vtnsIgiAMwwuHlZ0fD+d1++23h0QU/szP/IzI96u/+qshEYVf+MIXBq/t3r17Wxuvra2FiUQi/JVf+ZVt10KdHx8LCwvb3nvqU58aPuMZz7jg500mk+nRlvlu891PNt+tZTgXk8n0eJP5bvPdT3bfHYZhuLm5GU5MTITPfvazv+fPmh7fsl+im54QOn+ady6Xu2C+8+/XarWHXXY2m93xtPBPf/rT9LnPfU78ffSjH71gWY7j0N///d/TO9/5TiqVSvRnf/Zn9IY3vIF2795N/+7f/buHdSL4X/zFX1AQBPSTP/mT4tv4qakpuuiii+iLX/yiyB+NRulnf/ZnB//H43H62Z/9WVpbW6Nbb7116HXOf2OdSCS2vZdMJrd9oz1Mr371q0W/vOIVr6Dp6Wn627/9W5EvkUjQT//0T4vX/vZv/5ampqYGvw4gIorFYoOTyr/85S/veM0wDOnTn/40/eiP/iiFYSja6YUvfCFVq1W67bbbiOjBfhwbG6Nf+IVf2FaODhd7ODp/X7/8y78sXv+VX/kVIqJtv8g4fPiw+BXH+Pg4HTp06CHDwh6p/jGZTKZHS+a7zXejngy+22QymR7vMt9tvhv1ZPTdQRDQTTfdRJVKhT7wgQ98z/U2Pb71/TlS12T6N9Z5Z7GT00U9XKePajQaNDExse3166677v8pdCeRSNCb3/xmevOb30zLy8v05S9/md7//vfTpz71KYrFYg/JgTt27BiFYUgXXXTRju/HYjHx/8zMDGUyGfHawYMHiYjo9OnT9IxnPGPHclKpFBHRNp4Y0YMss/PvP5R0PR3HoQMHDgjOHBHR7OwsxeNx8dqZM2fooosu2sYOveSSSwbv76T19XWqVCr0kY98hD7ykY/smGdtbY2IHuTlHTp06BE7YfzMmTPkui4dOHBAvD41NUXFYnFbnXft2rWtjFKpROVy+YLXeaT6x2QymR4tme9mme9+cvhuk8lkerzLfDfLfPeT03f/wi/8Av3d3/0dffzjH6crrrjie6+46XEte4huekKoUCjQ9PQ03XnnnRfMd+edd9Ls7Kw4sONCOnfuHFWr1W2G+ZHS9PQ03XjjjfTyl7+cjhw5Qp/61KfoT/7kTy7oWIIgIMdx6LOf/SxFIpFt72ez2UesbkREy8vL295bXl6mmZmZR+Q65/VIPvQNgoCIiF71qlfRT/3UT+2Y5/LLL3/ErreTHu636Tv1IdGD3+pfSNg/8/Pz4r3l5WV62tOe9rCubzKZTI+WzHezzHc/OXy3yWQyPd5lvptlvvvJ57vf/va30//8n/+Tfuu3fmvA8Tc9uWQP0U1PGP3Ij/wI/eEf/iF97Wtfo2c961nb3v/qV79Kp0+fFiFWD6VPfOITRET0whe+8BGr506KxWJ0+eWX07FjxwYhYsOcwf79+ykMQ9q7d+/gm+0LaWlpiZrNpvhW/IEHHiCiB0+pHqZLL72UotEofec736Gf/MmfHLze6/Xo9ttvF69dSMeOHRP/h2FIx48ff1jOdPfu3XTnnXdSEATiW/H7779/8P5OGh8fp1wuR77v0/Of//wLXmP//v30zW9+kzzP2/ZrgvP6XsLLdu/eTUEQ0LFjxwbf3BM9eOBKpVIZWufvVVdeeSUREX3nO98RD8yXlpbo3Llz9PrXv/4RuY7JZDJ9P2W+e2eZ735i+m6TyWR6Ish8984y3/3E9t2///u/T29729vol37pl+iNb3zjI1q26fEjY6KbnjD6tV/7NUqlUvSzP/uztLm5Kd7b2tqi//Sf/hOl02n6tV/7tYdV3he+8AV6xzveQXv37qWbbrrpEanjsWPHaGFhYdvrlUqFvvGNb1CpVBqcBH7e+Wpe28te9jKKRCL09re/fdu3pmEYbrv3fr9PH/7whwf/93o9+vCHP0zj4+N0zTXXDK1roVCg5z//+fTJT35ShOt94hOfoEajQT/xEz/xsO754x//uPj8LbfcQsvLy/SiF73oIT/74he/mFZWVujmm28W9/OBD3yAstksXX/99Tt+LhKJ0Mtf/nL69Kc/TXffffe299fX1wfpl7/85bSxsUEf/OAHt+U7377pdJqItvfFsDoTkTiJnIjove99LxER/fAP//BDlvFwdOTIEbr44ovpIx/5CPm+P3j9Qx/6EDmOQ694xSsekeuYTCbT91Pmu813n9eTwXebTCbTE0Hmu813n9eTxXfffPPN9Iu/+It00003Dco3PTllv0Q3PWF00UUX0cc+9jG66aab6LLLLqPXvva1tHfvXjp9+jT90R/9EW1sbNCf/dmf0f79+7d99rOf/Szdf//91O/3aXV1lb7whS/Q5z73Odq9ezd95jOfoWQyue0zt9xyy44hXC94wQtocnJyxzrecccd9O///b+nF73oRfTsZz+bRkZGaHFxkT72sY/R0tISve997xuEGZ13tG9+85vpxhtvpFgsRj/6oz9K+/fvp3e+8530pje9iU6fPk0vfelLKZfL0alTp+gv//Iv6fWvfz396q/+6uCaMzMz9D/+x/+g06dP08GDB+nmm2+m22+/nT7ykY8M/Qb4vN71rnfRM5/5TLr++uvp9a9/PZ07d47e85730A/+4A/SD/3QD13ws+c1MjJCz3rWs+inf/qnaXV1ld73vvfRgQMH6HWve91Dfvb1r389ffjDH6bXvOY1dOutt9KePXvolltuoX/6p3+i973vfRdk7P3Wb/0WffGLX6SnP/3p9LrXvY4OHz5MW1tbdNttt9HnP/952traIqIHD2D5+Mc/Tr/8y79M3/rWt+jZz342NZtN+vznP08/93M/Rz/2Yz9GqVSKDh8+TDfffDMdPHiQRkZG6NJLL6VLL71023WvuOIK+qmf+in6yEc+QpVKha6//nr61re+RR/72MfopS99KT3nOc95WO32cPTud7+bXvKSl9AP/uAP0o033kh33303ffCDH6Sf+ZmfEd/Gm0wm02NV5rvNd6OeDL77K1/5Cn3lK18hogcfLjSbTXrnO99JRA9yf6+77rpH7Fomk8n0/ZD5bvPdqCe67/7Wt75Fr371q2l0dJSe97zn0Z/+6Z+K95/5zGfSvn37HpFrmR4HCk2mJ5juvPPO8JWvfGU4PT0dxmKxcGpqKnzlK18Z3nXXXdvyfvSjHw2JaPAXj8fDqamp8AUveEH4/ve/P6zVats+89a3vlV8Rv998YtfHFq31dXV8Ld+67fC66+/Ppyeng6j0WhYKpXC5z73ueEtt9yyLf873vGOcHZ2NnRdNySi8NSpU4P3Pv3pT4fPetazwkwmE2YymfDiiy8O3/CGN4RHjx4d5Ln++uvDI0eOhN/5znfCa6+9Nkwmk+Hu3bvDD37wgw+7Pb/61a+Gz3zmM8NkMhmOj4+Hb3jDG3ZsF60vfvGLIRGFf/Znfxa+6U1vCicmJsJUKhX+8A//cHjmzBmR93w9d9Lq6mr40z/90+HY2FgYj8fDyy67LPzoRz+6LR8RhW9961u3ffYNb3hDOD8/PxgLz3ve88KPfOQjIl+r1Qrf/OY3h3v37h3ke8UrXhGeOHFikOfrX/96eM0114TxeFxc6/x4QHmeF7797W8flDc/Px++6U1vCjudjsi3e/fu8Id/+Ie33cv1118fXn/99Tu2h9Zf/uVfhldeeWWYSCTCubm58C1veUvY6/Ue1mdNJpPpsSLz3ea78bNPZN99obGo28JkMpkeyzLfbb4bP/tE9d167Oq/ndrI9MSVE4Z2Ao7J9ETVDTfcQBsbGzuGVn2/9aUvfYme85zn0J//+Z9/39Eivu9TNBqld7zjHfSWt7zl+3otk8lkMpm+nzLfbTKZTCbT40vmu02mJ4eMiW4ymR73On+S+djY2KNcE5PJZDKZTA9H5rtNJpPJZHp8yXy36ckuY6KbTKbHtW655Rb6+Mc/To7jPKLMUpPJZDKZTN8fme82mUwmk+nxJfPdJpM9RDeZTI9z/df/+l/JcRz6oz/6Izp06NCjXR2TyWQymUwPIfPdJpPJZDI9vmS+22QiMia6yWQymUwmk8lkMplMJpPJZDKZTENkTHSTyWQymUwmk8lkMplMJpPJZDKZhshwLkQUBAEtLS1RLpcjx3Ee7eqYTCaTyURERGEYUr1ep5mZGXJd+94bZb7bZDKZTI9Fme8eLvPdJpPJZHos6uH6bnuITkRLS0s0Pz//aFfDZDKZTKYddfbsWZqbm3u0q/GYkvluk8lkMj2WZb57u8x3m0wmk+mxrIfy3fYQnYhyuRwREb2o+IsUcxLU9D3xfpkag/SEWxikd2ViIl/P5/TJVpM/79REvmSYHKQTxGWsu2sqX2aQbjn1QTqgYJA+4uwXn/GJEfdIu4+qb/pbQX+QjjsRSMtvXEaS/P94ist4aqkr8l06tT5I370yPkifasZFvq0el3GqzhVc73ZEvhq1B+n5KLf5RDoi8hWg+HvLfE+dwBf5JhKccSTJdaj3uA54f0RE+zPczmWP3+v4Mt8mNEXd4/L0byviUPUvN08O0lPBlMh3pMD9fq7B96TH5aFCiuue4NdzMXnMQQeaYoGHJfnqNISYyzUegXYdSciMo3Eu0At4fJxtcXpdDg/yAy7jSJFfv3qkKvJttHlupKN879/ayop8d2xx35zzuIwl96TIdyg8MkgfzHPZLU/dE4yJf7dnc5AuZFsi38IGV/7ALOdLTcjxdud3JwfpW8sZejiCJqK4+uIzF+U3jzX4zRNVvu54Sn5oLsP3lIHPZyPy3le7/LmFJr/X6YtsNAbzIwqDe6EeiHyVPo9TnHdzWTkjsLYReGtvRrbl1RPczrtfwJUKynKQ/dVneUNW83iybXTldcs9uC5UYr3N955SnnFXZme7dX9L2vYeceFNl9/rUlPky4Vjg/QUjQ7S2maMxOPkBV369ObvDvyUiXW+TUay15DrRMh1tE9mv1mK7R6ko5QQ+arB8iDdD9j3hCTHdtfjPi0kd/FnQum/Nup3D9IRl+10Ic11cB05yIouL9T2hlz2D0zIuq7Dpe6v8T8jMelrJ8AeoLmrdOU9LXfZxu1Osa16xpi0ExfneR3U7HM7/91KSuS7r8rltYltgU/KJ0d4POei3BZ6LZBw+b25DKdHEjxb9mXkPaWjfK1jda7rsZq8p/s7G4N0lvg+miRty1m6h+8j5Pf2OdeIfAUnPUh3Qr73gOR1t1y2aZvB6UG67W2KfCMJXt9d4Vw2SF9WkmMHS1/r7LwGJCJKguFGH1PpwmfkR2gC7P50kt+NKGOFn4u7/F/VkxnPtvh/vK4uD/+PDakrEZEHjnN3nu3+VkfmW2rzuNpweM0QD+W8mY7kB+lUlC+81ZPrr37IY2w6yWNnKi1vBPcFqy0epyMwP7PK3yy3dm6XQlyWnQNzl4xgWrcRp7EEP5Tl4boN26+nFotV8PGXltg+rbRkvmOd8iAdIa7gfEz6snKf/WYhyv0xk5FrmgLcL66XatA1altG47B+bfZxDSPrug7rhDKMsXW1EDpBC4N0n/i9cbWOr7hb5IceHat+2nz3DjrfJtfk/iNFnTjV1T55i84O0p0+z1ftN+cjVwzS41QcpLOuHAinAt6jLgZsz/uB9Deez2u1vs/vxaJs21OxUfEZtNvNzirni4+JfKn4yCCNvgOfAxAR5V3eq4wn+X47ah72YV5PpnmuJOU2mVowhHHPW/ekT07CgrgI/rUv3Sutwpw4F/C9d522yBcNue7jIbfZtWPclnuU78Z5/QDsdbTdx3tC/xBTe6c0+Dxc73fkrVMK2gzL8NS9781wPfB5RtUbng9tzVpXVhA/tz+7c9lERF2o7wT44csKvMYdycixnM9yf5SewX0RVHoiX/P0zkcjbq7LfXe5xePy/hq/d7Yt7wn3h9NJbkDt4xfhcxW04cofFuCZxkSCyzuQq4t87T5/sBfAcy1XdvbROq9z767w68cbcvwuuzyXZ8G+d8HuT4BdICKayfB1+9Cs6Qs8bUUfpZ8v1WGc4zv6GcF9Zb7HLZ/voxBJinwPhGcG6VrIzxz3wrMSIqKDafZZaHbqPTkhZrJ8vzjmG315HziPEjDXWtA1p2qy7C2Px+lcitcZcWXfcJ2bgzVSSz3DaEOHVGEflIzKxgzCkLygS3+59b6H9N32EJ1oEEoWcxIUcxMUDWSDRuDBSMyBjnTl4ht3EVGHey/iyHy4gY+C89T5IhSH9zifAxv7mCs32C6MJjSLUVcO6GjIozAGD9FjKmwh7vJ7CShDPcumHGzg0xGuUyoi65eM4CaO70MvaKND7hE31A/+z+mYAxt2RxpN7KsE3GMXNnsJ1UapCJfRFmNC5sM6YHnbHqJDPuzrqCPbKAELqJjL96THJebDRVMqsrND1HXQD9HjcP8XKi8d4bHdc3DhhmNFlt2H0YjlZaPy3tvwfzrK5SWVI4iBU8T20w/RcK5he/VdeU/Y91nYxOVi0gpnoH5izCeG59N1HyZcPOr2S0UCeA/nK3wZ5spJiXMN2xzLejAfLjY4X7CtjbhSUfGAQ861GIyJYfOOSC6oMI3ji4goB/2RhwcPQUuPS25zL8Cxo+cNXBfSaI/iyhZIu4WLY/mwDVsC57hLctGKtj0GY3S7zeB8FvK8Xefb5MEH6NFtG2wHfBvahojatOLnXPhMoHrEgbEtPkP6us5Df0bVFX08jgltP3D8ot+MqfUI2gmcynpDEQVfGXfRd8v5lYniLpHrinZVlxeFr8oc0nYC11LcFnotEIN2irt4XbRvsuw0/J+McLvgHCdSflisy2QdsH9D2tn36P/x3vVDdGEb4P5wvD74HowJB8eEHL/DHmDrh+jYZnFh+4Y/RJftPHxDPKwO3W3rlp3XgNsfovML8iGJeqpBuIbDsSLz4bga1u8Plo/jki8cUz8wQWsfF2tUmQ+HM/rKhDt8vTSsXXTZ8sE5p/XcxTIu9BA9McQfhqGek5wR5//2Nud2wYfoet8SFX2Nawa9puE0rpfwAdP2tRNnxPvVdkaOS7Crjnw6hr47JCxP3hOOMfPd23W+TaJOnKJOYtv+F22u9ONyTAybyzH1ED0SPjybi/+j75Z1GL7OGOb79eekr5B1lb6R3wtCOb9wVF3Invhij4p7a+WTXVy7w7pFFifmBLZ/RO27I9CH0SFrGu275T5o570JkdzDSf+gbdrOD9G3+0Ysg9P63tEH4t5i+0P5nZ9v4H5Lf25Y2USyrzFfJsp7i1xUjo9cjAvPw5cxgdqSRuI7PzPoRuWc7EV3fraj7TS2H/av9vHimcEFbDh+KYxru2xU7qvQZsQu8BAd6457vahzgfUhzEn/Ar4M1yARMZZpqOSX3upL+CFfgCfUMzh8FhANYI5rv0S4L915fhIRxcGv+yHONdlGeL/oa/tqbSHWMVB3f8j6iEiuC+JinSGyibmM893fZgfxPnD9JjPi/T6U77aH6KCQHuyMTfiVChFRJMRvt7jhv13bEPk6Dv8CK4DJqL+ddcEsd+EXT5769VMm5F/EJIm/9dvtTAzS+C01EdFCk8vohjypZuLy12L4bVdX/KJGOvNsbOdvo7d68rqfPcm/oluHb1r1t7in4Vera12ua0RvUHCBDAbhbEMuaH34yrIAX0+l1MYtD7OuCH7BC3beVBIRpcEhrXS57Lr6xrkLXh9/FbXWkQa+Br9WrDn8DWCG8iLfUpONRQN+8eOphxANDx0z30dbfZOJ37DjvVd60nGuwNd2K/AD7Bn1tXB7yFeqa/AleFTZnQL80nskrr4eBCXAQa50uB0iasMzDd/inIH70I5gM+Rft7jwS8NrRmV5B+Ab+wPX8bfb3pqsa7HOY7YFURab98pfm99T5f/vq/DrOD6IiIrgWUvQN+mEftDN/09CdEjb5764rCDHx+EC/3I0AId2uim/OceuSsM4mlKRGbiBzYCz1JuaOvwKfAYuNR7XC3ZO47fWc2lpL2ePcH+4e/mXvLVvnhX5TjbZduE6uqGG23qH2xKdJdoIfe8TsMJbhC93ttx1kS8T8FxGvxFRX+5s0blBuuwsDdKh2iRd1r2SvFDaEdN2eUGLHCdCMVf6uQhsWtsh/yqyGerFMvePD+O51VkZes1K5zR/Xj3AHs3yrzrSEf4FVimcGaR7al2A/t+FhduSDIYRdtqH+5hKyXk4neax3YL5VYzrhxBsq0oJ9JNy3Pkhv9cHvzkqTS5dO87lrcOvWdfaciJu9vn+N3z4Aotk/fAXhbhBnk/zZ54yLn/BfbRcHKTTkZ39MxGRBw8DKsR2xifp5DMO92ETfnm36i6IfGdCXjum3NIgvSeQEYPY1wGs06Jq/BZD/vVTCgzmalv7Q26XUhzXAiKb2Mjgeq4N7T+ZGr4hvrvCBexWkUXzKfhyHcbHSEzOtRqMP1x/9VRd8bq4kcyrhVoN3syLXyrLfCvwq7d8yOvprurr1T7/EjXr8+Auhw2RD8fOdMj9tqF+Ad/1cUMGn4ex2FVrtlGYh7i+1L+yzkKEGY7zsYRszA1Yv2LEoNYqDJgadEg2JsfESIIr8gD8lLLjy+uOudzO+Rjb4lH1U9mVHv4YgNu8KM2qeNhwtsltdqbBtgqvQ0TUh70ALmWj6iEEzge8977yFRMhRxk2HB4rHWXPe9TeZkdM29Vy2hRxfGqpX6IHsO/Dh9GNrvTJjUxlkL44yn1TUE+tgjZHSG86bJubPRkB3odfcfrwK/UIfsEMtp2IKBrnvUqzw1FtXfgFPRFRANFJy+lTg/RccEDky4ON7Is9pXxGgHvCsSTPtT0q+LUO/h+30HotgNFrm2DHyl05r08H3Gb5kH+peXFsUuTLghGPDLHnjvraFr/oQnetffeuLBd4WYFvaqsn7+mbG/glGv4ARtpcMCfifg8UpNEtwT6m7MGXrMqs4j680ec3z6kf/+C6KB/ja42rPSBs9YQf7/h8v72+vPdelz9Uv43HTnVTrjPOVnjfshuiw09VCjJfm+cA/oq8qfZYR2AvmoQHlxVPNhJGBVdhHx9RfVOGX+U3+/iDBvkr4d0ZnrstaItMUtrhMfjR21yGnUylJ79daHd53ZeNct+kYG+nxyWO8zHwX5PKJ+fhC47VLpetv4zBKDX8ZbvOtzsHP9BpcP/q+rWI+xcjaNskoxjaPtuTCfDX+ksqHPddsaaU9UMTgpFj+NxjQq09Q/jic63jweuy7M2AN0ke+NzpiBy/OVjH4G2c7EjfE1JA/VDhFIboAt+NmEwmk8lkMplMJpPJZDKZTCaTyfTklj1EN5lMJpPJZDKZTCaTyWQymUwmk2mI7CG6yWQymUwmk8lkMplMJpPJZDKZTENkTHTQil+jqJOgkmJU5yJ4cAaDdNb9isjnOQx46hDzE3PhiMiXCpm7tOYyD3fde0DkC4CZNBYwS7UEHHR9OOQKbfF1iK/T84cfbhgBKrI+/RuZ6ONwInRLEfvvq/H/yEjNxWR5VWAN4uFEuagciilgC+MBG5WeZFttdfDgF35d82EnAAFWiiETmT+fjsq64iFJeeCjN/vy3ovAIMVDDTa76nRtYCx5ITOc1l3Jdo57zMdCpmREMT63ushS5fbD+hARleDQkEXg6x6tKdguKAnlNTzZljVv59OPWwDsysdkHfB0bTwYJBWT/VkEHnakyryzri8Zbj4wyW4A5nClJ+faKlQQ8X970pJ3de1TmVEdmeSxd/RLclx+Y53LTwOnPK0OdGkCDy+PjFrFjtvs4OdwLKqTrbvI1+bXryzioS1y/B6ry5PVz0szCIswH2JDDp4jkoyzA1nmp11Vkje11GZbMxLn/p1IyfF2tMrtvNHjMX9bWdb77Bf2DtLPObU4SDdbEvyIZxXgYdt1TzEI4SyGZWK246XE5zo8c0x8hHbDuFzpcP321uZFvjTYsU5QHKRPh0siXxu4dBHo914o2+h259sU0PAzBEwPKuomyXWigj9ORNTymV8dcxjQv9k9Jj8fYfuCh39lE1MiH3JR2x772pA0Y53zTQTM8Y/DAWLIEiUi8oFvOpVRMGAQ8lIL8TS8LvMhb3oN+Kaa938gz+NvAjicuagcd8iYPN0afhjuwSx/LgPzodqVBmXJZZvbB0b4TLBX5JsDEPQ1YGuum2c27lZDnvPwQIMbA89E0KzNkZCZiW1x+K86ZAnWbG2H2fq1/qLIhwfMJYHHqM8zmQpmOe1wOqUODPXgvJoyrH30GiQOa9SeXhSC8OwUNIt4uNZGR47lsw3+v95HfyrXlCFsJ5CDui8j2fobPTiMPANndTRkvZGbWwAfqg8nw7NwkBPa7qvyoC3R5qZIThxkpOPaOFBzvARnFiFrt9KV+VLgjHZn4byQFH9GH8jphXpWPaiEyjed5LbFg7xaio27O835toD7WlXI7g60GZ6V1OjKPmzBfN1wmQFdDMdFvoti7EjxHJuqOo8nDjYXWcdFtX9owbqqCgOk4vN6pBfI/szHeS5PwpBd6cg2XoPBs+6xv28qVuyow3Y7Jw6AlHPylB+Sb+eZPKSSYZKilCBf7bvTDv9fi/LZM8mIygfzcLXH/XZOHbKQAt9bdJjd3U3Isw7qXbbpyESPRXm9OQk+nYgoA3vtCBzM0A7KNEzjAa8dd8XlWmAkgT6eXy8mpJ9bAhg1MocrnhzbuB1G7vlEUs0v8JX3ttk4IHOYiMhzef6niddcOMeJ5DocbTMeiXC6JfdY6L6WWvwhX5o+Gk1wxg04o+3uiszXgQ/mhzDaieSZCPjMQfsb5JtPAec6owo80Rjus1B54DQjKzqm7IkL+7scpHcVmeecTEiDHof6ba3z+N1U52Kdacr9NddBNvruNM8HF9ZI3YS8QXw+dLYFawuFmcbzTOoeHvQox8QUVA/P87t1S+a7C54ZYJtfVhj+e2E8t08zuXvABUc/3hfnlMiy8fkXnlmy1JFlV/s716mozpDZgvUSzg3NoV+GFzpwdmMzkD4oBgd09uAsqKojzxWqe3zuA64Psc+IiBIRfB7J+bLq6TKOEOSRL8P5PvpcHDwPcdnnPXPdrYh8HZdtOJ5DUgouE/kmYH09AmN2oyor26U+OQ/zN+b2S3STyWQymUwmk8lkMplMJpPJZDKZhsgeoptMJpPJZDKZTCaTyWQymUwmk8k0RIZzAa24p8h1YvQU9yrxOobbJAGzMEHTIt/xOoeSHXOPDtIOyTCOtsMhMa2Qw72ijgxDdOFzCQg5PdmuD9K+ChOuuKuDdDrksOioK8NtCjEOa8BQ1EBFAmPkxnicrzWuwoYw/LkE+Ta68t73QahbGkafp0K1xhL8/U65x3Uvq5BwDG+pQGWj6uuhVIQzFmIqZmSIjkNI+BKEJN1fleGdGHY1keT7G03I0Oxpd2KQdptXDtJ1qoh8iHApxLmReqpzMDQbw3A9hQLpQVh/CH3tq/DkPvzvBsND0aBrCCOSAig7oT6ECJgTTUAdJOWYL2R5Dl08uzFIj5UluqMEKJAGhC7XVRhzq8Dtd0mOQxKnMzI88cTdjGlJHOWwqP97TnI97i7z2BmFkMtrFf7jkhyPER9QAEerMh+iRZp96GtfmuaZNLdfEfA8I3Gu68mmbMtVCFfGufasMRm+Op1tDtLHK4wgqHqyDg2oX8fne798Zk3ku+Ygh49Fp/neN74m5939Vb4WTtdOIMfOuTZf9/88wCGwOYXQQWzDYlPbJ9YKcdjaJjFSouczVqEbSAMShbDG68bZ/s4kZVjkvTVulzb0YVCTvmI9PDVIJxwIGXRlqPJy7y4Kw4dnr57MykemyXVi27AqPZ/HOmJa/EDGlcYibF/iYH8TrkQLYYhfCOOv06+IfCH45dDhfImQfULSkf4hFeGx44Gt13QOxCxlAJtV94bnwzDJBYXNwDI6EIZb68sQ37qH+fj1osLIVDy8D35d+4SJHiPqsN/mVGj7kQK/99QJnrsz1/PrK5+Rdn8+xcZgy0OMhEKxJfha6EPvacow/Dzx+FgIuaEb3RWZL8n2KRMWB2lXQW8KEFKL2IyWwj/g+i4PKCEM1yUiWoawd/TxiF8hIhpLcGfloN+xnxqeLBuv1QAkXbkrO74IAy4PYcwV5Udw9OGSJqcQcLiMOZDlOugw6OU2Z8S5klLroOkkj+flDq8zMhF9H9zXlR63Xy+Q674ksf9a6XJ5Gt0zH2EbUgOGDiJqxhOyn7Z6PGYDGDvo74mIUoDGa4NPLsQUisnfOSRcr/cnAccwDfbyTF2Oyw6EiI8EjMaYUP6rBIvFEjSzp3z8rjTPB4yoV2QhgQPAfUsO5gZiIomImtDmGBq/0JD9hLikmsN+o+zI9Y0HyJo9Dqf1HM+HWeqHil9g2qaJSI5iToJias1bdSqDdJx4nYXh+kREW7CGc6Hv06FcmyGishQwqsBzJZIvk2I8yXJ4+yCN+3ONdsK90wQx6qXrTop8EUDKTID90L4RCRGI49yVlhM2BSgFtC0dtWTclQbsCGyOS2ovfAb6oBywTTvrStTsFCDXXECktdRipQj3lYPlzokat9cJeJ5BRDQXZ3s5khz+O09Ehn2Tt4pUU0wIxH2hT9BYCkRZYn/UFHpqsc3ljca5jIWW7MNTYDPBVAkcHxHRXIb/x6prXMc89D1irfJ59kvpKWX3V7g/W13uAO27UUnwKbMF2TedXkxnJyKi2ypyrm2C2UPM63JL1g+fneTgOVtWrQWmwD8iYqbSk/dxtAIoNhiXNU/WG9esiEHSCLhiYuc64XTVn0EsHWJt19pyvOFUmc1wOzT78t4R0zIGzrGpMKX4iA99UUT9VroQsn0LHHieEUhblQPmUhPusRfI+9gAY5OEPcy4QkVNAEptDZ5NlAF/t9yVz2Ww7ohwWenfJ/K5kK8U3QOvy7ZsArMJsYVpV46P8WiKvGD4HJHXNplMJpPJZDKZTCaTyWQymUwmk8m0o+whuslkMplMJpPJZDKZTCaTyWQymUxDZDgXUIzSFKHYtlORyxDSmYSQSXUoL0UhlCxJEBIeJmiYMFzcUac7I+plDULRCgGHY0QVKiYPoYZ5B04Md2RYQwinO48nuGwVAUtZiHJAhEvUlaEaiEjJRbm9Wir0dhZCkvAdFRFDezN8rV2Asvj2lgy7WGzydUU4sYo6KsGJx6tdvl+sQ12F0RyHkLPlNscn3evcLfK5IffBTJPD3ArQ/kRE+/P8/yUQMujA6eZERDm4Dx/ibZNqfGCb3V3l8NOltsw3keTxhyecByQbHUPHRyGk7oCM0KUYjJ1JCNnBMMF0VJbdgra9GyLlF1tFka93hv//gTEOh3vhU06LfOEJLu90jUPyq56cD9kI37AHiI4vr8o2R2RQEtA/ZxT6oN3nNuppAwAaT3KInVPjfl9py7DoB5z7B+miz+FUvaasHwHOqQBjOwNz7UBWtbnP18Vw/Q0Vyrfe5fGB7adDvRGJhCfU37M8LvKNlnmuVOBa6+q6TSijC+HdSWVbUIttDHuT7Y8h55ko34dGWeX7PF48d2rH69xVkWWfaPAk2JXmxtybkeHaSx1GBmzV+T70HSVV2Pt59agt/o+4CcO5PAxdFF5EMUrQKkkMRzTOYzsVMqYhk5b8pWLIcw/DwMccGaaKaAsfMC21mAxD3HCX+B8M5Y1wfbRPLvfZZmyCcS90pR9BP1D3OJ1RIbAQMUl5QEdsKkbCKeDApCEcMx+XcwBxXYi1a8shS+egeIgIJTUNKQM2LR/l9Gxa2vA6tMVt62wX/b/jvta2aleGKxU2eU4W43LJe64JYeUt9qFnHBkuinLA3/uBtOeIpak7jJ4phCWRLwX2KRlyO3sKpYI+uulzP3VJhkWveJu0k1z1O5lOh21fqc/jarXPSC/sF12HFO0czk0kMRy709wuq135mYrA83HZOoR7X5bb8nCB++ZsS87JtQ6Xj2HWGRWNC26JGsRzLRLocc5jxIPQ5WwoMUOIcDjhsh/XSKl07zJ+L+RK9QCx4gWyja4ucv1wLeGpumI4ez/km89G5Ti6G9YgbXhrVG1NMMwap/8Vozq0mfF36x3uw7m07MO9GR6naViLIU6SiOhYnT8H0d3bUFGIBhAooCigOhKyjRB7uAF1bfZlG2UAjbGb2D+0Aom/iwGSC/1B0pV2azKWJi+Qr5m2K+lGKOZGaNKRyMYYoBlHIozaQ4wPEZEX8oDJx7gPp1Ky7cfAjS62uA+P1RS2EGxDO8U+ptXf4jo4cuyUHPYxUZijdWV/c4ClScOY1ahOJI1sdgGXovaoiDBEHElGPdnBNXUuiv5e2qoz8EgoCnVHnA4RUYa4MUtxbku9I0L/j7YZcSnzCYnMS8M+MgPrjKeMSIzPqSbafbAtCpcymkQ7y6+3FXomDdfCfIibePB/TuPeYrEj7cRZd4HLDnm9P96Ra4FCHMcE4m9k/Q7l2Ef3wK48AMjRqZqsw7FNvlYTnl2NJ+S+JQKYljsqXNdETfq8HPiVjR6Plbjas+3NAFoI9nbTKWn3sWmBtLFtrbjYgbrDWnZPRvZNIsLlb4KtX2jIddrtPq8PZ+BZh97S12E9FgH8Io7LaeXz1oH6hvO4ptZ2fZjz2J+hsm9Nnyd5HtYMcxlZWXz+hXbhdF0aAxd8fAHWNJcX5TzE0tcA2eKr56OjMMnRX+fVMyB8vvbNDW5/xKv6jpzjXYfzdcLaIB1XmE3UZDDH6ZRcQOB8RXRMSqE1a55HXjgcCYuyX6KbTCaTyWQymUwmk8lkMplMJpPJNET2EN1kMplMJpPJZDKZTCaTyWQymUymIbKH6CaTyWQymUwmk8lkMplMJpPJZDINkTHRQfvCvRSlBFU8yU866zDfNNtjXlTdrYh8I8Ajnw1nBump5HCm6Va3OEi3fclMWqKNQfpc//ZBuuyeHaQzrmS7pojr5wPPUXPkkBNYA3DTeEp+r3JRluFKLrCw654cOnXg1xVj/BlJrJLsaWQp7slKrhQi4taAYY4MdCLJNcpEmZOkeWLI8lrqYF25hgtN8RHBRZtNMzyyDNxzIqKyuz5IJ2BKjSYkj6kHDDZsf4XGFfn25vhNzX19oMoVXHTPcb1J8s4qHR6LXYffW3dOiXwjxCypRKTIn/FlBZPQtkmX6zCT2pl3/2Dd+ZVSgstbasn+dIjf++o6t9/Fx0dEvlSceVU4LrWQ412FMYucQSKiYpzzJYDvFldwNmRtI8O1o1il3y0XBulGH/NJlm0tWBykkzB+Y47kjCML1AP26SYwUYsxyfDCqh+rcTtvKjY5stmQq1xQ+FvkxW3BWQyuI8s73mB7d7rJH4qocT4OPH1kuJUD3eac1qw81O4c1wNxmJqr6pe5nQNg8m4GzF9r1eTNTyS5EGT/ByRtew26AKuqzx+QZ2YwTzMXFkS+SXcX9cMufZtuJ9Nwhf/yN6raD5l/Hpz54KizRPBzuQj39e6cHAcezJUq+M2MJ/Nlfe7TmMPX2pXlsjvqIJBVYPR6xAMppri6AfiYRh/9sxzoBWAkIvt4KiUn0WKLP4c+pq8cDt47ctUbivWYF0xSrkNb3W8G+JXIVRxRnGZcM/SBV50HXud9dfmhIviosy2uQ1ku7cTap+HwAsBV46OPnNw+MxyjETn/e75kkg6uG5Ec1NE+c1CRjY9jhUjyMZHVO+FKVmk74Pd6Dqc7jgTWb8K4CvpcJxxvahlEJZfHcgY4vruy0uddU+K12FyuPkgfOzch8q21YZ0AQ7GvF4ugZp/Hykhc+rnZFNcpEwVOeUO2ZaXH4zRFmvHNqsIZSD7Y7aI6H2Ej5HuMgR/woJ+IiDaoMkjnAvbruA6aT8k5hOuJr2+wbWqptdh8iuuagLXOakfeO65pcLzpM1/qsOhFtjj22YPX4jLAHW7zz9hXIwmefGda0l6uAqjdRXuZltfF86oQadwEOziWkrbgQI7vCW3Bckuta4Fzm4KzISaakyJfDGzDGKzxN7rSuHQCIi+U6z3TdkXdB9eW+syny0Z4jZQBvq5eu5+FPWEWFqmXFKRBwTO9Kr3h55T1YZ2QcXh/7cDav6jOR5iDw8M2OlyHvifrUAR/kUL2tzoPogTmCdeUp+rSTiDXP+Gir5XzvwZ7nwNw1kQ0IccnzqndKWbUx9r75H3EuIKVHldQX7cOZyzhnncuw+krirIOHqz/8TwNvf/FuqZg7qajw/cPm7A17qjnLQGsVdCO6bNhNoEPvdpj/3rKPSbyRYnHWDbksTwal2uGYSz7IJT3cWeVy9iT5hvBdm2pvd1Sh//HZkm40v62wa+cbAz/be0o+IQm7Gv3ZWQfpsEP4zlb43HZiUsdPH+EX28qs7nS5mttdeEMupQs76klbpe7a9z+m105LjfCtUG622O7XSR5LkMczrnA5zQennO07VkTpzfFWWvqTDt4btGDw+rQLhARHS7yeJmAoVOISdsyEkd+O5c9nZTP6k41eR1Tgf3vhByWQgH4PL2OP5SH85vgOVvLl/dxDvztgsPPq2rQFy1/57N9iIgmogcH6elwv3gPn2vF4cyipLIFObCzSViDdJSfLkYTFH2Y55nYL9FNJpPJZDKZTCaTyWQymUwmk8lkGiJ7iG4ymUwmk8lkMplMJpPJZDKZTCbTEBnOBeSHATkU0LpTFq9vBCc5jzs/SPdIhsomaHqQPpjjkIlJFT6NwRBRl7ug58vwgVaL0SxrEUQ9cNmIbyEiqoYrg/Q+l/Ecu7Ky7E1AqQQQMrw3I0NEkhEIWY9C6KirQuWSHE5R6XGYhKLIUAuiJjCUuqXCwI7WOQypCuFs02kdYsFtIcLNu/L7oQNwX9gbaQiBxVAlIqIMhIJgRNfhUIZmL7a5DohwmUrL8hBnUevxddfaMqysBxiew0Vuy2xEtnkN2qLU5jDhJfeEyHeG7hqkm71Vro8rQxqjMf7/3jaPt4V2SuS7JM8hTxfluA4dCMNLurLjsZ1noLiYOzxkBoYH3VeRmIbROIdg1fs8h9KqjfoQEocha+MJmQ8RKVeVOEy7H8r5dVd5ZzyBDoGrQbwXjvN8VIbb7e0/dZDOBNwwvmq/rS5PnJU2l1GA0C8vkGVD1CHVAbmwrsKOp5Lc7xEILdyflW2EfdiFvtaEla3ezt/NYpsQEYXwSQzlW2vL62JYOdI1Or4sD8MaEzCsdGgghmrWneogHSEuvBfK0P0ZmCsYvqemrggxxdrVQ+krSoD+yhP3e9KVLjmgkLxQcXVMQ1WKS5uGYZKLfmXo57DdSzB4emqM4Thd7DC+wlGzYDbJ42cT5ttCg9OIhiIiikPIJGLBXFV2K2CHOJXgsTOZGu5v4mBPxuJyfmHINWKt5lMSm3Gswfbl9k2eVCl1H2c7AAQBskVCLTfRFiIeQqPYLslzPToQIroMocpnm7KNvljj9QiGamJ7EUlfGwu5vD10ichXdrYG6fEY49xa8brIV/UZz9X1a4N0ypX2JAvImzKEw2+GNZEvCm2WAPvUD2UfJiB0PESInlp/1eE+4iHH7+J1QgXh6wWA2oAQ7kJMjrdinNu83OayVfSvCL1He55SS4HFNpffAjzSdFKF3sa4fgfzTfiM9N0na1zedIL7QyNIWv2d/fVmV+HS4DdIu4I9g7Sv2q/icJ960G947/W+bMvbytwY55p83XxMNlIv4H47COjFjsK+IAamfQHCiMbX8XVkJ1YBjYP3tNqS9fMBVXZ1kcsejUvHmYJQd7S5iF8hkm3Wg3VBCCHc+g78cOfxJrFTRGVYYzUdvqcuyX4vw/4w0mVUUag2OxVqUj9U/CjTNm32ehR1HPJCOSZw7MxnuK+WWnJ+bcG87MBeoKP209dMMzag5jHi8kxDrhk8GNtj4YFBugl92VdzfLEFSBOfx5FHGpeyM9LzqpIcOz+0h/3IX56YHaTPafySz3VqwhyPugozCMPQQ0xBXvpknDvzgOtq9eXeYsNjx47rk6ovUVZVn8sfj3N/roJtPx2TZY8nuP1x+bXeHf64Cqee3j8gdrPc5fdqvsSeZnrcLohw0TZxs89r+SWX+6kTVkW+ksPPivYmGf8zo55hLMN4xn7S+JrRJGCCAh6z4wlul0MROd52AfZlA9A4Cy3Z5hWPy0a3rvdOQN2j/RmubF6hRHHuVcA2b3oRlY/TSy2+4ZZaNHR92afn1fblmCiCf0RUb1EheaZ6U1w22PcVZ13kKwX8rKcK4w/n8WpLfIQOFTnfWAIwrOoZ3EID0SeImpN1xbGNCBf9jCUOz+SigHPRGNUutHkZ/hlT2JcCDJFpWDYn1TotD+ieiofPg1Q+8L1PiTOOZbHDz02rEbmeRgxqF563pkja7LpTGaRXAe2SbMi9864Mfw4RzRVlW+q+R/2HiWKzX6KbTCaTyWQymUwmk8lkMplMJpPJNET2EN1kMplMJpPJZDKZTCaTyWQymUymITKcC2jLqVHEiVOUZLzCIXraID0dYZTFib4M/YjCdxJxiKHQ6AjEHWCohj71Fk+BLzoc0jUbcCjaSEyGNVQ8Do04VOT3MqqnnSTXz3WwrjL8sA1hOf2A611KyrCtCISWIM5Fh5JgqFDE2RnnoP/H8BEMWSUiSkf5xpYhHEhF4YuQTgxRuq+GoezD64AhV/rE30yE69CGUOC6J7+j2pPlz+G9n2rIe+qEHF703U0I1y3IcYmhQldkRwbpWqsi8pW9U4N0EPC1CvFdIl8xnBykGy6HIJ8J7hL5qnUOdd/qcDgr4nRiKgRuHGK1MRRqXh6GTaOAGliFU+476qTnYw2OL9oAdE9XjaNRwBPkIRRKzhqisxBemABs0VVFGV602uEQ8aMVzrcpp4M4YRuxPjhGiYhSfQ67bBOHIWVCORgREwRDjMowxqYjMhwW7z0Z4fbf8GX8WQHCLC8rcuEzSWkL7qpxXRHL1Atk39QB24JIpHXFPil3+b0pGB94gvaD5SN6it/TYwxDNYMQbCxJ4bwuAZqp4TAKwCdZ1/UOz8lFCIVURANRNmILJiM5kQ/rfriEIciyvJO1vgy/Ne2obuiRTy7tUjiXiMMdlPHYRi51JF6nDLidHISEZ1UHdyCsdCrOSAgdbFrzAA0AYeABOKaYI41fBNYPiISIK+RVHxAk9wGeK+ZOiXzXTfC1xhNcn9sqso0aEIl7OM/5slE5GBfBdqF90z6+7DHuoEzLg/QIzYl8WWJ/M5/hsg9kZJj1VJr75mvrRb6PTcA+tBviM2fdBf4H6pfs7RP50uC7IzDP0BZo5UJGi2VDOa+LDmOaejF2CrvcMZEPw8UrYGZ9R2EuALmCoa2LtCErtTOFg/qOXFvkQp4DRWJkmwf2btPZFJ9B3N9otDhI67DeFcC+rUHoeC+QlZsA0sBciq/bUD7+7gqnu4AqGIvL8q4e4/GWT3ObTzQlQueiAtcJXZFuuh6Mgwr4lHxMhgY3ejx2Yg5/ZjahUGABz3Oglog1atWTtWiAI1nvs78uK+zLRpd9UaXH83p/Vno93IO0oZ0RY0ck13BAtaB0VF4XEQd98LXJiMyXimjv+6CyUWkx9+a4LVfb/BmNE5hO8XsJqANiMWdl84s1IfZ7Vq0zcB62YJ1ccysi32ZwZpCuu4xHKoI9IyLKhhmKbFt9mLR6YZ8CitBx5x7x+lFADVzTY+xhKaGwAzAvYVu1DSewWmV7N5NiH3NZSQ6YpRbnwzGSjLDhur8iB2a9z/8j6uVCKLYMYEwuzkl/M3UN12/8HLfDtvVIwPeO+yrchxJJ5JIDSx/cFxBJpBbuHy4qyDaPw/4LSRsBSYxMudfdMd2GNVazL8u+rMSVOJDlSXlWIUhwrY0YyvWOXD8swqQvRtlG4vqNiGgUHBq2g36WgOg+REA66rlRLORrjSVxfyPL24IFVB1uylc8F9xHIk6zBfZ8qiT3qwdKvE645xjbp5G4wldAH7TAx2jMGCoX4zGfjsr5sNTme1/t8r1re74Gtv5Mkwdmg+SGugfIlRSOsaZcQzcAJzKf4TbKq2c7+zM8x7HNH/DkdRHHhL6x7nN9yqFcexZavNZLDMGjEUlMaRMwec2+XAMmYI4iwkWPy5NNvskarCewjYmIFls8pxBpdIkr17KI6xmD54JptS84Uec+QKuTj8rrZmAtkEffuwXo60DiFhFz48G+J+rKcbkZ8vjtgoFLqr3TvtzO/XFfRZbnh9FteLFhsl+im0wmk8lkMplMJpPJZDKZTCaTyTRE9hDdZDKZTCaTyWQymUwmk8lkMplMpiGyh+gmk8lkMplMJpPJZDKZTCaTyWQyDZEx0UG73TGKuQmKKd7ObJq5OshWaqyXRD5kF+WBe9XyJYcH8YIF4DuuKGYw8lyPuHsH6WIKryPLdlrIQmKGUBjKe0J23B7gkU6nJE8sH+f/82lONzsSMnVPuThIL7SQ8Svrh5gvZKmutiXDLQPst11Qv5biZiL6Dfl1mtO63OEOycWQK8UZ04rfuAnsM+SgdxS7vtFnPtaqw4zEWEcyaseT3C7IqVoPqyJfxWXW/maHOVXV3ozItxfgasgtnQwmRL61aJ52UtqR4zcAbh7y3JKu/Hw3ZG7YmZ6s+3nNxeRnlgCG1gl4nIeK/e1kgQsMzaznEHLQUciJJ5J9WooBX0uxQHMwKc81mPFViksueAHGDjLS8J6IJCd8LsP3eGlJMrpKLebrbsD8R74hkeSnlQBpPAYM+evml/Ej1PV4vKUjPCZONosi36EcXzcHvLOtnuLmATv5WJX7c9GTTLixCDPOMlG+X+RGEkleWQTSqikFHxbL00xeD+YofgbPfCAi2pXlD8aBf3lvdzgHGVXuct9cMyr7CcflnVvAxQaGHhHR3hSz3/ak+YZjrizPC6LUDfpEW2S6gJJOjGJOjFrKNifBEYwkuN89xd4706sM0us9tm8RVzI+D+S5jCww/+7Yktct97jvkw7baTzvQs9x5CpnXP5MUfEr43AOwiYASZfbkud4usXzEJnoxZi8bgra6Fybr/XPmxLeWQZGbQ+Mcy+QvjsOHO+CwxzObJgV+XIxmIdgzpu+nNhrbS5vue1AmtmHZacmPtMKmZMdAE/zbkfaqgP9w1wHYJpWSa6DkLPoh+zbIoqD6sLvUgIgRCJjmYhoFY6lqIT8T4mk35xK8DhFHr/blzYt7bCPicP6dd2X97vurnCdgOdecHisSCI6UQPatuNz/VY78t4LcD7NbmSTJ2S+Gvgl5KWXe+qsGRj2k0keb+mIHG97DnNfJ5/Ja67n/+0ZkS9+3/wgfW+NnaheK+IRNdiWmm7drqRoJ40n5dokMYSvi9dVyHEqwoQY73HfLPSVI4CmONfgC9V6skBkKaPf1OcANcFN4ZpXc8Y9ZPJCe+l1/B1Q3bbP81ifJYLrKslz1adN7Hwfo7AmGonLzyzCeTfIPe+pdQbaIw9A6j2S52cEYMP7IduJuiP7xiePfJLrR9N2bTllijhxSjrS9vWI7SKeB6HnK/rRLEyktY7MePMC73dw7GTVU5BdcHZV1OGyZ5I8YC6Sroy8kH3liQaP8zMNOcjWgNeNZ/h4an9euQfOcoE1eSambSnnawM/OKLWvJkY32QKzi3QbdmDqVMHHxNX8/Xiws6ca23HTtV5YuK+eTKN63hZid1p4C/DOU+FmFpnAG/ah/brhNLXirVBn/dbEwm5tsNaoD05VpOWvwuc7HxQHKRDV9Zvnvh8FGy/mtwK0DicTzef4XY515TXxfOSDgIr/mmw7xs5IO1Na4XLnitxO8QTso1uP8vrtGMNdgr6jCbpv7j9UhF5zs5Ci+9jDZalvUDfE/8fhedd2VD2TQ3mf8llv5uPy/lwtsUXW2sDW19B0XGc4jgaD4siH55HhHYmAqNFn6GIe088UzCvWPhZ8K+1Hp4RIBsd65qDZxhnWrLA4zDMm33tN1kl2E/kkSWu+roinh/geWNyvibgvSN5Hn8P1GUfbsKaBG0G7h+a6hnBVsB78j6c25MK5XhLw/k+4yEz6S8dl230tBFeD2/1eEwsqvMWyr2+8DkXkv0S3WQymUwmk8lkMplMJpPJZDKZTKYhsofoJpPJZDKZTCaTyWQymUwmk8lkMg2R4VxApUSE4m6UWioUAiNBMOQhF5OhAntzHLoxk+Iyqp78rmINogM9iBjQYdENCBU8GyxxfQDHMDME1UFElAK8jA6zmoDw2DggBFp9GZoyV+Q6dCDkpO3JoVP1+HMYBJFToeMxCJ2pQviuDj8ZEyGwnO++mgwlQUzI7hzXSYeL1iA0DUuYTfLnN3vy3tNwi4fzXL8t1Z/LS1xiCCEnngpd+tp6fZA+4d43SG90j4p8fgAInSSHIKdVOHy3ytgRDAnvOjIUfcbhkPVkyHG5oxDOTSTbBWvuqlDDeYfDZTC8cKXHdah4MqxsMsmhPRPwmXxc9ieGUmPznW3JOmA40OVFbvOxuAwHOtfmOVqGMarz7U5zjF0myu9VejLMZ6HJ9e0Gw8Mn8X8MkSyqkMQiRxeSUwC0kwqVX2pxYxyv4dzlfK4Kc5/Zz+Pt4jqPlZibEfnWAEFyHELC08o7YCjfgscYn5pbEfk8QJfkfR5vSVcWiE3WBEO41pZ9U+vjWOJxpO00KhnB8SLnoQPfHfdhkBVgPhzMyTZKR9BW8WcaCquwDJimPoT/9UnWFe3d0Tq3eU6F/E2lQhHyaNpZ1bBFUfKp3ZGxsqNRDv1EZFDPl/3RhxDdGoSRjwYyJLEp7BP3/f68HAd7cjzuT9TYTqxBCDf6QiKiiAMh1+Cwe6r/Maw0FfL9oS0gIlpoQLh4wOGPh/MybhOiT+l+GIvLLYWogjZD3FokkNeN93gQJ0IOod+blP5mbw7DyvmeNrqyvIvybMdmU9wf90W4vfxA9vsY7R6kO4BwWfXuF/nuchkFctB5+iCdDwsiH8JiEOHScloiX9VhFFsbkTKOHG9TPodPT0a4XSq+RPIsAmKqAGih+aS0TxkYL10YL0lf3sd+tzhINzzuX7yu58r1QxLWHR4g37a6clze3uf+WIe11DVFiclKgp86DugDHTqOa7jZFPdvLib9Q3wXlxH80HWDdOrWPxf5Tn6HfXkZXAquhYmIWmDTEbkwqjAt02nANEGbp5TfLMV3xgfivFNLRVFGHtZLu2hE5Gv63BZoT3TY/DngtGB5xYS8J/wcrsGT2zBjsA7y8XWRjco9vu53NwFRk5SNNJMGXCK0JY5RXT+seQL882pb2lXc+7T7uHaSdUX0ZB32XoEj6xCCL084PDdqgcTp9dyWQEmZdtZed5piToKOhgvi9T4gtc6GbFerLelHNty1QTrnFQfp2W5R5MvCuMe5rJEmiGm7JMd1eNo+3oNXKpJvtFTjcVCKsZ2ZTskF3SIg1vZleRz1A1mJ2xYYS3WuzeW1Fa4OnxnEAIeBKCYioqZAIgJyTNkdxDHhXJ5IycmCS5f7KrBfjcl5jfP/kgLb6UOwny4oez6TYl90tsXrm2REr/e5ThqRgirCuj4NmLyzXYk66wfoUxFzIRspIR6bQTqQtrkINg5RHlVVV2xL3P+WFV8jDs9zRgAzmily2lH76XtO8jqjBf45EZFln2ry+vBUne+3rgx6G/Yta+CkLinI644nOF+jz/3UkssbygCCd0+Ox0dG+dCFBo8D3DderB5/jQD+bgnWr2tqcaGRvOeVj8r9/ihwnwQqErFnvvxMCuYe9ntC+VDEjoWwprxrS7Y5ruMTLl9LP1vDdqkATnI0IRsT8Um4vjmr0FOIp87AWruqnv2twHrdB6zVYkuOCaBQ02YHULiAoWqpdXyaeFyuOGznET1MRFQEnMuu9M6oVCKiMbAtc3me/2dakyJfuRulnubKDpH9Et1kMplMJpPJZDKZTCaTyWQymUymIbKH6CaTyWQymUwmk8lkMplMJpPJZDINkeFcQF5A5FBIa10Zc5JucggWnqR8SVGGHc+n+Of/GFqxpkJYMEx6s8ehUJtOVeQ7R4z8aHZXB+mJBOM5psPD4jPTwGDYB1FvEwkZmjCfZqbMA3W+vzMtGb6+eo7DHBIRDsFwFCIBw7FHIXw1F5XXxdOKCUJYGgojg6EfCw0u74G6DJ9uQchfqceh4ykVo3dpnsNEPAipzUBY05Yn64ChbhhylYnIe09CGP2UzydytwIZprYKYYcR4rCXiCvDgboeh4F3+pVBuprYEvk2icdEOTgzSOfcKZFvOuDQ9j0JDu8+WJD3i6eBrwMWYTYcFfmi0BhxaGc8zViP5UnAcIzBHGqo0DZsc8SgJGVVaQKG6e40T7CaJ8MnTzf5e0I8cf1IQYbyzRQ5YL/V5TLqHXkK9Cic6L4S5bkWV+MNQyYxIm6146p8gDHp4GnzIhudbfI4b4fcaJMpnuRbNRniv/Cd4iC92OYQp9WurMMppiVQB5ANWcVEOgudhWFWrVC6kUXngUE6EV46SGcc2TeIUsGwskpf4gQ6xOGKdehfxFoQEbkAJCpC+Kki7dByC9AdgIrBUM+SCous9hB5xelzCjOE4Y4jEEbX9ocjNHCcN6XJoKI0DaYh2nDXKOLEqEfSP9T82UG62WD/oE9ex3DACIRFh2qM1WAcLEInInaDiGhXksvD8MkJsB/TaWnUcIwhgqCjxkQdxl8MlnBxV47FONhp9KcZ5ZPRdze3uIxyT87DLiBvArDnnsLQld1NyMfXmvCkfdoS2Bau62xK4WsC9DGsTJQ/H+/KdUsK6nd5YmaQvldhWjacxUHahetkHRmu3w8h5NRBdFqbhqnns49pR2vivbQ7N0hXfW7n0+7xoeVd5bItHVMOsQZGBLEe2o+gzcQ5gLgpn+SAGw2Lg3QpwTZSowW2OlzGKjSLS7LfD+f5TcSEFOOyvBFAn6UhlH8+L9syAHsevueWQfpLX5kT+ZYB86FcghDOcVyPaPRBAXxEAUrUoeh9WG9uwpTC6+CaiIhoKom+Ed+T/Y4+r97ndkgqtNNqn+1T02ffvaHWAvg5HDsaW4Y+awSWSGNJvbZgB7bY4ptfacvGdGBtgONX2+lzPV6s+GBbxly2t7GesoOAdkObWPdk2Zs9Xkduuhs0TFPuxYN0NmCeQMyVa8UutSiw36k9pMaTD2JU47094nVEQp51GKWy5si1uwtzogO2+Vgg8813p+Ga7B/03gLn63HAXPSOsz1BRAUR0WKHJz0iGw9mpX94xihgguA6p5vSRn63zPPhdB18j15HAtYgDbhE7R8QfbgJaAsvkPex0uY2n0nzvWs/sgbYl+WQ96WdruRrxByuRxaW/2MJtlVphRZBhEsP6tfypQ3CvR0ioJKONMCIOt0KeV/qOwod0Wdb1RUIDdlGWdivNwJurzylRL5UZOe5vystG7Pi8X2dAaSG3gc9JcNjZLbAdnBjhW3f4lGJOjrXxnEOKLae3It1YO2DGNq2QmYiphRRL55Cvs7E+T7qfVyHyj7E5yqjgDeZTMo1SDbK9UVfNJ+WbbQblm3frfD4vWNL9jW27UiU22hSrclzgJvBMbbl87wuRuTaswSINOzprkI2ITq1DVNgPCXrIFCn8CxR263RBJe/0dnZjxNJtFs54PsYi8g17zo8q5iF/fRkQvZNBTDPp+rch11lq/bk8HkJv16K8+sdhd0ZjfNcS3q8jo+ruYXbtDxw2hB7RES0jM9yAe2y3pF9k4zIsXkhmYc3mUwmk8lkMplMJpPJZDKZTCaTaYjsIbrJZDKZTCaTyWQymUwmk8lkMplMQ2QP0U0mk8lkMplMJpPJZDKZTCaTyWQaImOig5x/+ctFJIgWeWK9gGE++7KSF3W2ze/dvcXcoJYvoehrgssFDCxnVeTr9pn9OJLYP0jPB3sH6d1Fybbam0UeOfB545IJtdHle2z6yMPSzCr+HxmVyI0jIlpsI3eQX59UrMcEMDDHgdMeKtbjCqDkNoGXjIw1IqJ9wPi+tMive4plm4lyf+Ri3BZx4IR1FB+u7nEb3Vvj62pUEnLC2sRlFxzJyjrsMpsceUvnYrtFvmacGW4rIfOlk6Hk5iEPsxcBTmsoWYBN4AQ2POanLSue8+4sl9cEPnwqKu8YWdnI12oDn94jyaI612EmZzLCdTgoEbWCi4ZjsafGZQzGEXLzioqBdSjH94Es/F0TZZFv/EeZJdf4EjP+Vo5mRb6ZFNcPmbxaTeC2rbZ3ZpoREU0kkbnGryuMGaUjwA0D9ikcgUDVrmRyfmeL677Q3JkrTETkABMZ2e4NT+ZEDtlFKR6L0YYcR6t0Aq7FZSP7jEjeI3JQi4G8j/U+N8yGz2N5Nib7ZgoaA9ul0pOtmY9xv80BZxBZ5wXFIkc7eLzOxmmxI9soATxGvPdqKFndkT7X/RzgtEcTsi3LXcmfNe2sKEUpQjHqKz+SBZuZgvmKY4CIaD4LDH14XR0LQOMJ7oyv8REXdMy5Q+Qrd3cN0jnF4T6vhicZhGhP5rOcLio+/xYwkj3gV2t2J9rza0d5kE1nJb+9ANzWOyrsi+6qyzUDcoIXQ/5MgeSZGS3i9Q2e/YH8ViJ5DkJXrEHkfVS94iCN6w60GTNuCT8i+hq5lpfQhMh3Z5f9BbLOE+qchyzwTteI/WuPJPO2FHL52QjXOxnIvm6Cfzzr8nkmzUCymBMO+6X1PtevUlHrJZfXjshpPpSUfYO8zibYVbRVEbUt8IA9jUz18AJ2Cfnh6105fm+rcFsgIzwvl7KUAV87nmTHue9FclyGPW6L7/zD2CB9tK44+dBksylcG0vGZ8dn44/rDoXrlDzXhPbYrNOwHEPEbAkYpqnhSwl1foa8TifgSjWBz1tX+ZZhjM0GewZpPc5jcB7EBvi2Wk/24UQKzg+Ac100Ox35uskIsKsVLxWZ/qNJnLuyYSo+nDUDa21kuWuWqcLwDrTek3N3zVnnOgQ8jrokOc8ZOG8BebjTKt+W36Z+2KVTO1/e9C860WxR1PFpV0rayMkU29xie88gXfPk/MezQDp4dkUo9wI1YhtyrsnjvuXLMYbnnp0FV4n7wWtKcuxMJsGPQHmrneGH2iCXeqkj5+EqMMdXYZyOxyV3ezzJZXQBfjwts1EMzvspd4fbqlngoE+nh9u+qTS3eS9gn6fPkMG94iZc97Yyt8uRvLS/62DP8XmBfjYBR2FQHnz8SEI6kjrwoRMh319KzVdc+EX0QUpDhOt9fQ5FOCTdvYDfLMRxHZoamu+2NbZPBwv8nKiqzgTD825awPE/25bjDbnZ+3LA6lftgGNsLgO8enVuz9c3uB54rtOutD4bisuPQl07/vDf944Db10zr5t9vi/R/mrIRwnHLzzXCmQf4vkmu7M4lnkc4RwkIsrBv+jvtV+qQNXXAYqOa38iolmYh/iZTbWuwnY+kOcy9FkuHtwjzsmmL+1qDc4WuavC6WpG2rRcjMvIwDz0w+FzKAPPlOYz8JxnS5aN/YbPu/T+reZx+2GfaQ79dzfYx+D+a6GhbFC/Rf1Q8vaH6VH9JfpXvvIV+tEf/VGamZkhx3Ho//yf/zN4z/M8euMb30iXXXYZZTIZmpmZoVe/+tW0tLQkytja2qKbbrqJ8vk8FYtFeu1rX0uNRoNMJpPJZDI98jLfbTKZTCbT40vmu00mk8lk+tfrUX2I3mw26YorrqDf//3f3/Zeq9Wi2267jX7jN36DbrvtNvqLv/gLOnr0KL3kJS8R+W666Sa655576HOf+xz99V//NX3lK1+h17/+9f9Wt2AymUwm05NK5rtNJpPJZHp8yXy3yWQymUz/ej2qOJcXvehF9KIXvWjH9wqFAn3uc58Tr33wgx+kpz3tabSwsEC7du2i++67j/7u7/6Ovv3tb9NTnvIUIiL6wAc+QC9+8Yvpd37nd2hmZuZ7qk827lDCdUVoABHR6S6HJ/chVDbuyuY71+TYjSVAsTScusgXODvH8yRJ4gn2Rp82SO9xxgfpUYhX8FWoxgaEeGBUyEJLhkmsdPj7kzKEiJRU9BmGXdQhHKijwiQ2OhAODHErCRVLuRvCeaqADKnIqBw6B+EVTZ/TOswqDtfagpBTDDEhkiF20xkOi3YhhGirNzzECa+6pcLhMLQk5nIf6nDRlaDC9YYw9013TeTrQLh4s8dh2itx+Z1X0mFMgAsYg0rntMjnJzhMZzzEsHcZhoRhwwcKXF5BtSWG2GGY4GaZQ5zcUCE+4B7bTe7sZFSGm+dgvM2lOB+GPhIRtSD8v9EHPIdCFUykuQ8aPR7cmZIMXXJmRrjsKrf/P6zKOYljAsPPM1HZRk2IDoq6GOIk89WgGogT6frDwyKnM1yn506wbbmvJnE/GEa/BTGEaYXnKcR3Drtq+/L1FNzjRmd4TGLRZbtbCDncWV8XS0iDnUhHZV8nO3y/Z7vcN91geB0wpEvjMDB0vhBle3SswW90Vfgq9vW+HIdZbrRlRuzfZZ99QKAgOgHkawNWoansZcx1HpM4l8ea706HWYpQnDqO/DXcSIT76qpR7sScmq9JwHph+GOg4F0YHniowHNya+2pIl8F8D0+4DA8h23apurXlR4YgwbbozkV3omokrmAfcAVozLf/gxfay7PdmKjKcPmlxtsN9A/F1yJw+jC2mfdxbEt2/JIeCmXEQOci4or7fqIc+HXmwoJcQo+hv6/2mMj2w3lPJyHENu9YMJD9buR5iqPs7M+Y7wijsw3lWDfkehOD9JnaFHk2x3hMOupFBuhDRUPf6zP/rAecDqlsDS7gosG6Rj4+NPuCZGvHTKerEiXw32IbAJtkwa8hgu2/iLaKz7jwXq1A32m8UF1wIxh6O1aW+Zbg2VRDRYdpYT+TQ/34VaPx17nz+U4zyd4nH9+lfOV1ZoSx9+oQKnIiTib4r5a6VwoLJoelvBa6Nb3Z/k62ags7N4a3ztid3pqXYDtvOXwPiUk5cBAKWjXvvZLhH6J36urdYsfRiDN/VZT6DREwqFvzMWkj09HuYwKrFU2u3KdtuLyL6KzgMmq9bmzxxISBzed3hn1konJ9VKiAfsqaJeUWic3AAvihWwX5lSY+1iQo14Qo6/V6DGlx5rv7lCPouTQqbYci2lAoUwkOV1S/YtjrNLhvkE8FxHRoThjR2YzPCZw3U1EhGbo6hKXPQv7kWccPic+4wP6gGB/dPqs3N+cqLMz+tIaIhoVfhTsNGJu+moeInomBFuA+1oiooM52OfCDW4oYgGYdyrGAcXkyvIQcxcFxOqZhrQ7iEFEH9+G/ZHeZ0wm+M0KPCPQhJUpoJ0UATWrUQ9ByG3eByOu2xKfMyQBXZlTTL9NQHI6Yq2o1zeAyQTDszstbdqJJo/tg0CYcdXacw36aguQNzVA87YUBiUApEYV9sl6T3Ewy/deAvTvclutAaEPEZmpcT9tGM9TKfR/w/G+GdiLaZ+M6+6aB+vunrS5WD7SCE8GEpWcAlRO1uVG12vUMuyh5zL83qUltkEa+Yh4M1zLnm2qvT90At5tX/XNCCBu8fnIqYZsS+xTXCZo1E4eniNle3wfu9Kyr/cC1ueBKhf+1VW5sJpI4hqfXx9LynmIwrfy0eELqQqgu3Dtv9qRdSjCeqIEdutkU/ruNXhOiX4jqzqx5cfJ2wbV3VmPq4NFq9UqOY5DxWKRiIi+8Y1vULFYHDhyIqLnP//55LouffOb3xxaTrfbpVqtJv5MJpPJZDI98jLfbTKZTCbT40vmu00mk8lk2q7HzUP0TqdDb3zjG+mVr3wl5fMPfnO0srJCExPyoKhoNEojIyO0srIytKzf/M3fpEKhMPibn5//vtbdZDKZTKYno8x3m0wmk8n0+JL5bpPJZDKZdtajinN5uPI8j37yJ3+SwjCkD33oQ//q8t70pjfRL//yLw/+r9VqND8/T+stn2Jun6bTslkqNQ4Z2fA4lmd9c/jP/RHhMhbKkK6LcxxGmI1hCJYsA7EDowm+1mqHMy40ZCgEhiF/e4vT+tuSjQ7HeCx7jMC4qlAQ+RCLghFnOrQVES4+vHlKkmwoBqdZ5yCMo6lCTjC0aoU4zDoVyJCTsbA4SC+1MDRIoSMg7DLpcvuvQMjgtzY01oLLOwjNokOStiDMCg4apjM9GSt3OryNS8bwcxXNgugHr899k0zIvsnAvfsOt9dU6nKRLwv59mc4NPCAOpw8FeH7xUi3xba8300REgOnIjvclovOgvhMAeZABMLSN1Uc2BYM+jkI19OnLG9B+ORZQBVt9UZEvqeMcojzCKBdtpYl0iC9xGPs3Aa383JLdg7OG0QI+KEKRQXMTSEOYz4m72MUIlPxvaiKXcRoozE4nXz/BNe7pk5mX4RQPAwP1UikyST3J4ZFbnnSatyxxfkqHodTRZ3h38XW4YTrc02ZDzE340luo7GkrB+eKk8wpTqBHDsYnjUKIasXZWXo1xic6B6PcBm5GA+4sy3ZlnEIO/yJeR5Hf7UkQ8JP1bnsLHH769DMGKAQPMDSrKkwNSIiL9z+2uNF/1a+e84dpZiToIKfE/njELq4hmG4eoxBSDESehabcl5Xevwm9qGeAwVi+9KC/ssC3ijuyHDHXgj4CphfJU/6vEmwkWg/cO4SEd1VZbu42p0cpJ82UhX5TgPeJQbj/CljMmz+a2t4T8VBOh1KW3rFKNf3ACArPr8s2xzRIIhSIRntLNBnXYcNAIZP++pDudbsIL0/B7g15eJLCcBSdBilklUh4QUYR37A9Yn78kHQniy/14GY2pYvbVUM+nqejgzSxUDiw1DrDiNb6r58WNXo8v+nYXyk21erUtjGxYH1UnL59fmstH2Iw8IQ6aqKCfc9vl8P1jA1T7Yl2jsMr9dexA/4lWXwFWdbco7PprmM6gVM5RIsMhPgn7MRjYpD28yv6+hkDBvG5XDd2xmPRkSERL4JwOzFXdmWzT7XCfcFqahspTisRYtdCEsnuaCOwByqEfsvjRlr+PxeCcZiTrVRD/rtZI2vhXOaiCgONlKvO1CVLq6luOyztCzyrXj3DtLF2K5Bug3ojm53XHym4/PYxvmpw+ZjYMPXQ7aROZL2LQbb5rWAf00daxVFvolUdNue7vGkfyvfHVBIAYXb0ELdkMfVQpvnSsqRYxF9bxLe80LpvxCLhAiSvcrklmKwp4E9Ku4V73xgSn4myX5pfhfb6cmi3ACfBV+r9wIoREDhXrbhyY03/rcrDevIrrQT44CEuKbEc+WfN+XYBpKSQIEtKiQX4lhWYMG00ZV73n1xnnu4p6yA77i9Ig3rGPhkxGTuSkubdnmBF3S4J7+7Ktfk144jhpZfP6GeTZwChCziMOqK27XcZ2QgruvHIrIt5zJ8XVxXtX15v7tSiK/hz5xoyDZHd3tZgdt8CfZ5Ve1roV3OtdCPazwX27SLspxvj5obC7CHW4bnLZtd2UaHi1yPQ7D/2r6P53xn2zy/8gohizihCvjXpbZaq8AaGPswFaZEvj74x1wMMTcay8r/lwEXjHjPZER+pgA2ow/YM41oFXg5WFwoFy+edVyS47YcUXjfu6vch8ergEBWTgjxZohw0fvuq4q8Fqj0uP2OtSWmOA71Q/86k5b3uwTjr5ji96StkuvkUpzLw+Yrq2drSz1+TjbX4XWQtrBo07AdtiG9ehFyw+E4GtRj/iH6eUd+5swZ+sIXvjD4NpyIaGpqitbWJE+63+/T1tYWTU1N6aIGSiQSlFBcNZPJZDKZTI+MzHebTCaTyfT4kvluk8lkMpkurMc0zuW8Iz927Bh9/vOfp9FR+Yvua6+9liqVCt16662D177whS9QEAT09Kc//d+6uiaTyWQyPellvttkMplMpseXzHebTCaTyfTQelR/id5oNOj48eOD/0+dOkW33347jYyM0PT0NL3iFa+g2267jf76r/+afN8f8NZGRkYoHo/TJZdcQj/0Qz9Er3vd6+gP/uAPyPM8+vmf/3m68cYbv+cTwomINvsdijoh5eMyHAjxK8tw1O2yLw9GKTgczrPHmR6kZ1R47KVFDh0Yg5OoNSIlDScUY1hOs48nwsvvQRDvstHFE+vlKcY6vPi89EnDHQhzx6im+bQM30lDnATiZso9eVNbEBLTCThcoqLCgaohh5wFDr+XJHkfwzAVOiQG7+Nsm38NEYHjteMqeqOHxBVo/1kVppKGcLtzHFUisCVERG2PQ9YdCEF0HTkNR+MHuE4ux1OlSPJXMEQ3DaG3jitvvhRymDqeuDwel2PgRJPru9bme6x7Mt99PT7puuqsD9IX06FBGlEzREQTxHVoAJNDRTiJ0KjRBOd7oCF/wXKyDidgw8BcU2GHCTh5+9mzXO9mR46jtf/LIXpnmjx3x5KyvKMtHpdn3QcG6YgKMc2HHFJ8yOdf56SjckyUIfwcQ/RmUx2RD9EsGN71pdOMLUirE80xdCkBIV0qcokWmvxeAPZE9w2GRnmAsqmSDK1yIVR2wT3Gnwn2inyxgMd9E06o7/qyrxH7gtiBvJqw02mu+8gFTujGEEBsswaE8df7MhAM5/9uCAOb0yFrEO4YArIhqkLqyn3u3zrxmHJ3+F67/xjEuTzWfPf+fJQSbowyUdnXZ5vcP4tNDgn3A2lzC4CsuH2L+6YcNkU+DAONwvJpOiJtcwtQQ2XAcBTALkZUX/sQzh4D31FT4cRXjPDnihD2utqR5R2vcR0aEOY7kZDrGwypRaxNRMVCvnCG1zdfX50bpDWlAcOOj9a57ISyT1EXw2i5zdGnEBGNB4xMEX4EUFFR1ZaIeroTsHbjKZkP1w8BhDRPqHwYvovIkGZPYmQWW3xPKWhAjXOKQv/mnOHYjHN9xkps0blBetI9KPLNpS7j+4BxpJFBSVirJKF+2Rhfdzol6zqWAMQX+J5QtXnP5zKwp321sF3vs71LAwJuLCrnJH4K0TM1FYqegTVlF66l/VcB/AXW767q8LDdHhSiw6KdbcHCD0rRIMW6HsdRzeOMuq6IWPMusE5GJAmiLBKhrESbeH6tuoza64fSd0cd9r1uwLiUjFp7YvgzoudKCXVdqGAewub1/baR9QDt2nVk6LgfAM4JcB/YFWhviYhqPq/tlip8f7tiEo/YCnguox/ualSUw2uQDOxHPIX06/ohKeLRY0KPNd89Hy1QzE1sw/0gGuhYwLYvHUrGRB5wYg0Y5y2nIfKNE38O8aF7MmqPCkgN9F+476v05P7heANQZ4BVLCTlOn5fjhkiV02xn7ttWTLm7waELLp/jeAMQnyWwD6v1ldrC7jFsRTPqYM56W9wLYC2XmMtsTzst6xaf5XBXuF6YleWr6PLBiIiXQtkphv2nxP5Ymm+309+e/8gfaohx1EGqpQF85SXVaUE+EpEjoUXwFDlwF7GIxpXyekaIEjKPelvEAuGe7G1ttx3HyoApg0wlLeWeW+41pFj2R1Sd/06LI3FGnCtI/NttHd+btTsK3wYtGW9z+mTTXnv6BsbUIdzctlNCWjbcaAb6nGOYxbH2954UeTD/VgM0j21tsAyEF8zB9xe3GsSSSwuPmNZVzw4/BS2g4ZE41zLRBFDI/PVoS2iLg90jbhDzCDia5pqz3tfHXCwMG92JyRObwzuC5tvsyvLw+lxJ6Bh22DnO+q5JK6b0c4EqpUQX1eDZ47zGVmHkNBW8euluMzXD6LUCxRjeoge1Yfo3/nOd+g5z3nO4P/zvLSf+qmfore97W30mc98hoiIrrzySvG5L37xi3TDDTcQEdGf/umf0s///M/T8573PHJdl17+8pfT7/3e7/2b1N9kMplMpiebzHebTCaTyfT4kvluk8lkMpn+9XpUH6LfcMMNFIb6exfWhd47r5GREfpf/+t/PZLVMplMJpPJNETmu00mk8lkenzJfLfJZDKZTP96PaaZ6CaTyWQymUwmk8lkMplMJpPJZDI9mnpUf4n+WFOLuhQlonXFS/bhm/kqsDtHHMkWTUaYDRQDLnVSwUU7gABDFpJmGsdd/r/bRx4WfwhZW0REIbCCdmWY1zWdlvnOAYM42eWyq4phXu4Coxr40JqfFHc53yV5rncqIhlHyCSqAqc1r5hEiQ4DnyIhM5iSjmYu8nWxDM1zbQDeqNfi+0C2eyaq2hL6vQJcqdm0yEYXZbnwZp/rl2nKcXQgcT3ncypctn9W5AuI28wHZuW6f1zkK0SYhz0a7hmkGyRZ/RWH/z/b5Lbc6kpG1+k630e1z4Ay/csUZNQjl70DjMqoYrEiw6pL3JhxxRlNwjiqAndQf9vXh4mz5UF5EcnTTsKcSqc5XyopOZfHVvjwpJPQb7sUL/F0nbmK+f7lg3RLcas1s+u8NjtyPiQiPF5G4lynXiDvGN9DO7PQ4s+HisE/leR7d6AF7ynLOiB7kqA/pxXv++Iil7HR5Xwb4abIh+M3AwxjrRRcazPkMdpoyz5EBtuVo9w3GeW9juTZNuOoOtOS83AF2NHIoe8D9/wBiWIV5x58c4tZcTU5jCgLh0qkAIroqzlUBnvUdPneY6GsayEsUmTIWDKx9mdDSkUCykW1v+Exlgfusz7/4kyDP3fcOTlId5QtRWbwaMD814Iq0If50SLmoMZDHjsjwPcnIvJhHsbgzAzNXq6Dz0/BZZFdTUR0IM9vIgP2O1uKgwo+Pgr2N6ZscxbmwA9MIs9Z5kNfi34zqfzrWntn5qCr7FiCuM13Z2FNk+K21KhOOLqG8DY0B/WSPH8Qfc9+YNcSETWAX73W4XXfWWXPT/Q2BukZl22f5v2OR9iPZOCcjL7ih+Mas0V8Vkc2lFzKyQj/j2dFXJSXfmQcxkgZxpEXDD8zYxHOGUE+r0J8Ugr6F29DYUspgAsghzYe0esvTjcAtrnVleMm4XKnIj+0rPjheN7PucZw3iXa8FYfmb6yfps9XpvhWBxX56iUATu+BWn04z8wJu1MMsK2oe2zT9iUiGVa6PE4bQADGucMEdGmcw7e47FX8xdFPuTpp6IMYI71pbPFMwjwbIekL/Nlojhv2FkGyh/ifgl5xBo7n4vzHEg6XL98UIT6yIFZBUa673AdGn25f4vBWCyG/F7SlffUhfNgomCn++FjEID+ONBUJkIJN7rtTLAyMHUTXZ5fmnXuAw+34fAZEvNwlgYR0UQyBml+PebKC/sh9ymev7ALxmKjL8cY8ogjsD/adaQq8oU7I6Xp+tKC+D95lPd2S23k7ssJgXZ7A/ZzLWXe7mghl1oeFIuagrMY0G9q21dK4Nka/Hrdk7YP17141gaeFVDpycrmYshYhjV0oPwD8LWR473Yko2MZ5Pk48PX5Cnww+iK9Drokrg8S+G85D5KrqvwUi3lYFeg/NU2L5gKcWl3DmR39lkzKTwHT7Y/ngM0nuTysP+I5DoSz4Nab8t7uq3PZ1zlAl7fpCgp8pV77H9G4LnMumK25+FMGjw6sKKeQ2F/9GAcNFWT4PoOnw3ptQWuVfBckVJCth8O+3Xhv7gtD+fleMvAM6+VLreDDvApwzrmZMh+ONKSi9Qj4PMCOP8B+52IqOrppyQPKqHWafiMD5+/FOTWk07AErgKBwTWPNnow+wEnqunVYFnNuWQ7bnvyLILsM7dk4Z7z0jfHRL/j2fkfWtdbdBBh4rxoe99L+eZ2C/RTSaTyWQymUwmk8lkMplMJpPJZBoie4huMplMJpPJZDKZTCaTyWQymUwm0xAZzgU0GylQzE1QrS/RDFsQjo2h9yXaLfIlIxgGxmEJyy19pQjk49CNywryuldMMCbhzBaHENUhlMwPZaxGB0JY9uU4tGI6KWMTVtoQYtPnCtZU2OZkksN0MFrpdFOGx2x0dg57HUvKkAkdAjyogwrtbgPyw0UETCjjWQt9br+DEGWloqypDlEdqSF10HiI5Ra3JYZ2TMkofGr0MfSepZEelyc4/L/hTQ7St0a2DZAd1Qe0CxGRR9wWDWrD6zJfB8Ifj9Y4BGs6KW8Ew9EygBlp+wpBEnKIUizcOTQQwyqJiFrEdeg7iCYZE/nGEjwOTrf4Op6KDBqBsKsFiOVf6ch7v6fG93h5lcOB9j1Dhk/v63PIbxnmpMaq/Ng89/A/LEOYZSBDsBoehPxCvFPPlzfShf+3elwG4luIiI43uC1mUtxGu9KczkQ0Kob704Xw7iUVsubBPTbBfiy3FM4BbhFDY4P2pMi3EC7TTtqXKIr/EXPlQQimxldgGByG5E8m5P3uzvEYu7/Cod7TSWlbDub4fwwhRrvqOrKN0J6gHVxS8YSN/s4hlyMJOT7ygB3q+eODNCJuiIguKWSoG3To63K4mpSyUZ/SEV/gm4iIkCAC7llgyoik7YvBXIk54yJfG/AuUUIbKcvzILQ/CaGG4w47qYQrx9hchu0T4is0Zkyjys5rJC7nA4b2VqBZqiqcuNLlzzV9Hr+dUNqg6QTXLw2hz1NqbdH2+brl7vB5jddCzM2ucL/ItyvJ+Ik5wNIdyvEaIVRl31bmebQAqJ6IwoztBjTbrgz70GpPzsN7wY+MQ+TyVFuGlZ4CxEc54PIQYUZElCe+MNo3FTVPOUAQ7elzWK9eB2Fo+t4ct/++jLRH99d5zG5AaDWi8GLqpzUYlluDhZAOh8d1HyJRJpS/SUZ4DgDxinIxefMQeU8LDb7uSEIu1PKwxMyGeB+yr1eBcdCCsVeIKcbPELXUOgjXd2Ldp9YqaHcwlL8BYfOrHRkOXwPUYS6KiCVZ9pa7vmNdfZJz14e5HHUQa5UV+foBj99quMJvqOsiPqUBuEA3lBmzPQ7HLjg85rUJq/s8PxCrUgqk/fVdQN4Rr5vXXMbV7Amk/UiHfI8erD2bCsFXcLkPkjB2xpJyvCWjPODqMB9ycXnve7Phg+uGNTJdQGcbfYo5HmXU4C5Ae+7yeJ9Q8aXty7jcH3GX883m5LweBcJRLsZzaqUj7QQiyIpgF/OAiltTKMzNHue7t8prz7Fjcm838yxAQlQ5nXClD73uWkZ83nkrr6+/tZUX+U7U+bpRsL8ayYVoixrgHBJqMTEObXR1ket+cU7e71fXOSMugdVSRfjlRAT304gSk/fuB3yttQ7n+7N75fOWAxlAn8S4jLwaR4ibQP+11pHzH3Gk+wAXodEnaN+xLfH5z4P14M/BMojWFHuuCLwNLEOj3b60yvlGwSbh8wz9bAMRLogwGbaGfLCu/GZGucZYB/YtDqCNFV4Ol8OIz9T+C4XjSGNV0mCCq9BtZxvSzzV9/r8PaLKsK59DIRIR1xMb6sGRB32QBvwazqe7q7KRECGJ65a6Nxwhh4jcOq2I99a6/Mym2GYfWtuGTuI09q9+rjWZ4s/hCKvK6UBb8EwP0XVNtZZtdfgCiHnLRGS7TEInusT94ff4nmIKxYbPDy8vgc1Qa88K2N8aVK8eyOdBWP4i4KCKynfXPJ+8YAh7S8l+iW4ymUwmk8lkMplMJpPJZDKZTCbTENlDdJPJZDKZTCaTyWQymUwmk8lkMpmGyHAuoIjrUMTdHucyQhyq4sPx2hj+Q0Q0F4cwSTcGn5HlIdLBC4bH1RxdHxmkz7U5jAZDyWrqFGMMU20CmqWjroNIgjbgPyI0PBQKg6502BCGUGFIzKo6offiAn8QESk6VLYPJ67PRjicJRmRF57J8P/7MhzKs9qRQxtP6C5C6FfM4frpELguNFIVKnvUlWEq02m+dzwd+lkTMtQbtQWh9mvVg+I9D0JxqxBuu9Z/QOQLYPytuKcG6XL3lMg3luDy8xCKqsOnGxDihAidJdoQ+dac44N01pkYpHFuZEN5gnmf+L1cWITPyEY/3uAGxKi3roqs6UKYFYb/IoaGSIZm/uMShwbn7pBhPpEo9++1e5YG6ZMrIyLfOoRdv3ovl+GHsi2/W0nTTlqTkag0neL7uDjP2Kh1Fd690sFQJr7Hp5Sag3Sg6oB4mAYgFrKq3zFUDhFGXRVOeCDJ/z9thBsWUTNERPdV99BOGlFhkWgXewG3VyGuQ/lw7gK2APALRER9wNKg3XqgIUP5poEpdVmJsUNJPFW9I8MT76vsjDFIRPT30NyYWx6PD4dkHRB9gCFwezJy3OzKhNTRDsS0Ta2+S2Ho0nRSzutEhNv9njV+b0JhxnYBq6hRYWyGR9LwrEJ3RyA0cKXXFPkW3YVBugunz8+EHJrdVWHMIaDZsM81zuVAlu1+FUI6TzV1WCmrB3O5oozpap/r7sP99tW9t/tskyLC5sr7mAebdrTCr9f7MvQ2hHk0AZgb7eMRNVIGe35bmftQhwnj/1eN8j8tFVFbhvZbavP93VGRbXm0wh+MgXGpeDK0teoygq8SsB8JVVuOunsG6WTI64TRQPrNGGClMNQ7FZVrC1zfeNAd39qS/vAMcO3QjhW86I6vE0lbhbiZuForpyI7hwnrcGJcd6BpG0vI62ZhXI1CeLc2h/h/CpABo9It0eVFGAc+v6nLQ/dzV5nH4qby3RHoG8Q5Las1b7nL9zEH61V0r3+3LO0RYiRwjV/pyXFUCHh90nbYH8ZCOX5nideAJeI15VQ4JfKdjBwdpKseI1L6EWlXow53Koaid0jawYqzynUNLx+k9+Zk56y2uLxKn6815cr5kId1wlHnrkG65vNcy0aK4jNzxDY35uJaXQ7M0SS/h2MgrcYv4quwDzW+MepoyJRpJ8Uch2KuQ62+9CNoX0qAvCiRtH2IjsIe0MgK3Avgvrui9tA4xdD3rMGebbUzvGePN7ius5tF8V7qVmb7lJ7Kcz5UCNnYPp6jl1T5M2fvlOvDOqAn0RwrYohYQ+BeRWNPR+Ps5y6dYVTUfcsSuzmaAKwHdNsdm9LHI6InAwg4bPP5jLQFceg4xGEcV/uRNUCLPB32Iy+dk+Potgr79bvL/F4/lPlwrReADyzFZV/HXX4P7XRF4TBm05wPfcxCQ+ZDpOcI7E2aimG62eW+KXf5uthemahcCFVhMCOWbbdCHQFBR/TnRFLe+0SN99BpeA4ylx2O7cXydmXkPY3F4blKnT/kya4R4xRRKq1ALug2ASOL61c/kBikGVg/4b601Zf3u+6xT52NpeEzXNczDVnZ5TYvFEpxHqNp1TeIX3oq7R2kzyo8KrZzBtZO2r7B4xfR/nr3OAH7+OP8yEGsU4jkGno6xXMo7koMMNodHOd7svo5A795b4UL7/g8djTOcCKFNpdfr3naZiOuihui4tRFPsTcOT1+dlWMy+ctI4ko9YLh+B2U/RLdZDKZTCaTyWQymUwmk8lkMplMpiGyh+gmk8lkMplMJpPJZDKZTCaTyWQyDZE9RDeZTCaTyWQymUwmk8lkMplMJpNpiIyJDuoEPvmhT12SLBzkQwdABj/lnhb5RruHBmlk+2gWUgE4WqU4sE8jElBWB6Z5s89lPFDlOhQVP/jp48gt5dc1Y2oUWFcH/ZGh+TzgkCGLKqmATCXBO2Ygk2YcJYBZ2Qc2227F1BrzmL2Vh3ssymy0J8Nthpy7jrqP+RT3aSHG6RawogsKrHq4xO1/9xZ/5v664t+2uFKXQAXn0orbBHVabnE6RZJPhkx0F9oy7mZFvrRT3DGfk5D3cVHA4/JQSV4LtQGcP2S9bjhnRL6mx6w8rLrvAG+VJBuvHZQHaeTBtvuSOb7YQrbV8PHW7CPXnu93WgEskQ2GQ+KfFyV3bCTO9b1kmhnwmk2+1ePy4y6X+NSDiyJf6/75Qfpcm8fEM8dku1w9yW2Jc+VrG5IFivNydxr5vMCH68t+b/b55ss94N1fgBW53OL6FeKyLQtwlsCdVW4XhSqkkQSmgfcXk/btnirXb7HDE6LTl8y1UejEJWFQJBvTB7Yw9rVmwjWAe9eBNsrBGDiUk4DDNnDbHqjyfZR7Ml8uym2WifDkGE/KvimG/P8I2Pl9Odk5xVhAbfuq+yHVDx3qhw4ttOR8XQeblnC5rzV/FVl+UbAnjVCBkEEtB8asI31CLVgepPHsinMRfj3tS3ue6TF3NDKEM0hE1IB5ngZ/2lQZkd2NzEBkoBMRHXe+O0jHHa7TdLBX5MPycM43fTlABecamYZVeYbBosvM5dlgbpBOhJKJjmxQ5HNvgJNva54uNCAycyfk8KAm2IKqx9fVvEm0mUvAv4+RrOtUOMN1cms0TKVgdMfXfZL3sdXn8ZcEDrVm+uMaabnFZSBHlYioBxz+FLDnT3Qqg7SjfltzSYZ90RjYMYXnFpxKZMpudGVj4rkxiGadSMi6Hixw+zXhfI8vr0u+Ka6rphL8z9WjZZFv38Vbg/TRe3l9udiUnGFkBi/EuZ0rPXkf43C+zHSa29JXa0+C9dgzRtlfHK3z52s9+aEmnE+DZ/Oc8eX5NBsurzuKwDdPkRzoXeLrNsIu5JPrwVzI47LhrEB9VkW+DTj/aTxkO5EgxUslLm86xfc7rli7ERjbfgPXfXIspmC+jXhsM6JRXnRMBJLffHGR2yIDt6up1jjn0ZJq1jH6h3EYb4stWWIy4lDHNyr6Q6mUdCnuRqiqmLyug/4G9wUyXwMWx3i2VqDOdhhN8PjL54afsbDc5g7fBA56Gc4ii6szwXDsZIE/XvPk/Lp3ke3O03fxWiDx3D2ywBG2udG7bueX43L/cFURuM9ot9rSP+DZX3i/p5tyfnlQRrfLczIXk9fF/cgWHJeg9/s4f9FfoE/fk5N1wCNbHuhw4V4oHY7rsK156gi+Ljt0HZZwPbClMylpIxMRLg/XNzV56zQH7iIN53aUu/I+FprcGLgtzcf1Po0vlob212eWrXV3ZjSnQ3xOJNsIzzfBtVNbHQCHZ6roMyBQF2V5j4Xn9qTVgqnchfNW4HmXunVxzgDu1fWcXGjwfW3CnitUc9xz+L14yP2L5/4RETkO30dTH4YHSjk8f+se12ESDgUsqectXsC+CM9A0e3ahO7E+bSsNnytgAdgxOF5PZVU4xzWWci/zyum/wTsz3FN31HnEuZieMYjvx51tZ/j9OG8WhSCbq/AerPJrPIuPCvS54qk28VBGlnnOt9YAvn8nJ4MpkW+Mw0eB0nYD24q35OPuQ/7PBPbnptMJpPJZDKZTCaTyWQymUwmk8k0RPYQ3WQymUwmk8lkMplMJpPJZDKZTKYhMpwLaDYdpbgbo4QKp1j1MGy7MUi76juIlR7ny0c4ZgJDPYlkqAqG+10xr0ImyxxavdrheCUMwUiosLKnjXDINOImdFhZKYYf5HRTRQxVIMxhvcNvRl05dIoQMoKomLmUDJNIQfg5VumyggwDQRRFC0KwMHydSIbB3F+P7vj6g+Xxe4fzXAaGcCNehoioDmFcGMGiQ64rPseLrbX5OhqvgSFUKQj5m0rKsLI0hC5HA65f0pH4ilLAYc0jUS6jqAbFGPRHG5p5rS3b/ESf0SINt8p1oITIl4pyiG7S4bDDBHGcWzVcEZ/p+RwCX3E4BLnqzYt8iQiHK1V7PN5qKram7fN7DeL2L/bkuIxByE4o5p3smy1o81qD27LeVxMMtNblun71fnkfNY/rgaSBrrru3Av4zZUv8Ou9QN7wVIorj/N6vcN9k1chl5ko928RkAZ1FXO5AuH/GPKnESTLgMb4wgr351xKhnAnYWyjrdrsyrY8XuPQuzPuyUF6NZCYi1KDx1u+yX2znpbjcjrN18I7zMWkzdjq8X39wwqP37kUt9dlxbr4zEic6+pB/W7blGGC1T73QTHG40NFpRNQLmg6xbXNRGVdVzuuhYQ/DG10HUpGXPJD2dAdMHGdgP85EUhfiwS3UliCl6WN7BEgSaBbutQS+VIul9ELec2AmKueI1ExGIqbA//cUbGtd5T5wldBGPNVRYlLWYaQ7tMNLq/myLHth1An4rpuuMsi3y7iuVIFU3NSFidCofsQgh1zhttSD9o5FR3+246VFnfUlsdzr6XCdRO4uIBmXm5Jm4H4NcQxaGTAssf2ruwwFiQbSrTIgWRxkJ70rx6k6325sIqAXawGXMGuwqBh2Goj5P4dBewGERGS6LD9RxPSH+JYwjVNH8ZeQiFquv7OIeEFhdbDuQbdJFBCREQB+Aesq6d8Xhzwhos9ufYR+eDep5Jspy+6RuJc2qucsQ3rwV25hsh360aR6w7DYDvCkMvDOkTUMM9CWP5Sh8cllq1D/HHN0AD0QSKU4zfu8Jqr6VQG6VSosSpQV+hfPW8wnwuh7J2ebMtEJDdIb0UYy1QgickrgC1F7MYpZTPWOxCiD6+3AjlvOmCrZhzGtowEEn+Hwmj9DEwHvdc52+SMm7DX2fClbR+LcJsfBjyixuSZHp6aHpHnhtQNpM2NBzuv5yKqoT34HNrShkKsbcDcC4kdp8YveSGgImE9UQbOiMYMYd1jLl/H0WiRLs/fyik2oBN750Q+NIxAOtq2bzkN/qwA6/2r1Po1BnuGL6/zXDmuWCXnAO9yujk7SCM2lYjoVB1QamDfi4pzkwfbh259LsP/ZNVTKHAPtCvN6/1mXz9L4DJOQr2DUDqmHrQl4j53ZeU4Qp91FJCNDYVzqcD+oQR2u6X8XBX2VTOA9ZhIyT5EzA02n0aftPrc123NOzlfN0+yp9BeInqu1pP9lInCe+BvtM87XISyYSyerMv6bMJiwIF131dXNeuMlYdFTFtNyrt77GO6Dq+D5kO5754KGZeEa8o2yXYp9zxIw+uhtBldh/1jwgdMCyCQRxPaHu2MqKmrcYS4Wmy9tloDpgHhgusljRycALxLAM+yLlGIlakkIE0iSUgrpLJcagy0/dkap7tgs++uygF8W7UySDfgOepoWBykNUQFES5ol7Ww/XZl+L+ZpLz3WWFP+HWNpG37RE7w8By6/RLdZDKZTCaTyWQymUwmk8lkMplMpiGyh+gmk8lkMplMJpPJZDKZTCaTyWQyDZHhXEBx16GE62z7aX8RQh5ycBJ93JXfQSANJA9xOaMqLCLm8gWmEhxu0O1I5MqJGodMIsbkQJ4vNCSqh4iIdhU4pGurJcM7DxY5huXqUa7r0WpO5Pv6BtepDw0zm5ahDlMQNoEoioQrK4ini2MJ+jYScQ7d2A8ohWRUhmcgomMqyfd4tCEbvdnnq9W9nU9m12EqGPqchVCj8ZRso2WIA5OnlisUAJQ3BWFlvUCOoyiETKd8Dr2LODJkdTS5c3i8Hm9FwFngCc54yjURUcNhhEsMwoZdFYbvQThVP+TQoEw4M0inHYnkqMU4/Hwq4DDBYkKO+RaE7PUhfKcdyBAnDM+KgRmr9GSMbgFi4hCVoUNv2z7f4wPlIn8+JssrAy7Ghwm/ouZuawiCYywh27x1F4efntzgk6TzUTkYrx2vDNITEKp559LEIL3eleFYTcAJYZjVuaacbTgHCvHhyAXEHcVhTGgbhFGXd5e53xLKXuKp42mHsQg1Z13k67gc+tUJ+X7jHYk0iEMYLbbeHVty7CAqIwMxpoibSUXk+EVhz06nJIoJQ+ww3P+ICqnbl+UQ8SRgC043JbYg4UaGhm6aWCfrIcVdfWa7tLm1kOda3d0S+QKFbTkvDB0lIqoFjDjJuzxfEWVFRBQJ2U7EiO1iLOAxiqGiRBILFoY8rnQYoweIlCtK7PPGUrKuDrRG3OVxlQ3l2J5wDnC9ieuXD6SfQze10OA6bPYklgbDaJuAv9H4uynwF6NRvo/ZjLRBGPZa7u0c0tlWaJwagS8DHFk8kEtenFp4nROdisi37J4ZpHsh4P1cGf5b6PI4mEjwdfflpZ04DPbgTIv9+omavL+FNl+rQbJ/URsd/hyGYCPKjYjoZI2vOwLrh119xm5U+gpRBT71VJ0/U1Tx5rheKoLtm03LWVmMBZDmsrVvrAP6IBvlfE8fkW1+tsVtuyvPvrFxVo6jb5zgdUcDMG0lT46JKvhNDGnOKMwQ3j6iwM60IkPztWFdgCSFvQotgCHmuF7VZioB2BbEubQciaiJgz1pgnfU+ToKgTG4bCDnV6PLuL5kmseO7yhGCnT93T7PIc+XY2w0ZAzMFVkOla+o+b7S23n/UAecVsuV90TVqUGy3ef5idg5XZ7AFrly7iINb7PLN3iJIsqsdRzqDo9AN/2LOv2AfDegM/6GeN2BeZgAG54huc6NA7Kx5fA4qDgS2eY4XF4KMIgJ9RgkFeH/Z4D/gzjObYK6pmD6txQOcjbF86gNe4bgM98U+dwC1O8S9sOj35Xz5lST8+G+O1RYhKU224k1wDIu9Wsi3wasb8obvC7Y7MpNJe7NEMOhsYXr4Jd2ZXEvxnlchby5tMBzfC+sBRaUXW2AqVkAs9VUG3lch08DFjOmnk1Ue4hV5LSr8JdnGtwHy7CnQcQokRyXmJ5Myuvi3gzpkDMKhZsDNM6JOqdxjZuJyn6qAaYU/bo2Syfq7HBwf1TMyjZHc4w+qqXmBq7Em4CH8RSyCTEyuPUsd2V5G7TAZQBeq+SMi3wlQN42YL1fVev9RsjjHn1W15HorhBaai7cN0jjmNDPEhBhiutL/bhmFZA1iCbap54vbXZ5PuDeeiop97XTKe7fjTTbyKiaXxH4/yDsQ58xLn18Feb8A3X2m0tteSMHsrwoycBzC43nS8HeIh6ys5yMs20qJPQai8vA52l1Ty6ERuFzIzF+7xCsB4mIIvBcarWLz+BENjrdIArDh7fvtl+im0wmk8lkMplMJpPJZDKZTCaTyTRE9hDdZDKZTCaTyWQymUwmk8lkMplMpiGyh+gmk8lkMplMJpPJZDKZTCaTyWQyDZEx0UHV3oNcVc04mgBus+booZCHi1zqXFSydSaTzGA6UqoM0ht1yVWte8if5M8gG+vemuTDfXGN+UJLbX5vPi1ZaukY85Tmp7gOTU+ynSeS/P84QJ2eUpLsqMNTzLNbKTPT6axi/KYiO7Nnt3ryuiW4331TzLPqdeWQDYHTXOlxo3cUk7oDlz3dQvYxvx5RXYtIJOR9xtQAubTEdW8Apkoj9JAbVgWeWFJdGDlQ03C7M6lQ5WOQ00KLvw/TbHfUfuBXdXzJyZ/tHhykN7vc/jXFr2y5zJsVTHTi8kZcWbYXMucyE+ebiqu2RIYY3u2mI5lmyDs8EBwepOcycj6MJLj8Ioz5JcUwx/HSAH7+dFIxPodot5pf9T7fYxXmsebf3/sAM75bwGUfS8h5MpZj6F8yzXX6boXbuSyRsuJsB0Dw07Q6z+Ac8AQ9+JBCAVIdxvbBAnNCU4r1VgW28D3BiUG61Jf8upkocND7zCpDtisRUQk46NmQ7YlGlp1rcgVXAh6jZVcy1vd29gzSB3LcfhtgJL61pTmZnK4B4y8bk42EtgFZwJNJ2Tln4YyKIyWu69Om10S+by+PU8t/eGPwyazNnkcxxxWcfSKiQmRnlmqUJDvSJ/6c53BfeSTntUM8EJCjnguKIl8WbGHSAa4qXMcn2a8e/I/MUb3icOGV75Y5vdYdpWHCqTLqynWGH/B9ZInn9VRSsoBTGng6+LxsS+Ss1t0y1FsairFwfpDen2d7fHFO2r4yMKEjDuerA8s66Ci+qcMsRLQZownpH9BG9uAfTddHVnzc4faLk/Rzm8DanHXZ3s0pLvh8mvnm4wkYU6EsLxHhuk+n2UbiOSdERKebwOiE1ys9ma/i8fiLR3bmerbUmO8GPC4bcG7ERleOh9E4jwMfxm/DU+ulIv9/aYHbIROTc/dkndeRu7PspEaycu1ZWWb/MFbifOubkv3f8rm+yA9GH0pEtAZoUPSHxbi8jzYMUw/WoXr9hfkE3x8Yn9Mp6W/GEnCWQ4/r3ezL9XQf5m4hZDZ5XM21PrBdozBCAnXeQh5YpW1Y50Uj0ma4MA97ATPInciUyFd1yrSTmuGmvK7DdV/v8HhLKpuDfPK1gOf4At05SGdpQnzGh3tvA4c2r85/cWBvh+vzUXVuDy7XV+E8pI22HB+9ICAvlGPatF0bfouigU9lV659PDgDIgEs20xYFPmyAb+XJ04nwt0inwPjfirO82itJ+2Jg2zxkPt+LLnzWVpERJNgQsCc03pPjrFOwBkdYBMX7pBr1F6XB+DkT/PZKxfvXxD5OrBnON7gOfpPG5KrjPxwPKprMiLzVUPeZ4kzy9T9TsCzgCxMjw3lhwGXLjjoeB4Eni1FRDQD7YzPC/qhbEv03Z3+8P05/rvUQsZyoPKB/4ey9XMBPJNm3WcbueWuiHwjAdvCvX0+8+HqlLQJW3DOVgPY+rgPJSIagbNEOmlui3Vg3GfUE73pNJd3psH5ompzF4acD/fkmk1+G0DbN0O2v0WSfikTgWdm0ICNC5hDPNdlsy/Pf4nB84QInH3VVufE4P9bzrlBuhfIOZ50ee/Z8tkX4VkfRESRCOwZ4IybcpftTOkC54jp8xVRJTAUyEt31Iq/58M6F1731LOEGDwPunKSn8f985L0hwstvqcjeTibS53/VoMz1tZg3ddU21E8Qwafrc3JJQO1+3guA9u3PBwao+0MPjPD8+4ClXEG1teHC7weqavnmbeV+X98JrjTWVpDjl/aJvsluslkMplMJpPJZDKZTCaTyWQymUxDZA/RTSaTyWQymUwmk8lkMplMJpPJZBoiw7mAYu6DoUAdX/6Of6vJISz5KIc4ZKIyjCMBYSsRCJc5lJPxD/ko/3/rBmMuzrVleRhGNA4hTysdzne0KkOfmxBL2vE5dMFxZMg14lPu2ioO0iNxGW9zSZ7DLjDsSGMptmocuxFAXbuBDnHkz40nOW42p0J5McQrM8p1yEUkFuH+ezhkeqHF99hW1JhTdX6h3ec2moH4p7RC9WxCuNcDHuNDIl3ZT9fHOFRrRDazUBPCszAcMK1m4b4M1xX7A3EfRER3VbkPcxC1UozJ8etBHzQh9OZwXjYSht59e4vH+dGqHL+TAYdJbrncLhhi0w1k2ekI36QXcP38beOD38OQekQsEBG1fQ4TPhvhEMdi+yJ1XW6Ys21OZyIygAejg0Rouy/HOYZtz6XYLkQcWV4SQqviCUi7sm+y0L/jEKYeqhDHXp/7vrPJoXNJuA8dolcF7MgoYG0mFSrGC7jsVRkdJ4ShaThWRhLy3jFUMxly2FuKJEoh4XIhUQg/1+G66ZBty0iEQ8J0ODbavhQgJnRA+TrgYuZ8Dg8vJSBs3pP3NAyRlFc4lyi0C46cc21pGFY6/O5ql33AXErOteVOlNr+zggsE6sRdCnqEDVVeGc2hHBMh9s8obAZZRglXeJwwJRTEPnGnb2D9GjA+BTEDBARpWBtcK7HYa+rLoeYjgST4jMRmAOIcynG5Lzpgo0802AbdKqhQqnhfnNRtr/dUI6nqssh3CVArGSi0vZNpnisxzFMOyYd2FaF69uBtsyFEjeD/VGBsOEzETmvd6e5vhPAVcJwy6wr26gV8P8xaNeeisO/t8JrkFbIPgaxNkREKfB5TeLPuOp3KCUHbTOgdjrSTviAmEm6iAyR9TucZ2O6K8Nj+0RdYT0g1P1cC8OxpeHCezzWZn+TBrxRWqGONh0OWa8DVi2itg/93twgXSPGqiRUeXOZItSV56HryDnZBd9bh/VqMdBtyf+fXeOy9doiF+VxdLzB5Z2syzZabEuczXmNqHmIOJwO1BWxYERyLyDW57BG+ua67PfxJLctYl9wThMRJQSiCrBHalzWnAY9HLVhbKcBsTKWuVjkcwFvFACias1/QORDhEY6wvM/JNlGG87iIN31uQ6l/ojINx7ltYAL/Vty2W7hmoNIomzqHqCrYqqNwKBseTwG4n1pj0YSuK+Cz/flGjXuRrZhoUzbVXXqFHG6VO6fEa+PRvcN0j1iW5UjOSYQbxYHdFpfoYpysAfOQ99HHMkdqPR4PCPGMxV1dnydSO4d0epUegotAvNhEhBVm1uyDltgF8fuWRqkV5ckfmXvCNvmSo9tASIbiIgWm9wWuJZNK19b6vEcxXmj/QgiGPAt3S6YD4V1aCvu6bEG92EesGW6qBIgHBBxm4/pOcf/f24JnmcETZEL118+2KceyWcTLZdtaYMYBYJ70n8pcKCmx88pap5sc9xTIjJzsSnv45Ii4IiSuKdETJmsArYZ4kNqF2BVoL+pKpTrsrs8SPcdruxiKO3+VMBzdwuQvusKrTnXYVRRDdbuJ8LviHwhPJPLATIM15dERAGM2X7IfqTozoh8uYB92ybMgfXeHTRM53weO4do/yCN+2ciucZENE5VofUQVYLP+vScmcnsjH3pqPXNcpvXrGeabE82enKdNhLj/kWk1D+vjIl8J5qwZwB3XVb3gRhZfA4SUc8j8LlbKc4+FJGo2n6koP0Oxth+69GLJuS+Gvv/rZ5sSxz2RTCR2racbmxHvAyT/RLdZDKZTCaTyWQymUwmk8lkMplMpiGyh+gmk8lkMplMJpPJZDKZTCaTyWQyDZHhXECjSYcSrkNE8kTXxRacENvn8JaoI0N+MzGOtZhK8WfSERm6+LlVDtXCEI8VFUY6BmGDaxDH0YTYhZVOR3ym7NQG6UJvEvKpkBMIQdmfYYTAZFqGw1fghN4KnDT+Txvy3u+s8v/zgCSIuTIoourxkOv4HOYzElehkNBmyb0Qd6GOO04chbArCG/RpyKf6XJIfcvh0MB4h9voUEG2EYbipEK+v6YjT3peakL4E4SPjCVlOMtUEsaRx+9tquhhh2Q9BmV78jsvpM88pcR16igEyQMNDoNB+gSeqkwkT/nG8KJdGRlmnenwuIwBMsgj7ndfoTuaPr+HIeVpR+Yrhxxi14fyYqq8Sfcg1zXkOuB1iORJ3ojTOdeUfVOCiHMMDdKhWsnIwwvymUrxPEoCvumuclHkQ/QRng6fjkqbkQHcURSQMHhi+0xSBjmJEHMYl0frEZWP04gn0QgTEQYKnqPmybbEcLZ9Ls+vmCvz+cDQaTtsxyLK/mJI9BiEuR8pyvpJpAGfvt6vyFA+nP8xcRL98BBHxEMVIPZrJiXHA46Pjs9ln2zIudaAYYqonW+35b1HnO1ILNN2NZ0WRZw+ZUOJuSgleLxgX2818yJf2eGQ6YTD4YBJVd4scSjpwQLbRR0OiD5hyT01SFf6ZwfpSFQuv9Ih1ykLIesthfNpBWwL0N97JB0J/l+AMO26WxH5Vvr3DdIB2J2x/kGRD1FPXZjjrb6cNwvu6UG66XP47rgzL/IFYO8eaLJ/XuvItUWzz7Yf/RJeV2MTMDS7DeHYvqorolkQKTWbkP2Oput4h/1XoAJLcabWPbRbcg4jPqwNdkIRdCgB66c7KxzKX1U2NxvlfBlYGKRVgYkuj7kKjJ0icfjvRFL62rzH/2/2eW4k1PZhFVAvZ/ocjj0ek+Pozi1u2wUM3VeTaE8G+5Tn2nyxJvK98MrTg3T6Sr6P+rflWjY8zSH1I3G+p624Dvnl9xC/oqKThf/CMGQdxjyZ4vL7gAzb6nK62pNzfK3NDqLp8/htkFzvJwnRSTwuu460BQ1A8viwrtrsnxT5UhEOcy+Arcs4MtQbhaiNVrAl3ou53G9oC3q+DMNvA2KtH+W6J9QeyyUeO/NxRm3NBGw7EQlIRNQKuf0QIReqpRzO8akkX1djX9CEpCL83lxG+u5yN9hWF9N25cIsRSlBcVdieAJAavjQh4iUIiIaB/xfIcp9MKL2FhnoR0SzaPRnDMYIzuU8IjPjGv3JaTQnTxmRKJCrJzYG6VVActU9OXYQUVW+lQfqWlNiX8bCnfmLVbmdpuV2b8d82m/WwE44Id9Ixpf167R3totFhVhEzChiHxDhklVIRNxnIIpUbR+G4lHryjeiC+wEbPtW3LM0TC6sBTxlcz1oc1wrou0kIurD+gsRYatduZ/GNc1yiwejp2xHBFBF6Qi/twm/hS3vTCIjItnmTbW5a8G+eRXQhmfpbpGv2mTkUq8v/TBqK3Z8kC6lGO3SD2RbJgDh1oL1SDKUGMVeyD5mPGBsnO/I/X4J1tBp98Agvc13w7rq4jj7tmhRoueWu4x3ycf5uoiQ0niTDrRzFAZtS2GLljs8jnAeFnqyDgWYU0mwW0fyGqPK4+NMC9ayKWngplPcBwKF15bruQcAFY1zT2MesW3Rp5YVMigVRbvKr292OF9X4VwuG+H72JPmvq6oZ2GoAuBqNhXOBf9Du7/ZVW3Z87fNv2GyX6KbTCaTyWQymUwmk8lkMplMJpPJNET2EN1kMplMJpPJZDKZTCaTyWQymUymIbKH6CaTyWQymUwmk8lkMplMJpPJZDINkTHRQcVYSMlIKDjFRERewGCvUy0Ge42nJP/ryhJzdRDt8+V1yfX7QoX5q8jaPuceF/n8DvPULmlfOUjvyTAXbSQm+UnkMRMKGbAVhURrAK9zFJhLLU8OickMs6iQ23a2Le8dEVudgL+b6Sk4OX6uBTirmZRkrj11BBip/8zlJZKSgbVQZyZZHRhpSYUVn41hPm5z/BapKYumHLDa9gCjK1S8rjgAnjo+N4QfKvY08N0WGpxvoSW5duNxrl/H5zK2FLeJgKO12mH+n+bBI+t8d4avOxmXrKwTde77VWCzaW7eWZ+ZhFvuyiCdAK6qqzjOMdqZxeqqtvScndl9czQp/k9GuF0uxJ0cRpOuKF4X8vUmUsgdk0zDySRD57Z6yOqVbZQDhnkuxZ85tyTn17e3+P8LMbmfyThXOpjjOZkENl4vkN+JXjvFDNKtFrPnvrEpedCFKJdRBtbYRle2XgXYb7ELcLo9GIB5sC36E8g0jsBcSZC0aWMRHlejwBYejcsJuwlnNmAd5uI5ka8HZzE0oQ6VHrJT5dxFLh3ahYxi1yPDsQWs45GEHB8FYGoiv3KlJcfAVNolX0970zbl/4WrenFOsqx/YhefsXCqyXNgvSPHRDlkZngiBNZ5KH33RJLH5hgM05N12W8nemwjfVfakME1A8nk7LnMZpwIma8ZUTMnDec8+AH7tbIjr4P3gRoLJsT/rShzUCsBr03u7ko7EYOzMJBF6ZG02dVgkfMBy7bnSh7m6ZCvtekwa3MmOCTyleDckiIwpXFeVBVrs+ow1zNwuG+KgeSW5tFnQTvHFYA1EdnZ3rnqdyjol/AzmucaByZnAc61aKnzTO6s8iCrQfcm1Ppmf4bf7MJ6Va9lHWCLR1s83kbg/B1dV9RUHDm+crz5wJ7PRZmnrdvoruDYID3R4vMqLovItcD9gFxdjLGfDNVa4CqYa7sSlUHa23a+BLcztvlcRt5wBGx/DZiaTcXTx3UWjpeY4iWPw1k4yy3k33N6Mi3XBVtwUIkHZyKMuNK+4blJqx0uY02djxALeRy1qDJI+4HKF+XxUQrYJvZJ+rkOnisEXFs/KsdEpXuay47A2GmfoWHqxtgeRdKyDyd9nr8XZ9k2e8hi7cq6zsTgDB9oc08tlJF9XoI+zMkqUB1usQYXrqo15Wa3R16483rWxFpwjpLrxLaNxamAz9BYc9lXFINxkc8BKG+1z50zkZDryMvA9CfhXKH1rjSmq3B+WAmw6qOJ4QsxZAsjE308Ie8pAtdt+zxfm76sAw5NzwMmsjpfrdzmObAA9k6ZE4oD573u85hMOjJjgri8LZf3D1Hll4oRzjeV5OvGlZ/EuYLzCM8L6igmPfof3JOvdTS3mNsS99147gERUWTIT0Vxv0okz5BBDnoO1oZERDHYn0TgzKw1V9q0So//vyvGZR9qP1vk2wVnf+B5Jm11ZlkL9qhnYOycqgMzvyPt74bPdjqE81si6ty1nMP3hCx8L9iZuU9E5Lo8Ofp9eU5Br8/16/pszyeicm13RYLXCefa7P9bJNeo+PxAnr0m94Ae+KlNODuooM7WqABjfbHLfd2gNZEvCmd64H0EcAaNWhaIfe1Gl+vQDKQv2ISzHboOt/NKKB1Osc2GKwe+dttZddGd7ZNeU95TYz+83MbxJj+H7HP0lR21IS3Bmhyfu0XVQnIL/PJmF867c/jzFV+u450yz1E/5DGgOe/7M9x+e0Yrg3RhY4SkuM3QBtXVFu20V6b+w/Td9kt0k8lkMplMJpPJZDKZTCaTyWQymYbIHqKbTCaTyWQymUwmk8lkMplMJpPJNESGcwHtzfQpHfFoOim/W0hHuZn8kEMhMKyfiKgO4RB1j99bbMrQg4ZTGaSrEIKdIhk+jWGXKAzl3ZOTXZjtcN0x6CKmvi7BEFYv5JxnmjLE6URDho+e12GFuVjrcj0QbXGqKcOGEEmy1eGwl9m0rOBciWN5z25ymE8+IUMs8B7zQ8JZiIjGIM5kJsJ1xavWevLzeQjpjEOYtqL4iDC/cy0uca0t43yWW9DOnQZ/xj0p8m16HK5c2eJwlHxMhvlg6NeZFt+TxrlU4b4c6M7RuGzLTJTLr/X5vXUYr0REa86pQbrrcz+lIxz25rkyjDETFgfpurM5SMdU6JJGA5wXhskTEY0n+X6zMA/zKvQWmwLD4Q8X5dxtQKjcnjTP1+cfUsgFCP38X/ftGqRLMdnXYcjzKAtjtqDybXR5vGCNuipk6r4a4ESiHJq2G3BLZcDLEBFVIdRzDMKdnq5C/BEDs97lsL6RuJyTD9S5Doh20d/ETqT4lVFFm0KVIVQ+3+HwvaYn2wjDwoZQFYhI9mEINm1M2fMehLdh+CO2+FhS2lWguVAPzPnZlhyXfSgEcVXTiqyRiXDGCvT8/rys6660T21fxQuatikfSVDMSVBG+eSNLs8JGBI0qzABxf7BQRpLQHwIkfSjt23yvD4VLIt8DbeyYz1j7s6IlQevy4U3QrafkxHpg0cSPDbHAN1R6MqQVS/kcZOBNcxMRo7t8eYVg/StdMcgvRI+IPLFXUbHFEIOqc+GEo0zEtk9SCOyZoMWRL5ql//v9Rnfls3KUN6uz/+j752Af9Z7ch6uh+yjohB+7TpqbRfKa51XzZNrtgzgphD7ojEX2DfjgJ7SobYVwGYhCgARUEREqxDdimiBa0dlmHUcMAEbFc54uq5xU1yPHKwnMAS5phYQGwGvVcZgDPgK85YO+b0khMpHQzneyoAJaAEWpNOX698c+B/068WYQvrVePzd+nVeL3UU3uySPKOduoDdWuvINsd1G/obR2GVfDAoiPJIKid1goc2LTRkuDJ/Ro5fRNTFAbmA8/jBOnA6B+8lgzGRr+qzPWmCbcrEJBojSdyHTeibTCj3BTjna8Rh7i1/U+QLAgi9d7gOrittVRT+T8R4vZ+A+hAR1cEuLrcAWwR59F5H79POq+nJ8esCFgRxGKNx6X89GFcnajy/Kp5cu/YpoH6oYuRN2+RTn0IiykTkWMwGPCaK4YFBuqMRDiH3AdpmXIcSER3Kss0822afEFXjBW0Nzq8a7Oljrhw7aN8PZXkcdBSm5SjgBWYybI9SEbl2n8zxe2OX8Jj/5ufkPFwChMsioKKO1xTaDRAnuRggvRw5N9ptnm9rxDiSHBVFvhlYkyC2LKqmWh5u60ie2+VEk994oCrnVxm2jm3wS8v9hsiXIS4jE+F26Cq8ZwMwjV0YO1PBlMiHdnaNyjRMiJ7cBMyQtlWotMNIjr7yryNgX2YzXDbaFiKJDEQfg3uYdV+2Udlle9wnxawABQFj1UYd9sNR9wdEvtUMr9m6IV+r2V8X+Yox3htfHB4ZpBMKIzMC6+uFNreLRsjivE5Cv8fU2iICa+ikw2PiHK2KfHWXUSr43G083CvyOVGubxTqPpdBlKv4iHj+shJUBumqK31jK+QxhmvUCMm9SQ3WrAGsOe4uS5txuMj1Qx94si4riHtURCIl1eTFPS9imrQ/xXUpmp2VlrTTJz2+f9/hsT1PvFbpqfX00R7jddobnG90Ro6jXSVeg0xezmus7MKSyFeHZzYnmuphEShFCXq4ntt+iW4ymUwmk8lkMplMJpPJZDKZTCbTENlDdJPJZDKZTCaTyWQymUwmk8lkMpmGyHAuoGy0T5lohGbSMuzFCzhMpwYnZZ9TmJZNQKlkhoSEERHNBHM7Xr9JMkQ363BY42yKQ7DyEE4xmZSFYwjxOoSpTiR1eAwLI7qyUXlPeLIvvjeelLiOrs/hXXU4VVqHd0XgYohpyEVlCNbpzeIgvQhYiv2qvKkUt1kLQue+symH9j+27xykJ6D9L81weF1KVXY+DScwQ/N1VMg1tmwPOrvWk22JIbUYxpxUYWABBKeWQw7rI0+G9UcgXAlPMT+cl4Eocyk4hRj65o6qDA1sQHhrBEKIXBXGXO1BSJfHYTRhku+3HymIz0yEHC4WQshOhiTeIBFyWFMcwqdmMzL0ZhKigdMQSjkeVyFwTTwlnF9//rQ8Tfz+Cod0TcPYjiTkuDxzjpE1GN651pHhRdeOqiPn/0WXFWS43T5AK9xX57bwFXIlDuGjOA9jEMav5+T8RGWQ7na4/TT2ZbnD/+eg7JG4HEfTEB7bEydbS9vSgVvHUMXdaRnK/kCD77cCkdCTadmWaKsQ56BPHS/GuR7lHtdP21+csTWYk1sOj+VZmhWfQHQSRMDRmYYKHQUUDWIzcKwQEa1C6GIU+nZOTknKRX2KODuPJRPrYCFKCTe2zR9u9bgPEKFRkFOAJlPoU7kPY6rtv7HBZSDCpetI350L2U74EEYbddm+hSTHDubzIN3oy/VIzEXbzFKR2TQa5zk/A3OqqO69HwCCpD4/SNddGdKcDdmmTxKHJ+fi0jYHgHqqOxzm2/BWRL5Wl0NsXZcr5arfdjT73AeTgGbC4P92X/rGRvMyvg6xzU2QxEigb0tFEJuhmG2gCPglHXaMuJQRQFkpohTFAHmHa5+CijDdBHwYruGqnlzfJCNcxhrYloonx86BPN8/hvLi/VbUuiUG4wPHWNKRbTQRsg9dDDl0t+VIn1cPeUx04L3lTknkG0mwf0BcQiEu7wmRCU3wCYhKIyKaT/F9eIBz0bCPUQgxR3RiWeH+MDq+3GVfGVETcRmQd11AmkzAHOopBIFec52XRoa0QkDFARIlp7Av2Rg7lnTv8CC9hetLImoSr0vRHmlsUQfWr6sBY59a3TUapjjg/lKw7iYiKrlsd3IBt0sqSKp8PCY6gDmr+9wOU0n5mSSGpcOQqPVkGyNCL4D1V0mtKZc7O//2bDolr5uJOtQLIvTl2o7ZTf+iufAQRSlB6VCy/3ZnuK9xb1buqrkC3YjrtL1ZOV/Xu2xczwJ2E9diRBJPgr5SXCciP5ONot3nsVj3NH6JC1lqZnZ8nYho1Of1RO0El7HRleWtwB6/DcNUY1q2PLY7e+KAmlRohuUO2npuo5Lao46AQ8M1b8yV5V1Z4vcOFXgi3FNjNIP2N4g76QOSLnGBx1WIX5xKyzocrYCtB+Qj2k4iIlwWof096X9b5EMkX5YYB9cJ5UQfizMicCyYHqR3ZWX98LFDE3yKxr5sAGIudYH1CSoTsE9Gm91UiNYNl9dmmYCRFwmFFkH0SdTh+epH1DOzkMfv7eE3Bum94RUiX6nLth7bPE3SFhzM8VwpAWa3J02BmKOIZSu35cYqCXiyInHZm7AHJJJ+bwyQRoiGPNeS9473kYeytS+rAls3GXB9NPZlLTw+SDddnjeRrhwD8Rr3zWwGUHhqvY+01AywrLpqm9kCf3hpkV9/7rT08cuAfP6/i5yOKlswH2Gfj+vmYpznblQ942r0ufJNn9t5pSPHpR/s7JP1nqgKz297sAZMquk0Gk0JbNuFZL9EN5lMJpPJZDKZTCaTyWQymUwmk2mI7CG6yWQymUwmk8lkMplMJpPJZDKZTENkD9FNJpPJZDKZTCaTyWQymUwmk8lkGiJjooPm8nXKRrsUKj5ZrsUMrIkkMoQkSAfZYCdrwDQNJHurC6zBMZdZY/NxyQPqbYf5EhFRFdiMXcXnRtbrvizXJ+kqvlaP7yMGX6V46t6Rv4zcTa0EMOJ2Z5iz+M0tyVLb6CKrmD+z2ZPf53R8ZkStAdtyqZMX+Q5kkT/H+Zp9xSoOmOPUcZhdlmgyo6sYkcyqIwVgxSaYeXm8ISFTdeAdLwMfqxVIpnQuwp8rBzv3LRHReMicsEyE2U+luJyuON6QiT6ekOMtGeE+vLPK/dHpy752gNM6neK6ToaTIp9Hzxykqwnmm44A97wKLFwiEuD4eMjtrDl3o1G+3ylgmM4qVvRskm+4ApyrpY4sT7QLdG+vL+duBO49CWP+3NmiyLfa4UIyF7CeyCrPZ5gFrm1LEni/B4DBlY7KsYN9mIrxe8tNbhi8ByKieIrzrWzyvGkpu5WEuiLGDHnSRETT0OZeMPzmq4CIO5BlDt+BMclY7gVMNX6gxu3qqbmB5zwg93FFsUmx7sg+7auphtzBHjDvBI9ffb2chtuVVlpz2eEMCcC2jSZkJZCRPgtnL5QVpzXhRqntGxP9oXQ471M60qdaX/bHd8vAWYRhr5nocuxwf2hWKXIqS2CnZ2LynJMsONUywAYX+jwHNtwl8ZleyDzGulsZpL1QMhdP9fk95IePBVMi33iSGeZ5GIsKVUgZ4KJekuQ5We4VZUb43IE8zxXtyurANcyGVw/SJ+PHRb4gZHsSBc5oT50Ng3zt001mi+K8uVgewUFjcB8dn9PLLcX7huGSiuy8JiIiqoGvjUKbxxS7E89BOV5jo3FALlsoBwzdfWAj17uyvAhwx880gQvck0xIZIbf3YBxpBifkTqPkYkkl4EMTdfRNgh4pMAL99RyENdcW1225x1HcreRpYrs1Hoo+73lcxmnGlynmCPXlIUYXxf58hm1Xp1Mc/ln21yHEWWb0fbjujkRkYNiHZqpF8BZSb26yLfpMvs/EfI4jwFzeDQh+xN9YCrksjvKF+D5PE3YZ8QVz3MfGIBLonzvzb4cb/eWc4P0hs/jMkZyzVAHpm4A9ikZl6xz5AfnI7w+nFDnQk0QXzcT52vptUAKGK7VHtsPH86X0DxoHKfIfS0l1PqXcH7x63dVZd/g1mIcDmnJK770ZDIU58OYdtbBZIniboImUnLMIqcW1/Fog4jk2RMlOBfHVedVrABDeLOLZ1dJO5EB44989H0ZPItIlo09f6bJY35VcYtxiOBVE6q8iMMO474tXmfU1Z4Nhfz2iCPX56ku3hPUR/m5iQRwrjtHBulMVO89ub6Lbbarxai0J3hu0RpwqbE/NfvbAw761aPsA9pyS0RNWNjjlB+Ny/JGYCBVG2yr9HkmyEiPwL4Uzx4jIhp19wzShYDtnT7HokHsB3bFeIGi/eZfn+N6nPP4rK6a8t2C492TPvC8HLUfSeD6BG5Xr1uw7ksO+Ct1hsxUwM8COsTtFXVlG637vNaLgL/XrHPs+tE4XyuihjnuAfMx/pA+mwCfFSHvf84ryvr1W7ST2o58Hc9vwf1hDpjceK4WEVEb2iUFpP2o6hs8+y4G7PkOyTNkmj1+dhXGcK++T+RLAlwfx9hsSg64FqwZcP1wVp3xmFTrnfPqq7PIRlP8fGMmzWvFUK0Z1mB5tyvL78HRJtRUkyPucr50BHnmsk5fX+azCS7/RmWQ3rdfnn23O8OV2OhxXa8sKh8QjVE38InkI4sdZb9EN5lMJpPJZDKZTCaTyWQymUwmk2mI7CG6yWQymUwmk8lkMplMJpPJZDKZTENkOBfQ0XKR0pEExRQWAf/HMNw9KqLmVIO/k1gArMcmhD4SEc3S2CA9l+Fwj2JcxrBgWClEIRFQLraFvSy3dw4/3ZOWIeExl9+rAEIg4cohgfc+meQy4q78/mUywSEtiLJA1AkRUaUL4bYQmllT+TDsDak2VZXvXJvbD+uqw9QwBKsargzSGNLVCWQ4y/01LvtsFPtJjQ+Ij58G7sNaW6FxoFIRj0PM274MgQ3gPg4X8boiG5V7GA7MnwkUguB4A8PoEG+gQwgdSNOOaSKiKxLzg/R4cvcg3YGOuq+eE5/BUNke4TiSIT9HShzWBJGytCUjpqjrczunYU42Vcgfoo9yMKg0pgGxRQ0vumOaSOJYsP3S8jZoPMkV9iH8KR6VY6zSgfBJKHsuL0O6ED/ThjphWCkiX4iI7j/FIU41j9v1jooMvUPhOCrGZIhTMQahkByxui20dSbFIVN7JjkeKpmVnXNpwJiGE00O9T7TUDYDQqsRcTWWlPkmk1xfiCCk9a40BojAmohwSFcX5n9e2eIihBB2A34vqfp9V5rrMJfi9tLjbRJCye+vc38st2Rbjsa3h3+atms+3aZMNKBvbEqnvNriMVcATID2Sxi5OAp9r/EwrT6Xh1g2jR3A8hCf0nJ4XidJ1rVDjBmrhIuDdD+Qxs8BDETJZVucIjmvNztc1z7gJvKKVTQFc3k6xZW9tyIdzlob5j/c7oiM0KXZtHrhXxT//9n781jJ0qy8G1177xh2zBFnHjJP5snMyqrsquqJoaFpf4ZL+zKrMf4Hqa+EjWUsW0hGWEYg0di04CKQhdpgCS6+usbIYItPslrf/RB9P5s2xg3tnucacs488xAn5og9xN77/pFdsZ61zonqgobuKliPlNI+J97Y+x3X+747z/N7x0+Ln/0ilz+XcUybkowTU0Au7UOAz8NapapWsojawnmknpcDFtcWaPE/GMs8HCeMJAkcjm8+4DmIiPYSQIYMeeAeTloi3XqZ2+ow4LmypdYWiK/COk9UeR8D6qXhsC16qKzBd+gB/xBszy5xnbG9rO2/3F8W8pyhO0OZiTagCz4Tc2zH9RYRkQf9dIm2ZtcbOcm8QWzG/ogngdOJbMN1mHzf1uJ2W/Ul1u4M5lqM4ceBigWIBoCuvKSwL7jOOoF1tx6Hi2CBjx3OE9qT9Zg8hYbvRDz+pyTbBjErMYyTSk7W0TLkHdE/i2pNebcP+YC+V1R4iAL0+wSwTL7XFOkWID75Gc+1dZLjJgD0oZfgOlSt9wMeh7vuzux6kbiOc66Mq7gmxDhxuSzrEtcJ+2P+7PFQrqtqEEMWAS2i8X7NfEoTtTYyndeC71DRdQVyjIjoFKY9Pb+i2oAIxXtMFHagB+FgZ8g/jFO5Nw5THhQFF7Ee8zFvnfjivSyONSK5JkR8Zl/tM/6/+xzDEROk0YR92N8UIHBp6Mti8eK/lzyZaJQNf/NmHveN8+sfUUqLakGM+/XP93jMA3lCzIVERMfAbcExtVzU+1W+fgjz31EgklEG2NhOxvNhkVTwA23leL5ez75zbroi7ulVrBonHIeu1LgguiufxLBmcG7PrsNMYsEuO8/xs+DvX3sZI0hwfiEiigEtgviQ1JHtfprep4u05EpkiAfrNEQODgG5R0RUgBhcc/hdx1pexn18T1ODH0o5WZeI4NwbwVymXvRgfMf9/tRXOKIh97lH7uPZtUapTIk7E+IXOyljXsaOZH6ksC/IAQ7nanZLpGsADhLXlIg5JCLyXIxHXK+rngSL+jAgcO10DkkJP9/tc8K9QKJsVgo8/vfGPK4fD+X8ugb7/efrXF9BIlFAHYjn7YCf20UEYiIRl6jtlFGOB2PZf2PY3/gerCPvyXvge5XxlPvESCFu39JMaZyol0lzZH+JbjKZTCaTyWQymUwmk8lkMplMJtMc2Ut0k8lkMplMJpPJZDKZTCaTyWQymebIcC6gz3cL5HtFul6Rf8Z/rco2B7R07QfSDlQFC8oynDQcRhJtkThoreTvaAQJOqPwNFq091eUIwnRIg8GgCBIpGVq3snn94bz/19lF2y0GyX54HeusJ3nYMQ2E32Kbjtke1FCaI+TFvAS+OHRjl1Q9ki0raBvvqFwDJsTxkXELvtKShnXS9GRlg5Ez/SFS0rWEVqwEQOBmBwioscT7kcNj8v7VF3aXsIE+wf/vqHwGi7Yx+pgGxxMZTluD/jn04DTXarI/KEt/xAeVXBlugTscYgtWEIrYCYtP2g1nEy5f6yX5b3f2uSK7sWYb1kmtCFt1zgTQ+nMFNbADtg+m3lpHf/2Ne6/X+rwSeojhXN4qsZtiA5HjTS5tNidXftlsDsvS8tUdZftT8UTfu7Gzb5IV3wbW7/SIz61/eBPuYN8Gk6oJiI6CPizw4DLoeMMnt6N6dDCRUT0xS7bu97aYgvWM82uSLfcYMtUr8/2vdOOvN/phPv9AaCPTgLZz9GmhuNLdUthU5MIKJmuBqFwAwbYPjvFSdEchF0U41FJ4Vy2KxxbFgFxFaUyIWKC6nm+LtWUjTEhCqSb3HSBjsMilafFc+2xCDZOjKvNvKzn7Sp/tlDgCi++ih3fBdO0tv59csDxJHa4Aw4c/v1yell8ZwJWzXZ0d3YdRNIu6hc4FtRyb55dN3MXY1SIiE5DHARyLXCtyuXAkact4TkXxxfMNwqNg9b7CK4befncpxIu/wgs9f1Mxsg1l+MiziOdEPMg84rYJ7RZb1XVXAZjax88w3dpR6QbOGzfjSB/dXdNpBtn3FajKX8nc94h0lUijtUR4CvSynzUDtrAO9K1TbtD7n8LRajncEmke5xxGesFHhuIIBip9QPi6z55xt/pqsUdWrAvp5uz63VaF+lOYQwsZoyyW1c8jMeAXBimfB0r9MGoDxiTPK9Le6ocmN174JTfHUUqHSesQ5/drsn7aczdLA+e7Odeim3KcyjipbRwDI0omp8OkVIQZ/YC2Y8qfV5zXYc5ZjyVhWhHPAmG8NxLebme60f8M+KlCk5ZpFtI2co/BJTVkCRzYehwg1QTuV9CHYOlPgXejAdYG8RYERFdrnJ74L7n4UjW0SPAtmDbThV/4eUJj/H8hPviNzUbIt26L9chpou1XDyPxSOS+Ersp8NYVio0r0AzdSOZrgf7OYwnGcl0vncxVgKxVktF2cccuMcGYAVXinLsrlZ5fHmwtvjEkYzTKER66bIP5rD+dGzC2IVfyalFNL6PwHAcySmZypDuCnFsuaLmVxfedeAaH2kdGp2YZBe/ltJonIrHBUkyLh+u1bUQaTJ0RuKzTWdxdo1YmpWSzA/m/Usd7gcvxUci3Rrx3DbvfQsR0abPMTMMbnI6R1b6YsbxxQcUWJzwuq+vUCANanI6wOSFJBGBRYdj7mDK+LWhKzEti4AiTlPeQxcciWmJAE/SSrleMxUPJ7CJxvcZqcIHncEmaDeWe2PUssfzEmKCimpAtApcZ6OI13AFkoHoyD2eXSOWLoY14IpzQ3wndrhuEV237Mm5EdWG9wwObYvPJh7PN+OE2yNVOF6czRCte6beQ+HasQcolYor3+m15iCg7g7lPgORymNYyypCHTXgFzvwouZlYpTQaXJXfCfncp2lsC67HMg1ZQPft0IeDkeyzkd9bl98l7tYUBik1CWPDOdiMplMJpPJZDKZTCaTyWQymUwm01cle4luMplMJpPJZDKZTCaTyWQymUwm0xzZS3STyWQymUwmk8lkMplMJpPJZDKZ5siY6KCUnjDsdia6Wpir0yowJ6erWKDIvVoCOOvtWHGQiZlCK4DD1lzlDrCMHgOvrxsxT0hzRpFRvRfwc/fGkndUcDkhsl1XFFQWmXCos0j+/8tLXeZ1xSl/p3WO2c7fmwJHFrmlRET3gedehCyt+DIdfjYGrNF4qnlR/D0Puv0UGGlhJuHDYcI3x/bUTGn8Ga91mcout1VK82GJqyX+ng/cqyCR98N6HkHW41D2ieMJl3FvwryufiTTXQII3prIg8zfvb4D19wXNyt8vzc15XeWYdxgySs5Wee+C+2BDEKJjacutPXBmO+omXxF4JMiW7DsyefWSlwvl0LmdR5MJP9ro85czyYwr535GD5q/G/MDKTvfqf4bHnA91v675+ZXadtWenT+8xB/+gfMcPt/3fIDDjNg0YM3IrPdXS9InmkWzXOw20Yx//tSPaPBwMu72qJG2RxWBHpHg84T8gFvz+UfLJ7I2C2j7ndI8UgHU25rVyHv4P9n4joDJj3T9X4HjeqkhvZjblfRXCP9TLfe0HhpZEjiWxdzZ4cQ8w46TNncKjYs7fqHJufrnIdBakMLjvjPCWv0rdMT/Rnp3kqugXBUSWS7YjnJZwq0LwLLMSnqxyrlouy7yyXuA8Xvebs+iiQY2rgMscwBv5kCszKnuJNjlP+TjXPYzzLZB70z68or/imCz73pVNgxdYUrHAMVYExsiaHP+1BwqPx/E55MuW+XYS5dr0kg/iCzw/oRVz/biDvjXMlss5TqIemKtNKiX9Gdvqj4fx59zDlGBu4Q/EZ8pczuEYGOhFRkHQgHecvVlzrIIHzViC49CMZxFuwwMG1neZ4o7wU45hsRDfis2FwrXc0AQb39OL+RUTUBrZ+ovrhpTK377UaM1I1B7UZ8KAs57hQ9/pyDJ1mzMlecZm/2izI9XkAa72dIeepq9aoeMaNnmNQw4zH60kEfaLbEunevMDBBtdOZ6Gsl0mCPGL+DLmlnlpTejCWV3NyfkXhWDtyHs2ud9LPiXT9wdtn10nKPH7MD5FcJ1eJ2/Aslv0N+eZ5l9Ph/PzkHtzWY+K8+iQDdZRx3+k7XOd4hsST7/HaYillLirme5TIAxJOJly5uKZsqzngbMr9byHH+ank1FosupiVOlYA9C/0PApT+zu1r6R6PqOSlwqeLhHRY0BWA7r33J4Uax3PstExspdy+yIf2j13+gcL13cDWMPVcrJdv3n1dHZdzHNm/2RXnpnxiQ6PZTzDR4ejTVgql2EPOFDrSGSBY9k1ix/PPVoqcv4+eSbnB+Sg4xa6VVRzMnyGvGn93Ajyu17iD5Fxr/fTjTx+xt/R+PcujC3cJ4/0uVgB/6IG73JSNTHhOgPfH0zVPqPl4xqEfz90uiJdP+VnTVOOJ7qO8MwxZHpHKo6dEc+HtYTv7UD/zalXesh97zrMbO9P90S6BNalQXw2u17ISz73co6fW8nx/kafTzWCxqoUsZ30/hzPuOHvfH4k4z7ORRGcA3LkPBDpTjJeJ4wGV2bXl0qS2b5V5XFTCfg7E/XeqBvz90oO742rwM/fzOQYD+B8H99Ri2gQnr3wtirnYRjLM0HaE77/IGEue5DI+WsAZ8XgmJrIZDSCcxViWMNtlhUTHc4qQI7/SwpJj7EZz0paLM6PqyWYUxsR12XRK12UnIiI/Iz73qo6k3EZthYZYfyQQWMM78basNYeJ3LDX8ll5+ptnmyGN5lMJpPJZDKZTCaTyWQymUwmk2mOvq4v0f/kT/6EfuAHfoA2NjbIcRz64Ac/KD7Psox+7ud+jtbX16lUKtG73/1uunPnjkhzdnZG733ve6ler1Oz2aR/+A//IQ2H8q+ITCaTyWQy/eXI5m6TyWQymd5YsrnbZDKZTKavXl9XnMtoNKK3vOUt9KM/+qP0Qz/0Q+c+/5Vf+RX6tV/7NfoP/+E/0Pb2Nr3vfe+j7/qu76IXXniBfP/J3++/973vpYODA/qv//W/UhzH9A/+wT+gH/uxH6Pf+73f+3Pnx6Un9i1tG8I/+3805iobKVffZpmtDOD0pC/1pD0jdTjdLpBe6nlpf3gECJf7YZe/Dzbhciotpph3tPkcpl2RLsj4wY2M77FEDZHurU0u5AIgOc4i2XVOQ7SScPmuVaSlLlriihlN+Vrb8LEcQAw5h0/IOWi/4Q/PQtk4d9yXZ9fD5BhuyJdLqbTlRGDbzLlsw+srJ/Wee7H9TOe1VchDOv69ts2j3WsfbPPKaUR94ErkIA8apdICO1WaKU4FCNE46JjU/bwA6bTV9RVdKkkr2qXyZHY9iLnAt9ZORLoA7DafOFyeXV+ryPsdB5wO6yFT1synazyGTsHejVZKImk5X6+zbW6tJjcGH4M8+R6MQ4WHWVsGz9M3Pj27dD7xeZFu8AcHs+uTA44ThbxsJw+e9aU+98XbPa6XsrKYbte4bRAn0ovl2D0csU3q6VZ3dn0cLol0o5gHaQfGwBf6EtMQQlUEYJMaqn50vcqV/nfWuXwap/XxE/hsxFb7sifT3Wjwz1fKnMFKTj74xQGXdwT1sglOMvw+EdEYbOBnwgYm+9HOmPsljteJSncWcV0uA/JmLS/z6lKFxsnF9vGvp15vc3cwzSh1M+pHCuM1h7M0UpP8XfAkroD1c7kgxzVaFzG+l1RfXEwYmXDo7syuHYf70SiTltUEMBIupPNzTZEu73L/LQEG4TCS2LgexNn58AqiYMrPulEHnINaHT4cIlpkvlWzD6gHD+zmGyTjBFp7M5gQ41TGvgjQJxjjQpgod5WHuxTyc/E5g1i2J85fTeK4mksv0zx1AdUzIeltzeU474ieWM9WRLptmPSxK/Yi2S/Rit5PuH8cucci3SogOsaA+wtJxg5EzNzucf3VAAt4qSLndMTBBRnfO08yHZZjAkgYXH8QEb1lgZ91f8DpHjrKYu7ws1aI7c45NaZ96Ke4dlKEH2E5z2DCD1LZJ9ACjyimQSKxKnuAI6vDwxAlRCTH3h4gKnqwvkQsw5O8w9iAtUpVsQ+SjGPVacZr99RV80jG39sbc5lqCgeJcews4TVbkKl1vMuf5WBcY/8iIhpknA5jMaINiYjclMdKB+4xUn2skbL1eyvP5UU8EmKZiIgOY15jbk15jb+kUEcJrI2xrzRU/33O4zXgFPpRWS34C96rY/6+Xnq9zd1HgUO+51IzL2cpHMsxcDPaakOCcwKOj1Eq+2zbZeTKiDiG17JFkW4Yc188nsD+C/AhW2WZVw/2fV885vt9riv7bw/WJyH02e2a7GPXKhyTPAffK8gOhev6PuAhDiciGfVhjm/m+blXK3K8FqEcRzCHtgNZ3sMJt8EU2mahKDfyDfgR0SfrPpcvVLiUAeQ1gdjXzMs4XYI90QNYd9fUe5Q8xJpLgNfQ1LLHsLfA+WtfLqsEOu5bVriOosMbIt1K5eL9vpriBVosArxWgWR8mjiMI9qHubIIc0DoyIZfhHXBcsb70Be9UKTrRIxFcaGO8iTXYrhPLsNLAo1RPQSE1iLcQr9LwHcOiPQcunJdVQc0ThPWVUck1U8Zd9J2mrPrt5Xk3H2lwgU5hipLFOIHUU/NlON+4HCnwDlOq028xg/SkfisknD+OhHvtXGNRUQ0zHitl8Bng0y2Yd7lim7MQbE8+ZnrGcv3VF2me2uTy/h/7HH9n74K5wT7b5TI2IdrJBwP24UFSIXXRO2I+3wT1tbrZRkvN0sXx8s7Q4mHwf367gjfF8pBuei75+LSPH1dX6J/z/d8D33P93zPhZ9lWUYf+MAH6Gd/9mfpPe95DxER/c7v/A6trq7SBz/4QfrhH/5hevHFF+lDH/oQfeITn6Bv/MZvJCKiX//1X6fv/d7vpX/9r/81bWxsXHhvk8lkMplMfzHZ3G0ymUwm0xtLNnebTCaTyfTV63XLRH/w4AEdHh7Su9/97tnvGo0GveMd76CPfvSjRET00Y9+lJrN5mwiJyJ697vfTa7r0sc+9rG59w7DkPr9vvhnMplMJpPpq5PN3SaTyWQyvbFkc7fJZDKZTK9NX9e/RH81HR4+sWWsrq6K36+urs4+Ozw8pJUVaZPN5XK0sLAwS3ORfumXfol+/ud//tzv98cZFdzsHIbjGP6rAW3gtYK2VvJ1I8fpnqlJK8lHh4xwOOyz3aOZST9FEZoHbRdlOKlY20/maexKLEUerIslsA1pq8YJoAtWfbbHFT353HYElmmwmLWUBeudi2xp+UyXLSKILSEiaoEFCG3lfXXydi9GqzbnSWNG1gM+qbnrcT2X4MTfhGRe+2CN+mLIiI9FsM0SES0W2M46iDmDRVfmoQXeJbQd6tO6UbvgxcHvE8n+1wLbnKvM+2inega62Lov7Y5HgEg5i9F6I9umzQ4b2qxwupuATunFMq87p3D6NPSPel7WZavMN3+hz/k5VnbCgzHnvewBZkCN3eUitwdaIbsKafLyKduIvmGbx2cX+iiRtDgiC+iyOlS6D4iTj/8TNp3lXRkztupQLyE3YkHhkjZWecPxTI375XHA+dM2wQbYY1tg4Sx7ynI9BQv3WXN2fakk+8cPXub68+HU66JC2bzY5zx9vsvfCRSPqAJF/LZNrqPjvoyX05TztAs4rYLql1tlztOjMddlnEmLKbbAms/fqeY4f3uBtFJi3eL3PUeWaQ/sv9gXx8pSNyhx34lT7ituINPtTnLnTmF/vevrMXfXCw4VXYfu9mVFT6DuVnzuB42CjE9jsDje7UM/yGk7Nt8fLeauwkjlYe4uEs4xHI8SZduMEp6jJ/EZzVO9uMnpiDtM5shxHYHNt+Xy9WQq+xNaJm81Oa+DWNuiuecjKkrbRR9EPHbGgHYZKoQO2phxGbOo8AkpjDi0YD+esN207XTFd4oRT3qIaWnmZSzA9VMB5usFkjEozuDBcBmodVV/yugeP3dzdr3uS1s0xj6slt2RbJvb2aPZ9RnxvXMKy1YGi3MMGJRMrWk8WOvlM87EBBAwzUJVfAdt23WPn6vXWGjRnUCs10iLNriQDwP+oUhyEs3DengM4+bxRI6bMSBXllzMu4zhiPXpQ9/T9unY4fuhnb3qyr6D7RYl82PBGD5DTIAP+IVFX60V4VEJoFhStVZsAT4hP9ieXXen8q+BU1gTTqEvH0bSYj4iXltUoD2Kytbfges45XHoqPVN32G0igvj+DSV46bjMnIlAQSRp9qwSFxeRLjg+AwdaXOfOLKMryin1g8V6MAnIawvp2rNVr147a7bZtnPKHi1xf3rUF+PubsXEwXJeXwY4nE6YLffn8qX73lYv7Yi7rOIFSOSc+9SxnPoFU9iCxeLfD9sPdhe0lEgM/uhHcYE3R1wvl/sSdSDB7HhcoXHVE2hbMrw/uAU9hyRQgx0YY4+nPA9+moz0AkRS8N5v1WXa3zEUh6FXJdthTvoT/l7C4Ce7MeyHIjXxPi7WgTka1XW0QjacwxIiM2STLda43F9vMv9taA2gYieW/H5uYlCeiJmFPe8S74c/0/XOL7gPY4VkxYxrVfLHNNGibzfCOLLTsD1euzuy/yljPWYJhxzn3K/dXa9li2L71Q8zgTGt1HwjEjnQp+fAibkWnZVpKsW8B3GxXsdIjn/T3LqQ9Dn2lzeQQpYS1q9KDkREcWwprmc3RSf4Xxz2eFxje1ORISvuRDxp/fnl1xGM53A+rwICOTNvF4v8T1uQ5cdOzJuvTz5v2bXLyj01Dw5MIfuVu/L5w65TYsOt5Nep+3DnL9R4HXudYXMRS3A9L8z1Niti9cWRYW43CgjmpjzdLfH9T9We916Dt6JQCc7j0SCcQjx8sWuSggRHderRdWBe1Em9ievptftX6L/VepnfuZnqNfrzf7t7Ox85S+ZTCaTyWT6usnmbpPJZDKZ3liyudtkMplMf530un2Jvrb25H92j47k0QFHR0ezz9bW1uj4WB6wNJ1O6ezsbJbmIhWLRarX6+KfyWQymUymr042d5tMJpPJ9MaSzd0mk8lkMr02vW5fom9vb9Pa2hr90R/90ex3/X6fPvaxj9G3fusT+8q3fuu3UrfbpU996lOzNB/+8IcpTVN6xzve8TXPs8lkMplMf5Nlc7fJZDKZTG8s2dxtMplMJtNr09eViT4cDunu3buznx88eECf/exnaWFhgba2tugnfuIn6Bd+4Rfoqaeeou3tbXrf+95HGxsb9IM/+INERHTr1i367u/+bvpH/+gf0W/+5m9SHMf04z/+4/TDP/zDf6ETwnthSnn3PH8WeYWLwMda8WW6GnDMTiJOd6w4412H/5c/yJjhVKLnRbohsPyQaVgDVqGngJPIDdpLGMg0yE5EuiXaooukeWK7I/654nGBWwVZpnFyMa9Is4UfT/hnZGPfqEl2kQ/M9RNgHC0WJadoF1jqyHNH5jgR0XW/ObvuRcywCjJmlQ1JAomxzsvATr9akZzsOrDBXuoBuzM9lfcLma8VA5duZyjrEvmOZyn3gWq+KdIt+fzcjRJ/R7Ose8D/Rh625jmHwJVaLMxnaMdpHr7Dv//IMX//wVhyKJGlivzb+wP51yjPtZiR/kKH2wYZfEREO+7e7LoRMc88PpWM9U7EffatTb7fSlGxAHP8WTDm8kVTyRN7usrcNuRuv3VZMowfd7lcHz7mPGxXZZ1fa/H4x6E8iuW46XZ5zIfA1EMuq8ZvBjAmew5/Z5JKdh+GpzawE9cUR+5ymft2CdiJOVf2D4yD6yXue33FWK54XOe5HN9voMqOrPJ3LvG9J4m8Xx/iySGwyUeKxb5W4s+Qo47JDiby3u2QP7wEbDd9NkQIdYlcQF9xAT/d4X6VBxZzSzYN9WI5Ll8ver3N3a0Cke+d5xH7LrLn8cwMOT8sA6O+CWcntAqyfT2H+/MQ+J/I5yUiCmGM+RlzB0NnTPNU8HheSlwea5NIziP9kGOfA+dklJ2WSOcAH3OQcnmRg0hENIw4FuJ8rzmoyFLvhBiDFCPR4XE9cPjMl4OpzN9KJvPxijSvfhXi0P84nOrkRERUyeScjOxInMeP4vn1X3G4zWp5GYOw7AMormY2o2opl1fzl48m2Hf4OkplnefhPIeayxziiGQ59un27Bo56DmSi9Sy05xdx8ASD+H8l7ORnJOf8vivS282OD86VnVhSsW5qKcAlt3o4jZcIdk/cG2LPSxVZ764wKW9Qw/4A4UZvZxenl0jK7bsyILkM65n5P1eqcntEp41U4J54Fida/FwiDxc4I7CXBErzmgK0PwKzB0LEk1OFXjuOsSwRI2tL3T4YfsBt/Vd5wsiXZgyt3XFY95sNZV9IoT+F6eTC6+JiBp57jsNGA8DYNcSEXVT5v96DrdNzpEFxj1RCBzZEjDbS4qtX4KzIQLgre+PZZ0vFPm5eHaCnrvxW4iyV1sOWismNE40k/Xrr9fb3N3IP6nHWk62xxVADeO6tOHIWJ/BhqID/a+qzrVYoubs+s0t7hMNFcJhuFIAYxTb9ySUfSI4/9rgyb1z8uanMQeHnMt9drkg+8kQONmPx3x9OJbpxjAv4b37JPdfiw6P3zc18f2BrPNOdPF81lTnv62WuG7XYLipY0+oBzEYwhPdG3F82wskUxrX0Ncq3BgZzV8H454jUMxxnF8vQde5VZNz6CacmbMzwVgqnzWA874CWFdpgvLuCPnrcO6UPLKBTidcRpzbHPU3rkU476Pscix9usjXlyryOwfQXw4m3Bj6/LfVlM+MC2CN6qtz3dbKfP8lCM2nqkzHAc9zpwG3W8GT+dvJeG2LZ5E85V6ieTpOuG/XHBnrqxmvdxrAeT9QZ+6Fc84zOonk/IXnoOhnzfKt1r+41PPhDI+r6XWRrlXidcaJw2irAsn4hmu9IOV5czmVdYSs+ANip1Ajaop0DZfrCHnmHz+T66Bln2PBAcyVug37cA4g1rLmjGO9AIaeevD9USYXbat5rgtcD+o9/d4EYyR/9ngi4+AEFoVbcB4f7v+IiAqOQ+5r3Hd/XV+if/KTn6Tv+I7vmP38kz/5k0RE9CM/8iP027/92/RTP/VTNBqN6Md+7Meo2+3Su971LvrQhz5Evs+d4Hd/93fpx3/8x+k7v/M7yXVd+nt/7+/Rr/3ar33Ny2IymUwm098E2dxtMplMJtMbSzZ3m0wmk8n01evr+hL927/928X/Ims5jkPvf//76f3vf//cNAsLC/R7v/d7fxXZM5lMJpPJpGRzt8lkMplMbyzZ3G0ymUwm01evr+tL9NebijmHCq5zzm6K+IQQvD3aEp5kbAk4AFcI4jmIiBJi+0IKn2krOmrodGfXzbTG33Gk/WoIPomxM5xdo42XiKiash1rscB/YbBelrYGtDKhHeMgkF1nAC7VPtjcz8LXZoVcqEl70QjsbHmXv7WQl+n2oAsj8uPxVNo43lJZml3Xq1xnu5DMSaVdx4P2LIFte6Q9a2C7QjvK0JWW1bOErSkP08HsWlvRj122tqKja0nZ3hAJ0wIEQSUn+28Kdmdt30MtAKLnESBNRgppsgxInS4gOhBfEUIfJyI6SNhquOayjQYtZk/ucbGdsOTJ/lZOqxemGyTyubd73FYFsGb9rSWZruWzlWz1Xfydoz+Qz92osLXq5kJ3dv2gK+3OHz7mNkXbez0n+87ds+bs+hlAwhz0ZPn+xx5bv9BqeBpwnU+UdbgNAxYtjb6yWWF3xjG5XZF9JQ995zTiQumDNc7A7gi0CVpUCfcC/vAP7rE1rR3JhIib8QBLM5jKcjwG++TxhNs3p5BXlRy36TGgouoCQ5Op7/A9Voucod5U5hXRBVXAPG2puhxD3vvo8FfDs+Cet4mazuv+IKOCm9HDaVv8PnG4rVYitr3mXYkJyMNYqUMIauRl7T+GWNqP5ni4iegYMG1dl/FtUcbxw3Vk3xnF3dn1QpGtnwVXxoIJlBHRHbHCkUUO/9x2eE4pRk+LdBs+z3s4VAoKQRKkPK88HPFztWV902FsWQswEBM1J6CVtAXW26bChCwUcEzxePUcXreMExmnJxk/a989nF2PqCPSIY7lWnpjdq0t4ZUcWLgBEdZW0c/3uLw1mNfVUlHM14iVGKeyjsrEbY95PXNkWw9jLiNigfKeXFvkM+73aJ/OCNehquzAMcCpQy+DcO5ogP3/pZ60zXcBNYBrgY2KQh9A4MdxHTuyjurE6+FGyn0vdKTHHPufD/b6xaJ8bj2vmClfVj+SsSAGu+8KoPU01qMOfbYMHeEk5LUPWtSJiEZ9HgRo8b/VkIiKbQgNGz7XVz0v14B3+xeXySVlY4ZYMySuc8RLERE1ABWV5bkj9Kd7Il03g7gDdnhHPRfxBOOUx6jGuQSAc6kCdseHfZCXyQYou/wZ7sUGmewf+RhRnfydQFnHdwC9U4E5HjFgRER7jncujpjOaxATRYlcYxFJxAnutYeq3QJAUXUcxhg8ld0U6ZaLPKbQva+al7rRxWvldsD5eTyW2AecY0pqP4fCPWFtdG12Xc/LSQ+oWQoZJOtoyce1Nj83VnwZ/E+TBozXnNoPnkR8vxGEkPWyXkPz97Au22q//3Sdn7VS5Bsi1vVluU2mB0Nu3wcDrojL1ZpI1zqDeQ4eq+Mv0FLIh3cJSSbz+k0rHO+iQ35fcBTKWHV7iGgWvl9PrQfHUy57B+rFeRUEbwn2v/VM4s0IcGc4J0SwadBzMs7dKSDk6gp5tVTkOBskzdl1LS/L/nyDy3gMSL+RnB6o6PJniDocTeW8tJzxs+o5butE/Scf/lyFOUHjZhCHNwXMSqA2VjiuEf3rOhJ/N4D8juHdzvUS98WVksYH8XUO1lIP3IciXS/lubIIa5hIIR8vpxzHrhTx3Z9IJjBtsQNrC+dApJsAVm054/XXSSDrEuEsGI8w5hARvdDhOsd1jKsQOrg/rwMusQzvdgLVP3C9hLHvTMW3COZZbOqKK+OqA2u2oofr2vPvXl8rRNVmeJPJZDKZTCaTyWQymUwmk8lkMpnmyF6im0wmk8lkMplMJpPJZDKZTCaTyTRHhnMBNYsuFV3vnF0B7UFdsO8gSoGICJz8wsqjdQlsw22X7WdOKv9PAy07eTh9HlEZOVdaxwpoZQTHwzCTNncPLI4lOJHbU/bJmw2ui2Ye7dzSJoGnQE80kwC0VUU7Nv/+f7WlpRabAC1KnjO/bWR9yfs1i/wwvHcP7GtHwQC/IvA6j1y2xOwovEk5ZltZB9qzM30k0g28k9n1Mm3PrrXtuJGBJRks+p7y76CVFK1ptweyT2yW2NqzWWW7TSeQbY13d8E1fBbJ8qKNGfv8DegrhWFTfKcdcp+tFzjdWahQR2DbWgGfYEfZBKsJZ9CD/wvcrko71loJrHJQjFpe+s82Nrqza+e93zO7XvqTPxbpHLA/ttbYPnW30xDpKhBZV30cG7Icn+ywra7k8T06sey/90cXnxZ/pYpYABk/PjtgWzTGmcVwRaTbyDGCAHEpaNcjIlooYF1ymXqxnEZOoa3QdZWeCwtwP7ClIy6FSNop9yb8na6y1++Dp/AYcEmrnrSBYtuE8KgHUN6KmhmrYF9F51dD4Xmwb6PNUh3+TWtgbd2ZAMoiL+/Xn7qUn08NMX1Zd8YDyjkRDVyJ61hK12bXlyscG1pFOQ6xn3aBMKVtftjjYujQFeUhXoi5z3VhLio7bKXsJjviO0nCD04AR5Io+/o4PKGL5BdkDOo5nG4CiIQ+bYl0K7C2QCSXttTGsKCYgLW1QfK5iOU4C8CGr+bNS5AuB3PbrnSzCkzeJZiX+oASG8Wy/ncBBYJICcTfEBH5Due9T/zgJJHzSBfGctfpz64Rj0JEVIL29aHdNUJrH9ZLBxOe448d2bYDh9dtzYz7MuI0iIhciDs5aM9EIXSWM7aph9DHRmAhbmYyXk5gLbsLllzfk3W+DEgTRPJoFNsw4bzimq0f6djHfayC9vNMWtFrgGZ6mdginVNrwAAQIqcp94NyIvsvLndKAuMlkgmkAa6/8PdERHVA2zwccJkGGa8fUgXtwp/R5l5V81LF43RTyMPeRK7P6zB3bwJWcBhcF+kQAdkn7osa+7ICaIEVwA7cUdyik/Cl2fWAeBy2itsiXZAy08FDNItqw6WU1y5LgK8Q2AK1zsD+izEtJIkSPID4OxlxoCkoZEAK96glnL+TiWrDjChO5+8DTU+Uc5/860YyRrZh3fdwyDFSY089QDM959yaXV9ryDGAazBEBI5i2W64bkOkGcYCPV4Rh4WflTI5j6TEa+2TlMfazkjGoDLgw5ByU1LIG40ueUXPAFZMf2+xwP1cIwsQl9qJuF7boUoI39yD+Vpv/a9VsF5YAeyDthSZ04F3EEhIOVHvW46BqIPruYYigpYQ3ZGbv5i+2+U2WIb1eU2tyb/Q4we4sD4sqfkQscC3YQ1SzWTbLAFCawq1lM9k/50SoC4AU7gbcT86OZMTRM3jeywAMihWQRL32nmId3q+WS7CnAX52/VkjGzAeqSY4J5SdiQcK2dTHuO+Ix/cy/izgcNzRTlRqENAO8WAb2wV5TuRLqw1jgIeD0Em10uYP8TIFKFP1VV/u1xG5DM/d2P8jEj3eMLr8KaHOB3ZRxsFfsBxyPUwJlmXRZgr19P12fWO+1ikQxzeHjFqshCsi3RTwAXj+ryhsGUlmPPrKce7w7Qr0iEmcMnld1xlWB8tksTVYTv5EMPGyfxxjH1vSaH6COpoq8rPvVqR9xtM3Vd9j4myv0Q3mUwmk8lkMplMJpPJZDKZTCaTaY7sJbrJZDKZTCaTyWQymUwmk8lkMplMc2Qv0U0mk8lkMplMJpPJZDKZTCaTyWSaI2OigzZLT3jZ677k4yDH+DPA4ewonjNykoYx36OXBiIdMnr9hLlSp47klvsZ85RSYGX1gQnVjiQfsgL8SeQJImeQiKiaMpdrxWce03pZsq2uV5jDdaPO7K29SUukawMHOQSWUKMo/58GGcSIdzsYyzpHdjqiMnfHkseE98sDL93NZLrx9GLOHXI982o4jIET2kqZQToEJioRUQxsxRg4l6PgUKTzStwGfeCjr6aSUYv9o1ngPK2WZF2ulbhMyGnsKj7sJnSRgxH3qbrigsfA5A+hzyucK+1DW9/rA1s0N59LGWfciBVI55DkviHjD9tMkyUnwAMbutwe5dElka5ZUCDTL8tXbLzGdzGjK6ty/V/+ycsiXdbgz05+6XOz67XyRKR7O/S/T3SYE9aXGE66VOEy5lwuZab673KR0x0F/Bli8qt5+Z0q8XPDjHl/vqpz5IkiZ78vuwedASOxVeD6Q0b+k884r7sj/myq+gQiKxvQHguFqUj3aMzj+vGQ090Nz0S6Pef27LoInMFyIrmv3YgfvAjd4xA4xf1Y5mG9zGP3rS0uk2ZKIkPwFG5xuy9TrpUgTsNHyGIkIiq6Hk0S46p+JeXJoxx5tJnK8d8A1iByM1d92RmR0dmE/lv1ZN1PEmTecxt2QhlP8GyBTYjvI+I4kXhygLk+97ElujK77npyHomKPA838hyfHPU3EfP+QuJIMRKrAXJb+TpW3a4PvGSc8wpeU6TD42CQvVnNyfm14F08x7QD+WCcozdhzkM24105JdNCntv9TfHbZ9chyfEVAtsdz1Fx1MjuO3zGwhjm/7Lim65A/ysB11Zzd1+eMKP+FM5biUmuFX3iOLYM7GnfVWuVlOevIfSxU3dfpMM29IEtHsDvH6l16CBmFvvilMvbyMk1ZRG4qHiWQC0v15RZxmVC1qaOudhWT1V43kX2OhHREL7mDvi8oZ3sVKTrQbk8YK6WYjkfxsD13IDDMSI1zx1zlVENqiJK9BoVueXct7twfoNm65fh5ymMIXUsDpUhPq36anEBiqGtF2FN7vYWRboXYRmTOFyxrVSmq8BYxvzVMpnuDHj1ccLn8Uyzc5DlmVKoowVaE5+t55mZiueP1GHdooYatWFIjYFTvOhI9j8qzjieR6mM7cgtjqL5bNaC41GcTed+bnqihcKTfbfnyIbDs3XyLvfZairHShXOgFiBPda63BpT0eX7H8IauqLWzTiv4Nk4eM5D3pUs5nyP434Ge7GKJ2PkILl4P9IoyNm6BGcdvNCFvZ3qb+tlziCee7Ci1jfXKjzeVn0eEC/0ZDmasK7HLOn1zVLx4rktUhu/Ezhgogb7S9zTLxbkfL8A8WkASyS9Dz0CsL1krMu6XCniPpl//+JA8upxH3OzCudxeLLsl4Afjhz0w4nM4CjlexzTXf7AuSHSxbAGmQLrXJ/f4sC5FLgWWC5wORSWXazhsJYjdVZDAHNU3p2/B/xMlwcVtqE+Qwrn/x7s2Zp5Odeuli5m/+sdz8GYx1EbzvDDdd4T8T4X23qkNp94NsYx8Tzskwwalwo8R+C8iTFDHQMizgW4VeeyX63I/rEdVCEdV/SDkeyXn27zPcQZMvAekIhoI1vlfJf5HrVQ9jdxrhPMtUVXFgT7Ab6XCRQrHNdtmL+COs8Eme0yVvH1qVoW4HNxjm8V5L0xNsCyQJSBiGgAsQr7tu/KMrVKMY2T1zZ321+im0wmk8lkMplMJpPJZDKZTCaTyTRH9hLdZDKZTCaTyWQymUwmk8lkMplMpjkynAtos5RQ2ZvSlcpY/P6lPlsIC2B50KiSCvhRhojUUN6UAiJEUvYexI60Y6J9Z5OWZtcRWA2LrrTKVsBC/Ky3Prs+DRdEukXAXLx9EWwqyr5+o8YWzHIBLE45aX8QFlH4vxltCT8DzM3lCuc1SKR1KQdolhs1sOGds0zxZ2ghCjP54I8M9mbX9QyswTnGmyx5ZfGdHlj3e2Dn1pqC7RW9Mo4j+8c4YntxwWcrT9/tinTP+9xWN+pcpqKynIzBwrY35s8WlMMJbZILRW7DUk55tUAZWCRdytRnLETjgMOPzsL5Nle0A4UKVYG2TcQjhcpCdKvEtuE45fqqK1sk/uhD/l7qSSvv9Le4n3/D4A/4g6c3RbrkDz4zu/7U443ZdT+WoXShwHWLeJOmdCHR315lK9m1a2w3/+yL6yLdcMoFaQNWBe2OE1XlWxW2pr2zwn1bVaUAF6BFTyN5EBOUc9D+J9M183yTTh7HuEyHseGlAdjrPVmXXQiL2N9iR/ZfRFYhZiFRAXhnyN9z4TubFS7IYiJth9h3sBzdWMYjRAsMoYCBsvhDFsiFuJUp1E49l1KoMAKm84opoYwSWvBKc9PsDtH+K+dNtAMOp/zDmi87N44PaSGUbeTDvO4JNAPfAFEdRETrgHApZLgukB7HSoGtvGspWzhDhQ44cxkZVnC5XpZSiUh4S4vXN+uAS0E7PRFRYcTPQjustqLjOigH2JGyqqPtKt8DCS51FVBqeYyfnLAJ2KdYWfyjlJ8bAnbnaCJjwd0Jo1kqMPY0LiVJua1aEFs0HiaDWIN1pDEcLkRdP+P6r1BTpKsC+qTmcf5KykNcSPjnJszdzWlFpJsQB54QkDwBMZskVf0oAKxdAFgQ7ONERAn0WVwaLxRlOsTznIX8LKwvIqKmx5bkp9mlLVAHREQfOwFkEMw3pVDaoo/Qtu1wG05IrrtrMF7RDnwayOci4mAynb/mRdyJD8ib+rTJaVQ/wvXmIXH9T9tyPTKIuYzfv8Hl2KyORLpuzPMcjg1fjckq2NlXiOtoqSQXLtifTwPOe1PhjSj3jbPLcZ7HWjVrimRHzj3OE7TNKmALiIhWAV+BCAHEqI2UCzuAsdKJYNypvROOqSWf61V1X7rd43ruJozGCEnhuTLnVbE1pid6Zd/dKsj6exlwGx3of4cTmS4A3A7iw+4oxBe2IyJXupHeg3AfuVZzIR3u72UsKMJ+FXE/oUIBYfxtwZx8VS4FqJa7eF+bc+R4bQIW4fkG33u7Kt9heA6uRXmwHASyc6cwLx0C2ulBIPe/Scr7p8GUB1xBISHuwbr+MqyvNwCZ25vK78DwEutu1UwCBTJB9KfGOUG7IZJyfywTCsRHwHPtsw3ZhojQGkBfmai4g0i4W/QOzndOvmd4HDOWI4W1XjGTa9kFh+fyWy0eG4ik1Hu7AyhjG/DDC0W5vlkrcV7rsN56PJL97QD6BLaN3qHgOrKW52fpPeD1KmBeoV6nCmf6MrzX2kovXq8SyfXh7QF/px/puRtRsXzdczoi3SW6GPmFqNlVucwQu00fynS5LOeCdcA3rpZgHknkmm0X0J8uzM8LqUTNHk15zu8C8gZRpEQSn4L9BTGWTz4D5EqRG25nrN85cgfEMdoO5Jq8CYtCRG3hfK370SV4R4jpNDYqEcg7eE5ZrZM9HK/8ndtDOR5ahYwmyfx3WCj7S3STyWQymUwmk8lkMplMJpPJZDKZ5sheoptMJpPJZDKZTCaTyWQymUwmk8k0R4ZzAdXzUyp7Hq3WpBUyglOle2CPRXsXkbIdF7hqHw2lJaQLNlD3nIGBlYD1KwcWBURelD1prSjn0KrB14uZtFbcaPD33tLg8j4aS2/K/SGXdzTFE5Ll/78cT9C+w89NFQokBiceWqsHibTU7owgf01Oh1ZUIqIOfK0OtqGd6ESkuxf9z9l10eP2eNv0O2bXZVfaXvJg0VshtpWOM5nXNjGGY0psy5km0lI3TfmzhsOW+nb6SKS7O1qZXS/5nNdN6QKjCWSjBz4a15F94k9O2DtzoypxEfOEjvpxItsaESKdiD02PRgPjYLMw9Uq1y06IfXpyYht6cK9V0qy3a/DkMqDxXEwlffbLPFYWwD7fzeS9/til/2Ub354zPn5+Msi3b0HjFVaLLI9q1mYj8bZKvFn+tTxOti4kP5zbUnaypx2c3a9BFY8HOOf6qixCxY2cXp9WeaiDPb4jRKX6SiQfeUjJ9ymn4PiblWVLRp+HIMFy1e2aPwZbWXqIHWBJ5oC2qSUyfJWHUb8XCbA4agQi7ZNtP8ihmJJIZF8qKMA8tCXoYD6MA7x3sqFS3tj9kVWXa7n9bKsc89xzlkgTecVUkwJORQprEc9z3EnB7GmG8oxgKfFoy23HYhk1MN4N+VB0MjJuQORFS2wolamPHZbU2nXLQOKYgxWwmYqkQZ5WLb5DsYCWaYC2IEXAZ+wVpTPRUtsD/BEGtN0Fazt2J21hfhognGH81RQ+AREIaG9u56Tg6U/vdh+ehJynTfzap0BP96qcSN+/EyWPUh4IpE4PlmXpymgY+D3RZLtPiCOn4h6WFTrlprD+egT27kjZyLSuYCeQ2xBonACq4A7aAMixVGBp0Nns+s8YF8qsL7JqW2Bp8o4y4PCr2CdIbZkQ61b2oAJwnWjRgHkYUIcQx/oq6kWMQ0bYN9Neg2R7nCqUCNzhGvWDsSJfiQDMZYf3cV6zYuoQvxkFdah/URavbvE+IRSBig2dW8co4dgny7nZF57MVfS3hjrX1ZmBvev5SDOqDGJTY/t5meyryxk3AaIetH7Ho9uzq43HO6LV+pyPkRyFDbHKVxrnAPiA3Ht7qTy3hXob7jWXijIWDACNI4LdRmq+02ymLxzKz6T1uOxR76Xo5HaZyCxqlVEbKSsZ9wzYLgbqRiJ8z9+NprKsVKDefhun9sP1wix6mSIUY3hfcEgkeMrhtnjFOL53b7ErS4A+mAb9jp6DkXMwjet8Z73sC/5MAWYH/OwTlguyv6J+wSM57vuXZHOjZ6aXXccnlOWp0siXTIC1EPC46YHi3+NocS9LA74mpqG8OcQ8q2GP2HoOoL13EDN8f2Y5816nvN9Eso6v1zidLiHUVte8f4AUXsa/bdFcp56RVW1AKvAz0C1EgiT/3YgK3Oqg+GXpfcZa/7FKJWDsbxfH17gVAC7ea0u1wwN6G/P1rnSD9Sesg+IUBw3eu2JbY37odXi/M0RFr2k5q8BrD1xHxm48t3fS/HB7PrKlN/L4PoeMXZEEpdyo8L1enO1LdL9rx3eoz4Y8oRzqSw3HUuwwEHE1UsKV1WIeMxXIHhercqyX6twGyImdqCwSvhu7b9xNdBkKmOaRKzxs5YVCrMFW7NjKOLdPq99WgU5yHE9d7/PY8hT69qKQkrOE/YjHP4HCu00mjqvGaNqf4luMplMJpPJZDKZTCaTyWQymUwm0xzZS3STyWQymUwmk8lkMplMJpPJZDKZ5shwLqD9SYFKXoGaA2mFquTYRvB0la0HuxNpTUGURAns/42CtAUgdqQd8v9jDCcrIl2e2MbR8vGEY0AQKN9LEaxCaHNbLknLCZ5m7zpg9XKkreFTZ5ywDoiaV7NMLvlovdOniXM+8JNTZfM5DrieT0JpwUYlwvkFVpJMWuD7xWdm17WMrXM1QCnok9TRor8Avq1ddWJ1H1A5ccp2bNeVZUrB0nU8ZUzIUu6GSNfN2MqbwenQWyXJjlgFi+N4ys8qq1H9cMCVtANupUBxM+oFtBDytbZPv9BlTM2uyz4fxA+9KbouvnO5yvfDvof2QSKJhOkPuL7OFNNiGcbDShH6ubInPtfgulxv8vWLx4siXZBw+3bvcbsfnUkUE46Vm1fZPlnekHWJ9wj3li98DhFRd8x9eynlxjlWdsx6kdt+q8U+LrRVd6JV8Z07fcBXANrlqjz8mzwok+9xPTfyss7RAvf56PHsut/ZEOm2azwesI95RdnWzze4vyDVB63nREQuGDTrEO8KnrTA9qPm7Brte3jy/JNy8DWebF/J8XMi5cJeAMvkGVgQ26Fs95OA26kBKBFFLRJWzzrExIEaayXPofg12sr+JitPOcpRXiAgiKS9EFEgalqiIs4dQAnSNr9uzO3bdhjDUcuWRboFH23IfG8fbOQrJYmeQRQQ2hXdqYwFPcCCHQCeY+CeiXR+xgMdMQ2R4q/c57Ao7MkLCmk0AswN2iL1WEGEC87P2qKLCJaCy9dnoUx4EiDKgzOIaxCNlGkHgHAAbM5GSaGswPZah/H/eKyQJn22XJ9BH8gU0iRyuG0iQLt11BpmzQeUSnxpdl1056/nJlDIVlGmK0O6RyMOIhrnspHxHBESNGjGfUyjNnAd2szzOqOnbL2DmOcLtKJjvRIRBYBtqAI6AfsNkUSI4BqkH8n74VgeQjrdNkViy3QvO+T7uV2RLgbkx2OI5xNHWr1bKY/5LZfXmxpzM8h4Ldt3ANOScJ+YKIwPok/KDtf5rYZEmD1T5zp7NOa6PAjkugVzhPW3l3RFutjhCiwBtiRM5Zycg36FmCFd5yjsR4nCnCBCKAbcxEghF1LoE4hzOZogMm9+HlDVnBzjLbDNY3+rKBQA7uc8QEAejWW6YupRnNnc/ZXUjYmKCVGSyZiGCD1s66LCZiA+AZVTi65DGGJT6FaLruwHW7BX6cKW6zEElzWFl8S9J6Jexmp/48LfLTYcwDSpLluB218pc5xu5qci3WqJC3W7zTHo8z0533zzwpDvAXuJG1UZdz7Z4TXDLiAHYxWfek53dt0mxpHGrsRSjRLeZ01GfO/jgAuokbT4TgPXDBrnMpZVMdNyUVbmjSrPw7j/+sip3J/nYZ0B9FEquvJ+iDvBvUlTISUWCzBXhrxnfik4FemuFxiBswrvaTT+Evu9QFEE/NzDaVd8x5nzd7KrsazME1hzPRpymRDfQkTUhncTQ3jnUBrL+QaRw4Mpz1nLBbmnXPO5EauAIAtTme8YcHi4PrwzVJjXCt8P+85IrRmebiAah8fNS+p+h+7+7Npx+P1cVfRROcZxnYx768XnZF1e6/KYLMC+exLLPGBNLCO+pi47yBTiJ75X8XX/hf3k3oTvge1ORORA38b4pNHLfoh7HVjPqX3BY7h/DwJ6kOJAlv3y82eQDtaHjcL8V9fYHvrdVTpnfdJXm5jx1KFoDgpJy/4S3WQymUwmk8lkMplMJpPJZDKZTKY5spfoJpPJZDKZTCaTyWQymUwmk8lkMs2RvUQ3mUwmk8lkMplMJpPJZDKZTCaTaY6MiQ7aDxwqui4lmYQGrwK3Cf/XQTODtysM4Hk0Zg6U5j5XgRGJHOqCWxfpPHjY1Srf4wywY8NYcns022f2zLzMA36rGzHbaqkoIUILRe4i+8CHVahNul6/mCMXK67QNuBdlwqcV+RrERENIn4usua7EsdEWBJEU12vlkWq8vjZ2bWPHFTgEd8bBeI7vZh/jlPuEwXFLR2m3dl1f7LDOctkWzRLV2fXT7vvnF1PMvnc54pr8Fz+/VEoeVEhsK3e0uS+eBbJ/HWBn/p4yBW4l3Rk/oC9e6miwNkgZE7mgQefORf3PSKie33O3xLA3jQrGnO+VuLyBgoaiP2gDtWyWJB5yHv888I15vp9y+qeSJfO4euFxzJELpT5Hs0f2uQPlpsi3eIfvzS7flv5gOap+gwwfoG1n92V6T51yvdHzmYFyveZjsxrP+JKehAwy+7xUHJVc8CRfrbFbLumHJKUd/lZ1xwuO3JtiYjawK8fTZFpJvP3qMgN1wTmeDeW/Rdx+D1odx3rkKW4VQEGfFkGDWQc4h2OAs7PqcrDSXBxn20VNY+TKw1D31TFwetVbgMfoL6a95lm58+fMJ2XSw655JyPzRBAkX+95Mt0eI7ECpyrgPxmIqKdEf/sZfwlX3FzN8v8c6vADXh3wL8/GMv1Qw36L7a5pyZb3+F+Osn4ukJNkQ75qyEhM1iOfw86dB3GPLJYiYjyMG7w3AJkchJJtjAyCCtqtYm82ENgpGoWKLJQT2CqPAu5/vqRrEvk4UbAcz5WeV2BqsBps674q+tlvkcu4DkP4xsRUQH4nyWoZx0jccwvQhzcKMvB3oK4iCxWvcw7hHoZAou9SXIddL3BrNzTCed9P2Le94DG4jtXc8y1xbN1Gup8D2RIIpfyOJT9F6dy7PO6LufFvYlip+Ma6YtDnueKantTw/EBbV3KZB1NifPRIZi7VX7GDnKHmauqsZsxsOdzwAUvE/ejUiYn25oLzPwcnI2kzinYGV+87g7Vegm53qWcnrNYeK6NC3FHc95HCZcpgrN+MM5oLbm8vhykkp3cg/Mchhmf+dIbN0S6BZfXpZtlrjNcx+tYgM/C/B3EknHvTDh/HciePscG4xOe+YLnZRARjSiiaSbLaTqveu5JnZbPnZ3w2njy+L0KcNTLnt6PcDw4hHMf1HFLBNMKXa/C/hJ4vYcT+aUJQNZLsM5w1dydz/AsJx7M12oyXRnKsTOB8wLU+TgRsKM/1+PxoHnhBxOei+4POW4VFC9ZvBcgHh+eYhWP4WyHNOWyj6krHwxxdjHDvQWXSZ/vgSzrFoTFnDqvDeeYp2s8rt++Ipnj0ZQH7P8F51NN1JlgeP5CDPUaqjrHdQvuBfS6ZbvGvxjGXHY80+ZJPrAvzd8bI3Id+2wEIfe631T35jINkovPoCIiOoYs4UebZTnhLCbcD3YC7h97gVwzTGAdtAJrhgfq/IFNmHq/d4PPGIrVGRxEHPdjWPNOEt02sAcs4/iXfQzXtlcgD44j2e6N0bXZ9bLPz8X3afovkXG9WobzFMN92d+2Nvj9y3jAHf34VJ7nh/e/UuH3D0Ei1y14hsQjONPnpb5IRmdwntwxnB10fyTnw4bHba/f96F2I2a7LyWNuek6EFjbUy5HF2JJc6rOiZrynFryON96P12El6UFWFsP1EK5CbEG22kYywI6dP7crHmyv0Q3mUwmk8lkMplMJpPJZDKZTCaTaY7sJbrJZDKZTCaTyWQymUwmk8lkMplMc2Q4lwv0YOSqn9lqgX/iv1mSlgIf7GMbPtsQVpXlfx+wAaMJ2Hp9Zf2CbKDt6jK4OHrKhtCHnzsh2K8jZRMucLpTQKdcKmlEDfrCOJ223qERDC1n2oqOLrhmge/9NulgEdpn58e5ciDm41KFK0xbx/sRWIqi7ux6NZX2HVSBtKXoiYoKGdCYst254jOKJe+WRLpvdP7W7LqWhz4wlegUtJyg7eqzHd0/+OfrVUABqPyCm4quVLkvJ4OmSBcDfqaPFrGqtgZzfr94xvYbtIstlWQDHE34M4H4SGV/K4PdqwEW3YqyIKN1rpnn/L19uS3zCuPw4GVua19hi2qLbLvtn7AVbakqLU4bz/LP6bd95+zaOTwS6QjyW6hDve5Lu1LuEViNe1z23bG0au1MuMCIc0IUi7aYHqdsk5o4PIh6JDE+Adg2085Ts+ubdZnX1RJY5aAfDZSD+/GQfxGD1dNX/qhdcAB+AWIVYmiIiFqAlMI7IKqDSOIElgApcRbJvljN8fcKgKhpA89BowTQFXY84R8WFRZku8Y/o6X2LJQ37ISAGYJ7ILKBiGiznNEk0bHWpJVQQg4lAjNARFTyuB9chhi5WVIoIJi737rYnV3/6fGCSDcCm2oT5vEFxeHBuRvxQTgf6pgGLnAaAKYtVDHSd/lZ68T2ybwj+yIiGBbBwrnsz38uWoa3yhpvxuW9O+Tr48n8ORljhtYRzOs4XWgkVwOQJoOYE6KFOCJZR6uAfcD8jEayfyQZ1yXGo0zhK7COihDHQmUnjh2IfYCOmKZVkQ4xfkgn0QgD3+d8PBiBPVzZ9dH2XgPMSJQptAUEsqOIG2BEfF0kGfcRk1UDG3NeoUUGYu0JeC5FtHi2dbH1Vq+rFn20w3M9jKeyf0wSQNm5J3zvTK6/EFUyzHidMHXkBDYlzvB4ymiAldzTIt1mxmu9COa5UNX5pRyPUUQn4XSo7cnVOWinR0OZbhBfPL/q2DKGMYDV91x5UaRLMv4Zx81hJNdBiPQrO4gwk/mbh6UJYzkn97LD2fUEMIM9b0WkayOSZ7zNz8nx/TTSa5LyeqID/UMjKhYS7i9ZxvfYUYPtEmwuYIlKGyWJyXowmZL9ndpX1pVKQmVvSgsFOQ4fjbnvILYkVPiVGuxFl2Fdv1SU68hGnvtpL+J743qOiKgMa4Y1n+93rcLp/seJbOtj2Me7YlzLdUHB5XV9BWKpxpSewh46FLFB3u/h+OJ9wYLGPgESBteyBdU9EROyXeC1TymSN4wB4VQGBIav0Fg34B4Yn3Dv+nxDNii+MtiHfU+ksColwFdgFe305J7+0YjH9cMRzlHyud0p7AEBV7eo6hKfhQgNxAASSQztDeDknYUy9uHyHhEzOoYfAE4Xn4TowFTFmxTiWCfEvUmm0vE14r6WZTcnB/rfis/4YY1bxb6NpTgYyzqqwsKqE/DDWr5ce1Y8bqubVb5OMj3Pcf6uVnjO+q4rcv5CDSbcwJ7TFJ8lS5y/ExijuIzUYwjRpOMpt/X//MJlke7phe7sehhy/9D4MCzjENBEJ5FMh3naG8E7B4U0ybsXv/ZFfAsR0UHCHJi+y3ktOnJd1cq4HyCe5/m6bEMPcEwvDHg9fDDh8bo3kmNyI899YgHei2ms3YbP33uhD8g8NYYQ4YKfaNx13nXOYZzmyWZ4k8lkMplMJpPJZDKZTCaTyWQymebIXqKbTCaTyWQymUwmk8lkMplMJpPJNEeGcwGt+xmVvJSGU/ln/I/AAoR22KbycRQBDRCC/ayt7Ds+2LuvVPha0RiE3RNPHc/D9xcK0kaDpxMfgSW0o+xijbz83itCuwiRxB0sgV1pw5fWuwF8L0i4vHX1nFaeC1kGi047k3WEVhXMuyYbIMLhMpzGvF2RHuLjgC0oJxFbRBA3sZiTNpUl4KCsl7gx2sqeXMrxid+V8bfPrlfz0tp2tcZlRHtyyZO2UuwTiHdwSPbLnSHX3zFY4xfVMeHYj5A6cEPhOrpgwZ5nWSOSGKNDZMVAHvbGspKOACEyJj7NeSVbFenWCpyn4qscj4yPreW4Hhx1gvuDdnN2/XDE7fF0fSjS5bpc9mHE7XFjVeJh+g/hVPn/9//JH/zI/12kyyBO3H+BLY2fUCdvx3e4jNcqbH/qKgTJUpHLhbb5Edjy+rHy+IOKGddrrLAPy+kS/wAfHQcyIK2XudKRpDKKZZ3fqHP9/e1lvsdJKG1ln+9CzIBynKaybbyI7V4t6MC+smrh+EAckbZ07YG1FbEeT1U50GgEDCK+QuQvqP+HXoR43I0Q+yDrqCCQEPz7y5LsRFn25J/p1dVzu+Q5BRplY/H7N+c2ZteI+EI8GpG0ga8uMgbp2kjG8ALgCU5C7ot9hTTCHoe2wG1AY21VZL+8O+CfEwjACcm8+uB3zrvuhd8hkmOlBUi5ZkEko6UCd8CNEo+BOJN9e3fC4xrXKqNXsQbjR51ApkMc06UKR4xgIgABAABJREFU53WkkCYpYT0DmoU4XiaOjH1BwnM8ojZ6qbSYBkNOhyiKQSbnr6rD7f5ci79T9ORce9pvzq7zDuDI8rIuccwjdkdN3bQLsepOHy26MjYj8seD3uc68oY9+F4MCJxbPq9hMtXf0OpafZUdQwqfnUJbV1XZcex1Q1zrKFQJzKEYPoNElv0F54uz60bG5cA5j4io47I9OQTU2Tg5FenGEc/50bQ3u/aq8n5rxDiXQTahedqENQ1ifNAC/+JgIL7Tm/JYK2I/ysn2PI5kvJs905dxC6esEMbDelk2KOKmHva4vk4Ag0JElIe6zWWMq1kryAmsjnsQ6PSIpCIiqgBGJiSu87yyjiOG50uAqFuMIM6T/I4L4yGv+gRqrcx1jjGsoqzeQGagJizd7wxkunJQoJhs8v5KerHvUdHN0dM1WX9va/GYGAL+J+/KdWQdsC0F2FMOQtnWiFFFW/+tpuyL1wBh+lSD87A75L7dUXi+fszPxf7WUpi3OmBUcfzr+yEmaAXwa/WcLPu9Pn92HPAi5HCskKNwwytVxEOIZPR4CPt9wAxWchJHtjvmeTQE/NWlXFOkqwCuA+cenBN2J7KOEEuB+NaWwuLerMG7mDx/CfEtREQvD2BtAeuHWK2XVoo8r+N+/1JJLu4QZbcM70S2ynLNsF7h2DyF9zJ/etoU6b4AdE3EnrqOfC+AWI5GEecRTnOsXiJh/UeaUQnCKRrxVVEq54cbQMpZgPaoyazSFaiLlwY8DjO1X1oFXN3ehOv/IJBjdzTl7z3f5DH5zNNyXmrv87y312PMSKLQHJffxflzAS/3dG9XpIt2ObacPODxjwiYQk7WOaJZPt/lPPRjWXYH0GQY0zTO5RjeH+5O+B4PBvK5hwGPySGsjUskF/yrGY/lMuxDt1Uj+mPO3wN4B7fkylhwpcb3z8H7lyVfjofLS13+YY8xbeMpj9f1JVlHZXj3dARLd92VF2AP873rnPDhSPKIzqANDgHBNVTvMIIkFe8GX032l+gmk8lkMplMJpPJZDKZTCaTyWQyzZG9RDeZTCaTyWQymUwmk8lkMplMJpNpjuwluslkMplMJpPJZDKZTCaTyWQymUxzZEx0UNlLqeyl5ClmcBWYuiPAY2k29kHAbKCFAnOlkFNORHRvCOxTQPGsSHwPLRaYydMC5g/ygBYKktfVjZlr1AD+eF6xonsxc5dOIs6f78q8JsDk3gMeU12xGZ8BdtxZ1KR5WgZ+XSXHdTRMJMdsD1CPyHrVfOMN4DS/GVhZmiuFPNyrFeZmIRdNc0angEQCHNk5pixyvr6xxQysuuKEIXcMed831P02fGSSIVtfluks5roMMu4HvieZVc+1+FnYd0aK/b8InDVkXu8ovt4QutzuiAfBQXY2u/YU477jHvP19BGn82S6LeB1XQZm8FZZts0Z9Nn9gO9RPpPM8aOQGwHLtFKT3O0U+vmVdYbUuXnJxfrMnfXZ9dURszuvfvNdkS465TYEzKhgDj/JE5djDH32WlWyTkdTLuMEeL8B8IODqexwccD9HDm5vUxygX1g7yHPMVIc7+MJ1wWOG1eh630ohwdxB+MMkWQprwDkPh82RDrkhy8Cp1ET8xXCfaaHsqkFn/wIuGjfzU1LfTU2BjGy3vD8B1lHX+zy9QkwKnOOvN9yScGPv6yeOruilpdcadPFWkgXKecUaT0vmbzvYvQeva3VnV0fB3KyxTM9/vT+5uwaOeBERGfAuZ+KWCrzg4z0FnBQr1WAnao68AjOOvDgw2Iq+8rutMv3djheVlQsReY4nt8SK9RfHdYJt5Y5hv+v/RWRDuf/NjBcB7EceEVYQ+CxFvovNroRrJFgsr3ZkHWOc9YxMENT4u+cOnviOw7E1Q2H54ScykWQch5w3eeq6NID1v5jYKLeqMu2QRY1ssX1mQiHEx7ozQK3W02tb0ZQtasl5EvLDveYDmbXK8AF12zsFszxl4jjLK6PNJd9AH35Cx3O0M2GTLgM53bkgXmtQ90JTD+PA16z9R0ZqC/DWR2rJa7zUSaD5CTj+brpMKd80amJdNBdaOryuiVy5Fwb55i1PU35Os1knbddHitLwPSueXJBh/0Xj1HqA+9Xl93LeB2Ja7uhmkP3Xe73yPueBgsiHfb7lQL3ibwKQsgMrjhcjn4m1+fH9IA/c/l+pfiWSBfBWuBhyvxaL5OdInG5jJe8t/LvM7m/mTrcbmHGdTYC3n2k1jd5YML6Gee1SJK7i8xh5MbXC2ohDwqAtavw8uS73rlyms4rTIgoOz8fngIX+U3LzMLf/Fa1hj7lPnJ8h+f/L501RTpcK+O5YmX13A7sje/1OYbcHXI/2hvLPoZnaNTg/IxLeX3mAF/jERB6iefCevHVln/4WQDnYmie77U65wOnop2hTIfzSh/WOku+LEcO5sp1h2PN1aocK7jsTTJcF/AHDwYqD7AuwDKNpjKubsP+8BDevdwfyTH3CM4Ow/PCltS5DHU4+wD74sOxLNP+mCtwd4ydR8aT51b5rI3Fp7i/tF6U52dUPT4X7P8AJPdALdQWYO9zHaa2BzB1BIn8Dp6TVYB12USdK5IA+3yU8Hh6PJTpooTrGc9yqah1Sy0H507B+4wtOY2cO+PuFel3ZvjuJIL1cOdQ3rA75p+Rjf/5rtwXNP53rqcf/Fv3Z9fFH/8Oka50yO8tLsN7I+esO7tOPvFQfOf0k9zWy/XR7PqLR0si3c6E6xLPidFnYMXwbgLXYnuBXLfsujuc76wC13ItsDPEd0XwTq8gx8MGnFO0XFqkedqFl6JDiJ3HgdzHX27zmgYR5LhX0mz9RXiPugzDC2MxEdEnO/zzVVjMlj05HoLg4vPkElXpSZqdO2Nqnuwv0U0mk8lkMplMJpPJZDKZTCaTyWSaI3uJbjKZTCaTyWQymUwmk8lkMplMJtMcGc4FNE5cysilT7WlNeUIrLeHKSMcqmNpJdkd8s9XatLag0JMAtqdCsrP1gCUBDoLEF/RjaUFY80Hq4bL318oSvbMwRlbK07A4jCeyi6BWToDC/efKrTI4wlbppuQb42ROQ3ZdoGoklpO2i5u1PizKlgy8uq/fSo5thut19ka/Me7qyJdH/wjaKnFOk+V9RKRFWgFXFbYnVqe06F1eaiQEA+HF9ufNkuyjl4egHUZHMR3e9LaekxsY0ZraieUFqwDQBJgP5ooTkQF7Gx4i6GylaEN5jDtzq6nDn9JW2Wvptdn1y6ggK7Shkj35gXOwzsX2aK7vdIR6aKI++luF6zPCuOTh3TYmR92pNXoqRW2ZheqXI6Dx9ISfgL914XnVv9f90W6z+1cml2XoY82C9KKmndxjHP+Ho2kDf8RWAr7gH3owu3qBTk4/JzqqF/W/kim66VsP2u6/J1V5U/G3qzce0IhdLI7Q74fxhkiiYHQtnIUYpbWy5x3PW4eg7XycKyYFaAp5A9TfbpbOJ/4y7oMbkAcDhrpdTTGtua+uFmZ///Vd3v8ndWyQiQUEhon88tieqKtYo0KbpFuNmQ9v7XF87ULc5GvbH53hxyv8JPHo/m4gxt1/uyZmkQ93B/x2MH+UoZ5bjyVeb0EfQSRWXf7Mp7HDn+4WOA+u1mRfQdjPc6bGj2zB/ND+Zhtm/dH0lt5POEbHgOqyFOoIkSG4HO1dRfjDlb6clgX6XAeHiWc+TrxoIyzNfGdHCxtF4p87as6vzc9oYvUIpmHPHHd3o3Z4rsSr4t0KwK5woXC+iKSZUfsSNFVtlL48XoVLL+kYlWXEUQNKG8tr9aU8DWc47uAFtEYr250MSfLIdnfFgA/OE7mYyyG0P9OXK7/gCTSZCXjNWUCeLiGK+e1Fl2eXedgDRcq/EoR6uxqemN2PSE5J+9Atw9jjh8BYJSIiFIor+9w/hBBQET0GPYPZ7B+aKdsxz5z98V3Tunx7DohGGskx+QgPpxd13I8Bqqq/7Zc7mMlYEq82l9RlQEPlVeouCnxxJcDfMVpNhDpAihjB5BDrjt/67mYNWfXB+6B+GwK2IwpYFt6xPXgqFIVAXnVIkZUFdR6/yTi+504vN4cdqQdHtUqcNmv1eT9bjTyFKYJUUd/y4T6f1w7o2quQM2KxFx0Accwibj/DV6UKBUXJpkIsGwbJZluBJ+dxfydXixjJManIvRTifSU42ER0Ec9wJEcjGUsyAAbMIXrZYVLQRToGawxQ4W5aME2q5Tj+tJIrjc3OE8Px/ysvkKxhRn/PEz5wclE7keKgEK5AgyGSxU9d3DdtmB/4nuAsoplmZIB5+8k4DmhkajYB3P5Ltyjq5CIiLLbruJzZbs/BrTNQ0DM5NTepBNdvPa5UZNzcgLvN5IR4IPKMoNT2PdtlrmMY7U/R5TtNy8MIB3H9sdD+Z00wzUv52eqUBVjWFfFxH3AU7F0FzBGU7p4/iMiKsHe8011vt97nn0k0n3hEcfjYA4GlIhopcyxoQOYp48r5CDiUa9APHEUFGkT0L+FG4Bp+fzLIl263+XPmjAGcoD+W5ZjY/mdPG7SIV9/7g/U+yDYhy8CCu/pquwfp/AOo+RxHTU9uQ7ahSLi2ky3ISIMY0AsITaViKgK8+NTNb45Ii2JiPZhj4R7etwHEElkzRpgWtKM+/yLPfEVegDYljV43XoiQzu1A+6LnZDzUy/IfolLWSzuZbV3eq4xpXFC9MHXMHfbX6KbTCaTyWQymUwmk8lkMplMJpPJNEf2Et1kMplMJpPJZDKZTCaTyWQymUymOTKcC+jFvkNF16WPjaXlpJyxHbBK7CkYkrSfjSO2buy3wSLtSJvPeglOvQePmKcsMUsF9hsgwuUYrAzbVRLCU2vbYGvoRNL6gbb0KdhPbzSkreGSj+Z2/j+XF7vSBrY/wnLwc59b0DY1sIUAimalKK23t+ps43p6my2/J0eywDmw5V95D+fvxn+QbfOnJ2y5acMJ5HgoL9qdiIgWwBOH9ri1oiw7WvTjlMs+nM4fXni/01A+9wWo2/0I7GzKxpwHG0zocN/rJZIxcRcsMnHGeY2U7bji8f2wLkJ10nsAJ3u70CcwP11X+mCuEeNN3uw+Pbv+/kuy7H/nCh9PXmmCLUpZ/rDdbm7wKeilBYW8+di12fUg4nt8vidP6/Y9LlP3gMeQxsM089x3Wj7nr1CRdVkETAveI1CWxL3JxQiRw0CW9wtn/NzulJ8bEv++7Z6K7zzjbs2uF8Hir/ELaPfKOy6kk3nCk7PXIS6UPWmV2w/ANgvOtKmikiBKCZFBGu1Syl38f73jRKbrA5Lg1TASiIcYwEn0nzrl+y0UZTu9dQFsbzCsRwrZdAUs3Wgd07AaOEBc4EN0XruxR5PE/q/7K6mSf9KfNO5rCvF4N+AxH6g6DSHdKYRPRHIQEY2gnz4YclvXVB/FeQ4tzUcwj18qyVj1fRsc6/cBFRdnMkZM+4wXwDio7b+I8sD+p+sIy34HsDYPBvJ+D0aMZhgSL0K2C02RDsPLwRhQLAoLhusiF3I4UoEC56JGjselA9+/mpOW2hJYnzHe3Z50Rbq7yUdn10WPbdFj96rMK64BM+5H2jYfwJh3nYvzQER0vcTPWgNE1Ukk67wJ/ehNsCaq6fJ6F88jp4G8H+aiCHmHKYXujCWSowq4DrQTD2X3FXH/5R63ocYU9ubgYRZSieTJwIL9YDzi57h7Il0MaJFixnkNFKYFMR++w/1oksl0BYfrNp/jtsb+QSTRKnvEc6+r/jYpdbguDiGvscvX46QtvhMmjLIreNz36p7E35Vyzdk1YlXSTI6hgouIJW6Ptopv6PKvQKBoTuV6aexsz64TWIMM3K5INwUUDeJXmrRJ84RrGi+Ta2hswyJgkELAyGjMEOJcFjPG+E1Jlh0RLmfE69DMkf0V230UcV9ZmTZFuhVfxgPTxUoyoiRz6EG7KX6PaIb1Ko//8WA+ds8DZMCVpa74bG/Cc+ruhPuIng9RX+giYo37y1pJ9jE5/XOf3QvGIl3V5bzjXsxx5uNccI5RxEY6gm1uAoO3oDBeiAzBNWYtL8uB64lOwvONQK+R3DMMY8TDyOfilI/Iim1AmrwwkOhPnC7yUJfLqs63yhwnENfxpb7sH5dLnIkJ9CmNfUFMaQAIxYWcbBucy1uA9dFEyj/Z5fnMAdbGvZG8H66R1suI1pU3xHbDOv+GFneC5aKsywHcA5G0j4ayI/WhoXLT+QMiBtxPn3hM+iTxqAFktlng77TP5LrlCNAs+O5koSDfB5XyPI8gAnWr2RfpltY5T50jflYuJwPx6ndyX3IanG74hxIf9uABr7XXF/lFSs7n++3stcR3WhXOaxTzvHkUyj0lzrWrRS7TqsJQfbrLdXt/AKjZRJa9RcucP5gDF/KyT5RzF8c0/f5L7H2q3DZqm0GXqhe/51qXtGsKoH11XbyiFV+jk/gaUUV6XyaW19B19PpmyecyblUA0aze6V2rjmg0VazWObLduclkMplMJpPJZDKZTCaTyWQymUxzZC/RTSaTyWQymUwmk8lkMplMJpPJZJoje4luMplMJpPJZDKZTCaTyWQymUwm0xwZE/0C5UkytUrwc94BllomuT45+D+JFBhdieITIifMzyEDS8KGdoDbNgD+ZAS8KeSeExHtApv8MJBsJZEHYFthmYJEgoyQY4T5ris4mw8A5XtjZqm93JUsReSalaH35RzJO1sGYOen7zAH8uOKqeUDj/nW/4fLu1CUwLM3tzh/n2rzg3sRcNUUtxT5cMiHrii+Vgak0dOI863Z6ftj/ux4Avw6xVXFmsV+pLl0Z+7R7Lqb7Myu8843inSXC1xnyBm8Hw5FunbKzC8v4TrKKcZkBJxL5H8WgBXZI8mTwrxv5BTIH7Tfrc+uyyPgaaav7f/7nn1zT/xc+iR32v0Jl6mpsIqf7zJf8xgY9ZdLsg3/9psez64LVeCed2QoDZG9F3O93B7ImHGnz/foAyvWc2TMGCfcT3G8TjKuo2Imx25/ymNg2ef8LfkyrxXgQwcQgx4NJPS2VcQ4wd9Z9WXcOgBO4+Mh56HoyjZcgwAwhcEWKTB4CVhtB2OsLzleDyfAogOO/62GPEtgAxjEHRivpxPg4keSyfeJU87rO1f4+5slmVdsNYyXmpdYy/H3btQ5DyP5WNqbOIJZbbpY/TCjvJvSY/V3AQ9LHPuwFseKiV6Dcy12x9DfJjLmPnJ5/D8z5vMWKjkZUBYlhnCmNvTZWK0fDoCXPgAupWaJT2E9sR8xA/JQjYfnGxzTFiA/fcUCxbkN5zJkUhMRFSDuVDMeU2rapKMxjyM8c0ALOei+y/dOZHEJlwarEDMwTlRy88fIAeQH5y4ioiu5b5hdpzB6q2ldpJs4HNSaHpf9bU05YO8MOX93Q86fjn0LRZg34aNYjfUuZPfRiPvySPVfvAeGz0XFmNwAPmwXzhk5nHCn2HMfiu/4xPP18pS5m/VQrgsGwMM9Cfh+CqEp1p7LwPHUwrF2PL09u+70XxDpPGCGZ7Vv4/xl8t71jMcD5sFT6xv83igvzxlB9Z2TC39fyZri52pWvzBdO+Ozl3oTeQ5TNO3OrgvAPffKMrDUXebu5mlO0CG5fihB3xnGcn3TgMVoBNDWuopvU+gHPQIeeSb7ZRP6Tt7heHmSynXayGW+K+6jiiTnbuTVY7+sOfP70Wq6wt+HONNP5blJyEGPMmZZp2ottgj9CNnQyKQmIopSh2Ld+U3n9CdHLfI9nxTGm5p57ptFOPsqy2TCUzg/ZADnUB0r7m4f4h2us5DVTSTPcHg85Dxg+66U9P6Xr282+LNyTu5/8TwePA9tWcXpRp4/Q575pi/nr37M4xLXxh2F8/XU/voVVVWld2DOahLnXZ8f0Ha6nFfgWl+ryTrHs2Gwxh6MOd93+7L+cR/UwXMKAlmXpzD/4Pky1yrzDyI4gvclekdZBzg+rsn12VB54NqvAcM8UuuWAzgjBEt4ry8TlmHtgv3oalmW4/6IP3w45jrfgD5xqSQXd29a5rM27p81Z9ePhvI9Cp5zhudndBUX+tA5nl2PHY7ZK3DWxJNycJkwLH5od1Wkuze4+Fws35Nz2TeHHN+/88r+7Hr978lypHtwXtgA1jQnMn+TP+DPKmWev+4fy3NZbg/4/t8C56ZhDLrXl2elkP75yyq6sp+/vQXtVuZ9xgPVNthf7sZc/y2S64qyy2NvmHL59HstPJNmBOtzz5V1jvHuOIR1bU/2Szx/6EqV+852RcaqIexp8g7fO8tg7a/mS9yTD6YctGP1TnW5yGMS301G6lCSCbzfwLX2YSDj43HYoCCZ//4UZX+JbjKZTCaTyWQymUwmk8lkMplMJtMc2Ut0k8lkMplMJpPJZDKZTCaTyWQymebIcC6gkkdUdImueEvi99kcS16mvAdHztnsejlrza5bBWmTQJtkBjfvRtKKcwhuw2s1vsflCtsQFAFG2JrnIVuIiBZybJNsT/lBD/rSnpxzLu4iayVpEUEL3GTKtitt9a7A7dD5dRbJ/8/58DFbNdGqfKIcFmch2zoeDtnWt1GWaIsCWGm2a3zDx2C/9pUlHMsIFBq6P5LWVnTbY3tECkHSgbweAWqn5Mk6Pku4PQpgNXZI5i8mvofncP+okiw72lsOwMc4diTOJQTLepXY/jQE2xYRUQKWeD/jdorh/+RyJG19HRgb1Zjz+qWetOsGCVuhVn3O6yCW/fepGmMMagUeN8GOtBC9eYXt2OOEbb1oLSIiakP/Q1vedkVafqs3OR/u1sLsOnfnTKRbOeLvPTrlMmqsRwts/YiOGJLs6DWH2xTHdRXa/VZZWtbQJojPjZXtuBfxzydTtjHnlc09ALzUy2POX8dti3QbKVvi1nwuezUv6zyB2Ifx4xzOAa7Lr4K/igHhknfm///w8w3uV7dh/J9BlZc9hTCCTOG4fufiQKTbaPDP0RRQMSNp0esD4ufbN9javt+XqKMv9cs00R430zmdxgHlnIxyrox9d6F9l4sXoyyIiMqABUMUy1ZFIZJGbCutgfdWdW1CFyFaxwNoyzOFYptAugLkR6PTVlOYu0OOfXmFDAH6ElUBH9RSKCsf5kbsaog9IiIaT7nPtiFmBGrA1vNghQbrczuSMW1MXOA+PHiFpG27ANbg/oSfNZlyJaN9lYjIh3XLIGOLr6v+buRW/hLnFZqjoAJ1J+J5aa3E9dDIS7vz0zUuRznHeTgYyzm+D5yHFCKcxjmtQ/f75BnfI1QxHPvfCkypihJAy0V+wGnIbSPsseo7ON8jXk7HpRzU2UKR60jjiPKA/PAczsNRImPpYczYliCC+VVZeVGdmLEoI0+iWEr09tn1pQLH2TCVAyKbAtbH5TVDkWQM9zJuD0S2uKoCce6uI3LNucn3rsq4fzj8LOcHUAr9YEekKwCuKucszq5LCu1Shb6I+48wlXbnKbRNJcfXrkKa5BMe40OH568xdUW6nsP1guWYkMS5DGNGE3qQV480MpPLddO5Mrte9jnd3liOyTbgZkaAFYwduVb0Ia+Ih/EUdgsVQgx7cSDX0/4wT3EW6q+YlFznPDKDSCLNHo1g/auQVx6gAc4ibqsX+3INdwJzByLRolSmWwNE31aVPzsUuBQZ+65UOU9taPJAxUhEkGHMHau4j+HdJVwny3G4VeafTwOur5LayyLmFfccntqgh4i5gjX0WSZjM6I8HkP9PRWti3TvWuJ5eXfCbfP5M37u8USOwyqsH4Zxd3bdjSUK5PaAJ7o44xj+5sZIpDuYQLr5U4dAn61DH0D0BBHRYvHi+Vq3IT7r1ZbwRej8Dwac8HAi+yUiSHcAOdiD/TSuY4mI8i6/h1oohvB7OZfhO64Y4jTGN60w43iXKGgbYjM+3eGM92Odjp+F69zjiZyXXh5wXVw+4X138Q+ORboHh/wZtvsLfTkfYj196xL35WEs12ktQEo9BkxLOQdoERWPjgB9goi7ZkGW/V0r3dl1B5BInziT883jgPMXOjx/pSSxMSuwLq3C3rOt3ysS79cjl99ThJEcX0V4PfwiU5XO7eNdiCEV2Gc4qk9gH8tD/SPiEuMUkXx/GMFaxXdlOzWg71yu8L3XSnIMdaAqziBO6/eUrQLR9FViBcr+Et1kMplMJpPJZDKZTCaTyWQymUymObKX6CaTyWQymUwmk8lkMplMJpPJZDLNkeFcQCeTjApuSqNEWljQnol21rYjLYl952R2XYcT3PWp7YgxOEjYqtFxT0S6y+nl2fVkypYYf45Nm4ioHfL/i2SQ1yiTZfJzfD9niidMSw8D2m/Q3tAoSP9DHexPaPdIFAungZakEdjwAvlctDgt+ny/BenKEZaMnSHbwo4mygJfuth6vwDWtqJ0ftC6z3lCd/cf7kmrURlsryUPbFaRTDeGE7ArHltv1styGFYizshOyJYpRKIQEfXi3dn1Uv7G7PrZmrQGb1e5nvdG831lEbG1xwEcUaJsr+OsO7vu0+HseoOe5vykK/gVMYbQOnM8ke1+HbJ+FoGtV1nqpoDU2FzncRj1ZV2OwbL+tiWuvy+etUQ6xDu8ucGen2km+/kf/59searluV7GU2ljPAR71gjGl7bbITLoCE5cn0SyzrH+VgtsbcW6PDcmwRWGY+0kkPcOwCY1dNgKuURNka5V4Bt6YFPrkNQE7P8OoIVaKn9YE3ia+5bsvlSY81+9j6Vrk75hie17b2/i6eTSUr0z5j6Brtdnmtx3tO1zoci5XQUsyCSR/a2+wGOo9jyM8XsSiXSyw4VsNNiidzCQKIu8ex7ZZTqvffeAPKdAXnhJ/H7YBuQC2ISLahyiHbuezyCdfM6VQnN2jRZCRAQRETXyaPPncf2RIx4beq6tANYrDxNORdkn8cdtwLw11fhCCzFMc3S5LO2dQcIfnoAV9TiQdXQYct8+hrXOViJj31qJ71HJI6ZJTt7tKdqGuV56iRyvRxPOL6KjUuK4hbZUIqKFjJFSaIHFNRoRURAzvmrLAQxVSSI+5uGhNLLtmSaP80aeLdN5R2LLDgOuF8QE7E5kOY4nHD87YMut5VTcgb54Cve+VpNtiNb0dZ/r79kGYEZ6N/ErVAQUYAva1lcMBvyxBQMnTOQgwhy93OP6P3AfiXQb7vOz62EO2r0s11UNaLd8xn1srDB09ZTLiGvUivpbon0IuHEG7aEsv77DMdzLILaobVUv4zLGDrdhBv03hjRERH6B1yfVPJev6WyIdFheXL8duvsiXXfKeS3Cd/IKlzKBtUEFUEy6vzXz/NlmxDH30JX3O0nuzq4dGEOTWK5l/Vxzdp0COuamKy3m/SnXiwsdDpbg5Luyv+Wg/02hzhFTRERUynjurcL+beJIDNUIUHuH7oPZdTeS/dfPNSnN5uMQTE90tRxT2XPJVWv8MbQbzlFHgUYzcH85hv2vXsOdhrK9X1Fdza9dwAsgDQD3so9Hcg4lwAwhfiyvsGB92BPiml4jY+8NOU+YvXYk5xFcq3zzIqBdIjkGTmEuP4W99kQxC3YcGTde0ePo4+Jnz4U2yD8D95bjFdc+QxgKCaBsGgWNOuOEo5TjfkXHIJii+7Af+VxXrqFxrh3A+wy9XsJ9xmqR+8q1mtxo7ACaEdGumAciWecpNHCzKPsb7kEejvm5O5nEkd3Mc90itvcuzFctde9uzHldLnLf0aQKfK81gPUX7geJiPLQz1doe3a95ssxuQRoHFwXdNSwQRQIoo+aatOH8f3FAZfp7N6mSDcAjMlJyNfn9nOli1kdeqt1BuMoTLn/vaUJSFtvPvcD841xhYjodq9OF0nvd5se4IgSRrZq1A7e/ekGoFd7ci17CG3Qy/j9zdSdPx96Cfe9VnF+W9dygMLLyXVaHrCRfcDm4D5FxyN8t7YKaNiVkqwkRBhiHuoqDzG04e4Ix6Rsm2ouo5zz2jbe9pfoJpPJZDKZTCaTyWQymUwmk8lkMs2RvUQ3mUwmk8lkMplMJpPJZDKZTCaTaY7sJbrJZDKZTCaTyWQymUwmk8lkMplMc2RMdNAgTijvJud+j7yoPWBAh47kV+aAF9V2mWdVnCqWV47TIRdtJZVs0e0Ks59KOeSscX6eaygwMDHHKAAOop+TzJ8lwBq1gYPaSWWZpgPmC0XZxUxvIqIoZe5SL2JWU6x48KMp/78NIsMfRF2Rrgx16TpcD6cST0gjgF0NEoY9HabHIt0g5rpF3iny1itqNBwG/Nl9qIf2VNbREXCcYuAsuur/qMYO8839lMtUiRdEutOQmWTYj5B5SUTkOcAnBX7VREGU7wIa9GTK7M1QcWSfpTfNrlPgQZ1BnyciGkz551qOeZ3I2lzwSuI7ZY8ZXb0YuJt52S8/fsrPDQGW1VCstxow/StbXP8vfUzW5afbzdn1rTrX/7XqUKRbKHDe63noRxNZjiEw13yPO/DHziSHrwZnFdTzyGyUrMKHQ+Dmwe81V9UDdhwyF5HLrND6guOP8UOzmIMMGeZ8P9+TeUVOXR7Yp358RaYDkOSlCvIIZf6Qzbhe5ntfVczbS2UeDyGwMZHLSkTUyPP3nl9mluKjrmTPHYfc1q0C1/831ji43Fppi+9MQi7vMTARi4qH93iX+1/tlPNdkihLKvnAPjxszq4/3ZFA+JPQoTBVMF7TObWTh+Q6OQo9yRbOgL1ZirgfPJ+7KtKNYFwncA7CSMEUM2Bb1iB2dWI5VgYwz52FnC6GsbdPknm5OuW+swKMybNwPlfXdfDcCPkZZv3xGBjciimLzFUco/vnuK+sBpyZMSKZ7vaAx6EHHGRPQaUnhKzzE0gn1xZD4rF4Ft6bXSNHecN9E35FrNmKxIOvQGWRbupw3Q5SHq/BSNZ5B/jagxE/9/FEDmwHONkD6BPNvOxHy3DGwt6I6+ghMJaJiDrh4ux63eH+UVCN/RxnScTBTiznzdUSx7gA1iBXqpxumsk5D+dhZHzqdcYoRTYrcNRlmKYJhPcc9I9IccH7yR5fB3z+S9WXXPBCjssxgWrRzOsuwXkfU3W4DigPc28R2jMmuficEPeJzOX2Laj66zq8XsL+h8xRV/X5Uo7bvezwWMtlcpLHNWYr5e8g+5tIss9LcF1wFD8c2qOdcnsUUrm+wXN8XDj3ZBipuRZWNRiL40TuW3AsL6ZLs+sbTVkv9/vcwP0pxA9gYePZF0Syn59kfH7OQJ0xhP0FY1A5kzFjAGdhjROOTdF0IO+XRpRl5/eTJqmHozz5XkGsk4mIIlj33AdGeLMg2xfXuXf7HIOCRNY9Ms1xfuirOX4MZxjl4ObDGM8wkPfGs4mQf42xk4jIhzOzQlgLtEOZBwzv8uwJvRbkn//2and2/d8O5D4It+FRAuz0WMa0OpyD9Mj5wuw6iOTJRxWf930rKcdjtd2nT53yPOpAJSHTWzPCT6EuFlw+F26lpPYjwIPfg8Cv2dPYBCNgoi8oDvIC7AXwPK6GWi+t+lxnuG58MJSxuR3AexCIfUuZjGnYx9JzVG7WF6YPZ9cZvHOYQtx6SyrPM9mClxpH8D4D+zKRPPNlQrwOikmeTxMAI30z3ZpdV/Pz9yi4TFDIa9GbN8v804l6z9ODJWaQwBwQy8UFri0Ox7AHV/z7NZ/L9aY3Hc2uH96R4+bFAa/vMC4cBvz75aKsI1z31WAvi+c6aBXcdO5n+I6w7fJ7Lf2+ENfxN6rc9xZUvOzs8fyKZwTheoRInhF0pQrno6i2XvW5np+p8Zqh5ctGjOCMi4+3me2ODHh8z0F0/r3lK1pW+2nfw7EL/Xyq9hzQPzyx75HjrpdzRD97NdlfoptMJpPJZDKZTCaTyWQymUwmk8k0R/YS3WQymUwmk8lkMplMJpPJZDKZTKY5el3jXJIkoX/1r/4V/cf/+B/p8PCQNjY26O///b9PP/uzPzuzBmVZRv/yX/5L+nf/7t9Rt9ulb/u2b6Pf+I3foKeeeurP/bxxOqVc5p3DBKBNIglWZtcdsHMSEZ06bD8NiXERfZIWRwKnMGJLFvLSNoT2swjcHhuAAtCWicUJ3+PvrLMdwXOkXeGPjwvwGdjXMvn/KgnYi0LIeD6VFhG09jbAPhIoO9sBWGyOJ1yOtnsk0o3BWlINOK+hstENM7bS9By2nI2crkg3AqvrMOZu3warfCUnyzSa8rNOEm7PSNmEz1y265Yzbuun3Esi3eOUPUkDl/N3N5B1fuyydXmQ8L0b3qZIt+k9P7t+Ns/PWvKlDWXFxzbg/N0eyOEf0cXWU8+R9rMGPCvNuP5Ch9si78g+jziRKvTzSPWP3QkjZvrE1qBb3qJI99KA+0TuvzJO5O5I5vWLHb7/YcAWoh+6Ivvb5ZXu7HrnuDm7jjNZlyXAd+TBgoV2oiff4+s8jL3lgqxjv85t/3gE9v+CtBCjLRQtTmhrOm+B4/ytlvg5txrSC7U74n5whlY0ZYtG2+thwA8ru9qeyN8Lobj9SNZRF7BPPuChHimMzOXyxTbLdiTHzQlYupMDjtO+Qq5UodtfLvFY3qqzHXu3o2zpYN9rFngcFzzZnl+E790Ge+elkszDTbC9/ekpIwM+fiLRGHGWUpxJu+AbQV/ruXvdvUWeU6CAJKZpnHVn10XADlSUTXAMc/JpMN+OveBz30Qc2ct9eT+0NQ9h3CBKbMOR8RzRZ2jhXi/L8fBil/vDwxEiMGTMGIN3di9kG66n/nYiBz8jfmWo+h3aSoeAJttLvyDSIRKqADgMRGM8+ZnzO8jYVlpxZKwvEn8vAWRbQF1OpJAhiNpDBMZ6elmkQ2zWBOb1viP7EdqYk5TH+N2BbPc2oKIC6B8bJdmPrlW4HE83OA/h2TM0T3VYV7WK2rLO93+qxnm/P5QW3QDs52NADh2CHb6p6vIRIMccKJO22rqwjkQazpovy44om4My19edSUukQ0yLA1b+oitjcwCIjgTQZFV3RaTLEY5dwDlMZcxFhJ4H/SMiGcMdQH6EsFaJHTluytSgi4QoIe3izzlcL0tg275WbIp0aMt/0bkLt5N5bWQ8pmKoo0om1wJnxHXZcff5OYlE6AQDtoQv+9xhEkdikOouf6+RLfP3yzdEuqWU2+rNdW7fTigr5mzK6w5ELE0BI9Pw5qN6ei7jV7DfaA1h/6CxQLWM23PFY3yC68v1dMu9TEkW053gHr2R9LWeu+8Nn1j6V3w5zyHFBFGb1yqyPe4M5fqTv6P/RpDTBbBf7UdyDYc4oDygNpA0gEhWIok4OBzz2OtEMq8+rI0z3CO4MpYiBrECaMg9SbwS7yaOx7y+OQrk/Qpwf4x9x4BzICIa45hKOI5phFYlx2O5Dri0ekHWOc4RuKbBfKtmp29f4/t9sXN9dr1VlWW6M+By7I+4nsNUticiqtbLF/cVIqI2YGD6cJ2RjJE/cIPrrBdx7GurOj9NuLHGDl+fTWTfuVzgeBJB3geuxE11AOeSQnxv5Xj/u6gqE0lFOyPAsilsxo0G14vTb86udxVJUKP2XtG+wt+NYB9ZhHbvhrJtEF9zE9ZBGrXzAHC6uNYpqb3ip4D3e5bxmm0jlGuGBzVu09qXVmfXQaL2nrA/fDTmsvegfJcVfvT/ts3vAWsrPIa+8KU1kW4EmOeie/EegYhoB3B/XsR5yKl1PG5VdibzX+1egjWXHzw3u07VO0KMTwewQRqq9wLf0OLP3rTFuJnj05pI93KHf/58F1BdAx4bGhN9A/rz8WR+HFyFd15IxsJ1LZFcZiFaSO/zDogo0myqOXpd/yX6L//yL9Nv/MZv0L/9t/+WXnzxRfrlX/5l+pVf+RX69V//9VmaX/mVX6Ff+7Vfo9/8zd+kj33sY1SpVOi7vuu7KAiCV7mzyWQymUymvwrZ3G0ymUwm0xtLNnebTCaTyfSV9br+S/Q/+7M/o/e85z30fd/3fUREdPXqVfpP/+k/0cc//nEievK/4R/4wAfoZ3/2Z+k973kPERH9zu/8Dq2urtIHP/hB+uEf/uEL7xuGIYVweGO/378wnclkMplMpj+fbO42mUwmk+mNJZu7TSaTyWT6ynpdv0R/5zvfSb/1W79Ft2/fpps3b9LnPvc5+shHPkK/+qu/SkREDx48oMPDQ3r3u989+06j0aB3vOMd9NGPfnTuZP5Lv/RL9PM///Pnft/IFSjvFs7Zp/FE5zycZl/PpD05n21f+LwFR1pqhe0VEB8HsTyxPg+Wkbe1+FkrRf7OnZ60qaDlZN1nm0le4SbQzlbPczfwzp1my1aNElh5kkzaRdEOcQMQFYnCYbzcuxgZcknZrB2wjlfBxny5KO0eUcK2nF7EdXQSL4l0aGFH23wEJ6THkSzTKON6roKla9eVJ5X3E7a9umAvQtsREdGNPFtWR1M+BfoQbelEVARbftVle7efSrv+1Tzbn1ehz24q63je5Z9bcBJ6fSJtQ3ehHGjbaqTSXh84JbjmPluG/OmT1BExg1a+z7SlDQwt9WOw1EeJPDUbLeafZPc1tQqy7ItQXmzeYCpDX3mRn5sc8XcqCtexWeXy4mnT37Eq+8QO2OgPAq7LhYIs79UK/+VOJcf1ujdRtkiwxKENCU+b7oSy/6LVG22WFRX1v2EpB59xmfbGcuzujvj+BYiDsYoFZ8Ax8D2ILcr3VPYQLcDttiPDIO1VeOyVoD2Uc1QgBPBUbm0XxX4wArzBy50mlyGSX7o/wnHN4+btLWnXPQo5XRuwICfK6vmlHscqxO70UvmXXFNKafoGxLl8refup7xVyrtF0g68ewnbC0sKXYAqQnP3ADs0TuT4L0K/QoSZxpZhnF0v8zX20VTNDziuGwW0c8u84rjBewxiOQ4fRd3Z9ZH7eHZdAKwNEZEP65hixlbPsUKaDBxGIZyFjCgIY4lFqIH12wWEhkZMVACLskBsD89UPPGR1QJTICLzUvWdIXFe8bkaTYYlPCWuowatinRrKVtx0XK6N5Lj/xTidgUmOozfRESXy/zzjSrf4zSU+ZOYPP6Oxlw8GnA/PQJsmXJP0/0it/U3tRh5M0643XfGMrAiCqAJNut1tc4IEv4MXNV0HMqyL0Jz3qjjvCZRNvsTsGY73KdWFB5llxjNNnV5gK2mEueCa8rHU56vH2SflulSLv+KA9gRhahLAG+I2JcVwK8QEZUdLvAx8XNx7dQiiUtYBPxKzeW2USFD4A1DGK85ksgAHMvrGedvISdj4mAKCIKUx1Bn+kCkGxWf5evJ1uy6r9bGesy/opvOFfHzYpnrrwwIiAcDOR+eQgxyILYgalJvMRLIQwRYBVehCRChk4PPMC4QEXkOIyYuZ9xuJUeuz4tp0eZu0Ly5+3icUt5NKO/Kie56FTBSEEIWCjLmXgP0SQrorqFCUeiY+YoQk0VEtFC8eF16lnC8XM7JtsY15hrM93q+QcRakOK8JPOEpKw45R80OgDXDB85BfyVSofPPZpy3BkBVpSIaJzy+E0B+7SRf16k24D50If1yGkgB18TFuktiPu43lopyu980xLnYbnIe9ztikL1TWG/D9iGsVoErgFmCTE5Kqu0M+Q6whh0pHCrh4BsHCW4L5D3q0E8aTu8SUXsLxFRGDPWD+PnROGmGjl+R4Lrqqspv3fSa8AO7GPq+fllx/7XhPHQncq1Yg/i7yHMKflEvr8ppYBLgT5wpab23eJHWHersYvvUhA/fEcF+9uANCsT40OGqhyf6fCD7w15PdGPZSPOQ3UcAdJktSjn0Ld+K7ebd+vq7Pot+Yci3d0XeI7vhzw43rXcFel8l/NX71zj36ug0Yb9/58c8rzTyMsYhO8CrtX5M1xrEhHdG/Lc25tyLHgplQgoZ4/7X9nj+XCcyNq7C3jTjl6YzvImYzFiUAcx30/n9RTWmMfwUmSqYgHeH2lfatjQKM4ounjpck6v65foP/3TP039fp+eeeYZ8jyPkiShX/zFX6T3vve9RER0ePiEF726Kjc8q6urs88u0s/8zM/QT/7kT85+7vf7dPny5bnpTSaTyWQyvTbZ3G0ymUwm0xtLNnebTCaTyfSV9bp+if77v//79Lu/+7v0e7/3e/Tss8/SZz/7WfqJn/gJ2tjYoB/5kR/5C9+3WCxSsTj/8BmTyWQymUx/MdncbTKZTCbTG0s2d5tMJpPJ9JX1un6J/i/+xb+gn/7pn57Zw55//nl69OgR/dIv/RL9yI/8CK2tPbEUHR0d0fo6WxSPjo7orW9969cjyyaTyWQy/Y2Wzd0mk8lkMr2xZHO3yWQymUxfWa/rl+jj8ZhcV3J1PM+j9Mssse3tbVpbW6M/+qM/mk3e/X6fPvaxj9E/+Sf/5M/9vHLOoYLrnGPtrgO3sV5g/lEnlEzpNGMO8gQAWaOpBDwh36ngcBMgg0zrTo/53P/PLnN9AsXc2yjw/ZA/tViUbCA/h/xKhv+MEpmHssvMpEHKz8J8ExG5GsL2ZS0XJVjoZbhGtpW+XwVYyltVLu92Rd4v53C5vtDj74z7sm3GKZcL2XbIqyx7Mg/RlNlKbYc5VyOS3McE2q2XMO/sBcVifj55enbdhHY6UCzrlZRtkgses7yqeXnDRf9izrVms4XTi1lo3ansO6HL9bKSNmfXmv/XBr5Yw2Ge22qF+0pOdQjMUwcwl+3pRKQ7ddkOiqzebiz7ZTXPf9Gi2fMoyBI185xuEsu2Drr88xQYhFv1gUh3/dv45wz4adOOasMDTjcJuI/92YHktCInrATnFtTzskxLBb7/ccjtfgYssKJipNXgrAPA4Z3j4W34HJ++Yflsdj0I5V8NfbbDjLlPtPmz+yMJMW/HXHZvzCz7JcXJR2YwsstKOVmOO0P+ngcs9iUV0zbgDAisilpOxt8WjD289wEw4DVvDnl9VajML/Qkbw456K/GVNsfcTwPUs5fSLKfb+TqFKev62n6Qn2t5+7lskdFNye4m0RE41Pufz3iwDNSTL0rUMVtaHodWk5Dbh88CyDO5KDKgL+O/e1KhW+oWf0ZnB+yCGcnHIYyoeNcPNkGit+OHGSfOJbGJON+3wHGIdw6JXm/NFOgyi+rWZZnwSy6VzlPNKR5ihw4DyLjOSWvWMXIO65nzEidErBJna74Ti/evfCZSW7+GitMeI5PPJmuRdyPugnn21PUzJrHHbAAMWQ01TGN03mwhumroNGNuA0W4DyYg0kk0j3OeN4sjhiR0CzK/OF5JNi1D4H7ehrITr8JHFkMi23FOge8Jm0Dz3jVl/0mD+UdwJkUPXUWUT/mG1ah/jcrsn+4gzW6SFfqMhgMYb5+ccLrDs3tnsQ8B46KfLbOarol0qXwPRf6wWquotLxc1cS7r9uxmUqebLsOJb3MubQnsaSz41jdDzldGVPnmOTBz6vj3sOkm1dBJb6indzdj3y5Jq3nPFaIIbxuZzKthg6vBZIYG7zvflrWeRQ63RLUy5XD+4dQkwLVXzruVwvXai/Rv6SSIcx0gO2tq9Y+DHsuRby3Mdq6bJIN06mFJMK8m8Afa3n7tXKk7m7Joc1LRS4b+PZWgN1nhHGzwU4E8lV8yTuN09jHv+ban/TgLCx4uP6mvuBRvo+HFzMN9d7MZwTcsCAR1Y3EdES7JvxfKSxuuEY8pGHPVdeLREWYLHhONzPF5JnRbo+jJ22x3FwPZP7li04pwi3enppslLiX5RzeIbMfM57q8bnFrz3XTwnx21Z6S/+MZ9XMYF3HXujOS8jiGjN53q9M5BPPg54Tt2ucRw8msg6/y87HMP7sEzwPTmPxCnXeXXKfafnSOTRId2fXZfgvI+SI8/+qMF84cJCLYY5oK5eXi3BFs6H/eVQrUewP7dDjueuousXYZ02hviLZwXqfOC273JZ1mUvxvMv+PeRYlnjOQWohwO5Dqpkzdn1MqwVhyTP1rg7wHdwsAbJZLrA4TjhQ9lbLl9/y6Ks8wwOPsoOed58+Ytyfvifx5xXrIcFda5bGfrVrQa8y1F11IWqSNS8jprHAtf77qswxu8MuYNczTZFOgxJd4bc4fYm8n57I77HSYTrL77BRlnWZdHl7zQhLuuzFvHcygmc/3CcyHcTgymX6VYDYhhJVfIOhen8OIJ6Xe/Of+AHfoB+8Rd/kba2tujZZ5+lz3zmM/Srv/qr9KM/+qNE9GQz+RM/8RP0C7/wC/TUU0/R9vY2ve9976ONjQ36wR/8wa9v5k0mk8lk+hsom7tNJpPJZHpjyeZuk8lkMpm+sl7XL9F//dd/nd73vvfRP/2n/5SOj49pY2OD/vE//sf0cz/3c7M0P/VTP0Wj0Yh+7Md+jLrdLr3rXe+iD33oQ+T7/qvc2WQymUwm01+FbO42mUwmk+mNJZu7TSaTyWT6ynpdv0Sv1Wr0gQ98gD7wgQ/MTeM4Dr3//e+n97///V/1855YSDI6nkhbThOsKWhdUm4KYQeq5dmbpi0xA7ADVgAn0nDlAiQEizha/g/n2K+JiCox3+8Q8ARJJg0LaBseJJxvz5HpfI9/TonvrREaHfADPQJPl1uT92uCVQXtdeNQWn53p93Z9WLAllq01xHJusX2yKtyoN1oDPUfAoqlpuxTCdh1+w7bctD2SURU856fXQ8cto6OslORbpCy7X21zBbCpbAp8wpWym3wOIaqw0Xgo7nB7lphlyYi6sQXY182/JJI5wZsA18EbNGSL+vyqsvWNHRWPx5yfU0UouZwzHnai9ni/9D5kkjXDR7OrqsF5i2WSOa1CVbvMlh+NUYGKEgUgz1nEEvv6FmH2+POgO3Y+ZG0eq7e68+uK1f590cPZZ8YBGxrQrsptsWTPPF1I89jqKJwIncBO4I4FuwTGkPVKvI90FJXzcn+sepzJUUJJyx40j6JVTuCjPdJWqa6LsenagT2v0Ta669U+edlGNfN/Hwr2gkgBDSWpgjjpgv2uLNYWgHRyoh2TBxehxMZjzDOjKb8HG34wv6HtrlhLMcD2uhHxPUfO9KeOEqmNJ2D0Xg962s9d2+UnuBRcir24TyMpBxt30NLIlquKznZd/Yj7utTgXOQPQH70v6Y0+FaQltWD8H++EVAtnW0dxyEuINyTpaqFvB8NiC2RTtqbgzADhxnYF9VGIOpw/NmKc/W4rwjYySiEOrEFtacQhsUM46RizmOE3WFLUOETh4s0hh3PIWD28jznFxAVEwm437gsHW85wGOzJEoq0cOg+iKMP8vAnpN52/J5/imcXePh9z2aIH3lB9+2edy4dxdVpiLQszzNSJbNmXT0FaJY8ndIX8H525tCcfu93JvPg5npcQZxPXIg5Gsc7QdIzrmYCIxHFEKVl7AZmjkDa5Fsf6OxnLcnE3Zqh07XI4V96ZIF3k8BjxC9Il8bguxCPmLMT5ERDGM82V4yVgvcLoXu7LsDx3GArbTR7PrCSBbiIg8l/spImUKakxi3tGy3lNTyxAQUK2U0SktkniYItRLw+M85FVHP4v5swqML72mREzjw9HFmDctxFX1iBEJnsJBRSmP8RSwjq1sQ6TbJN5nYFwdTmWf33eOZtePI2YQ6DngWql2Lsa/EfS1nrsvlzPyvYyOA1l/j8bcx65VuA2KrhyHp4BVBVf/OZwejktECGhyXwPWn6ew3sQtr/4OxiBcS2gsYLPA/aoFU8xTVdnHlou4DuT5C/f0RPPX5BOFq/uGJR5TfpOvHwzlPDKa8tg5GnOsGqfzkbSbFb7fVlkh5RL+7DFgVhLRTgrV8YBxEc93eb+1VJP7jFt1HtcJoHTHUxVbYGvmAx7iLNR1yT/vA5Kjlyjs6eDi9fiiK2NuL+U4i7F5ia6IdBEgQ3qAZWs6Mj7dzPNaI0q4rUOYJxFrQSQxIRib9R5rH/ZE+D6oqNZV192NC9NtVuR6abt6MQqjHal9bR/ecUGfbSl8y6LA0vD10URhZKIaXaRIYTLx58sur2XXHbmPP4J6QWwZajSVefjsf+d5BPfQL3Rl3hAXsjvist8fyLZ5rnUxGkdrpYR7C1iPuDoGIcqZn5Wp+QtRTI7D46uj3u3gOiaAMp3HVfJ1x+FxXYWxq8IWfanPN8F16EjtpyOBg+V6WCGJ1sO4hdul9aIqezEVSO5Xk95Lmkwmk8lkMplMJpPJZDKZTCaTyWT6suwluslkMplMJpPJZDKZTCaTyWQymUxz9LrGuXyt1Sy6VHQ9YR0hksgF/KN//cf+u8Q2v/GUbX6NTFohVx0+NXijzA87U0iT/SlbHgYu40SOp7fnFYHCHFtT3ZDxIUVPohTa4cW2XG1JHE45TwFgBcqutEyijRbRBUkmPR1XwOZzAPb1trJWoVXzaMK29GZRlgO/VYbevF6WXTtK+ectYpvaPCswEdEJVFGewLKayQ5yzWWble9tza7PYolmQIsoWmrWfXm/IlgAF+CjO33Z427UwbZS5MzuT2TZsf+iU/tG3VPpuF5aRU6IeSWSlqIdsCFhn/JdZTf30BbF6QbxPs2T53Af6ztd8VmScZ/ogd12QXmIpB0Y8pDK/z988YzHJBqF9Inrtx/CCdsP+fLOQNrAdsFmdgoHfmskD1qcjgIub0t2c2GNmoCtqQhWrbr6zjqcRH+rzlbIdij72+6Ef+4C5maqTqdGLM016DtHp9IyNSCOVWib1Qio8fTittF6rsF5DwA3cxDIAu+M+bNHgCfQllogVNEAPL9od9QYj6OAG3EfsAB5NYVuFLguEIugY8sYbOVoA63kztsR4zegJfxrrZOAqOjK2EkkbdfNAvft0VTajl9m2gll0E+LnowTTc+HdPz7VK0GArDYpmCxRctwR1lbOyGnuzvi+a+gLLWbJR6vaNuOlP2wBFiDJOP41EEcHBFVsyaUgxEuA5LoCJQDaJaqI9c3i4CBQAusT3K8NnNcjjVAgWgcBtqu0QYagg16Qn2apz4dz65dR47rBvHcvZrxegnXH0REMQHKBuynE8gDEdES2IFxCZJT4e1kgn2MP2wpWynO/2tF7lP5JdknnDZbiHHuvlySa8oSzGe9GFAxsJyrqF0Bdisdw1G47vvcGSD9FGIJ7cU4vBDvRUR0FvD9DiO27juxrCPEiaDF/CW6I9J10oeza8SbbHlvF+mWM57jD11en3RcOW5KKT+3Dvg7HYN6MMeEUJlIcAkULiFyuV9VXB5P9cKaSDfOeK6NUu6zG4AOJCLKwd9L5WHsanyjm3Heq4A00rNzAzBZDehvOgatlnmsVKFfNdVapZrD+Zp/v+TLzjgcc/uOAVE1jAGDUJC4BEQp1AqMI9Bopy5xH7sEKCtECRERjWJeK566/NxiJpGD/ahMcTYfxWV6orJHVPIygZQiIgpgo7Hmz/97P+xyqxAjg0T2nUHMn00AoTeMZbvdH/BzMd41YB80jOfHwUXowK+2dtvwOT96n4E6DLjsiIwlkjHzUoXzdziRz8Uxdb3C69dJIhGyuA8vwB4/VHX5PA8BulTiQHYSyjH1CBAubYjniF9wVHQ5GHN57434Qe9ekff+Uh+wmy7f73vWZR2tlTiW/vfj5uw6SGSdxxn/fJ92Z9cTV64tEOf6vMdxFt+VEBGNiJ9bgLVYM5OYvDTjNX+NOH9VFU+2a9gv+fp2D/fgsuwe4IJLMNmqVw5inXwVULOKmkG1PCfcBjyX3nsWoT1wqNxT7zB2gjl4RIVVwfc5uP/FtQ4R0WTKdYY4HXx/RkS0lPI8WgV8ICJRiIi8EefjEFCOoq+MNGOFv4P1dxLJdM08f3gJkEj7Y5GMRoCUqcE8WVPrtAGsi7Br96eyEXGfi+sRvW7Bn/AdkKv+9hpvfwzL4bFaC4QpJyxlHHfwneODvhxDiOBDtJDGNQ9SjkG4vnnrgkQsIeqwG81H2VyrpOQ682Myyv4S3WQymUwmk8lkMplMJpPJZDKZTKY5spfoJpPJZDKZTCaTyWQymUwmk8lkMs2RvUQ3mUwmk8lkMplMJpPJZDKZTCaTaY6MiQ66WsnI9zLqKlbpccDsnDc3kaMj/w9iPViZXbdTyRpD+cDbLQEs05/K+9WmzPPZBI562WOm1ik9Ft+pAiu6lufmXS3JMoUJ8KZjfq5molfyyDsEZrNiUa2X+XuIgVssSK4QYpeOJsAqzEve2bXp5uwamVXIVSOS/NQ68Lr6CuaFP75tke+3AdzoT57JMm36XP+FkFmK40yzzvl7l6p871Ys+4Boa8BjuQo4iWy7HGTpuZZM+FydAVTItQ4Vy7oN7M3dITOn1hX8VLBK4RYDhc+/32co2TgBrh9w0DUjuA88wsABvqli1FYL67PrmsNsUs2h342ZUxcBYz0ZyHbfKjMjbR3q/yySZe9NOR9LBa6jMJF9Atnn2M9f6MtyBMBtOwGIm2bK1gtQZwCPK6hO0SzwZ9+yyOXF/K36ofgOct9rBf5OlMq8fqbLQLsMeGKXSrIuL5W43wcJ528pL/t5H5ih1ytcXwX1X7anUEmHwJBfKco6SoHTmsC178p0i0Xku/HDAoUlHQF2rQQMPOzz9/ryS9ifke3sK141sraRh7lVlek8h2ML9qNvXJLtXvZSmiRE//t8PLWJiL7UG1POSWilKPviQcixBttw6IxEul7GFVwgjhnXkksiXT1/8ZJpms7n/5UhiOMYOM+ehvM5kOPrybkRhVxPHXORd+wBQ9PPJG8yI+6zPTjXZZJIjqQLfR2Z6K10WaRbg3MB2hEzDTV/GeulAnO3PCtB8ieRg55A3M+RnB+CDOYHYEU7as2WA542lt0nyS1dBoamZrujKtA/joFLq9G4uxMuh+/yd5BhSiRZjQWX875UkPHpm5aAeQ1x8UxxOF8K+VkYF5Fr+Wgo741TUQz9Ou/KuuxGHFhPU8mUR20XmrPrp5vQLxVatAPTWc3ldiqo53qQwQLMbXnVJ1w4Y6Xs8jpZMz5PgH3eTZmJXnAkXxOFbNFAsUBxHsDxiuM9Jj3fXMzS1mUqAjs2c/neY5Jg1U04t6CW4z5QV5PyScCf9ac83+t9wX6E8RO453l5v0tQZQ1gwMZqjXoWwT4D+mKkBk6f+LkpnNGE1zpuORB3sow7VdeTjHuM+3k8L0CNyQLEvjLwjYfq3J576T4laq9gOq96PqWyl9K2gvweT7gjfLyN6zTZx1Z8PH+Bf69j7kIRYx/3WTy/hIjoSwMeO2Voe298cSwmkuMDe3ZereNxRT2GtbsT6XUk9zGM0/psnU0YX9sVng83FEO+nsc9COyZc7KS9qCM+Ml2TT73eoXnr8/3eNzcVcxrjHeLkCdsm34k9xlFmAeeqcFeTJ1jhfu0Wo7L95aNY5HuqMvMcXy3E6h1Bp6Fg424mG6IdMg3b5VhHVSUE5gfLMyuT6bcp/TaM59xH3va5zi9WZHlxdDqQVbfvMDfH0uktHjvgV0R929E8oyKVRhPj0f6XQJy7eE5mUy3B0XE5w5jOdaKhHGWC4hrQyKiOz3+Hs4xJXXYTKPA93sY8jyA8zgRUeJyRdXDG5A/2YYPU54jYofnjqdc3he05bZbnH2H/fw4kOkqsOG8XOby4R6cSLYhvtvRvPrdEf/iUcTnhfSdnki3MuT1+nKBx26mxmEN6hlrua7OyMM+h2e+6dh3vc7jpjZpzq47cBbkwVSuGw/cR7PrVsbvV1eylki3A+l8OLNoYyLPR0HmPb6XWVNnKPZilybq3c882V+im0wmk8lkMplMJpPJZDKZTCaTyTRH9hLdZDKZTCaTyWQymUwmk8lkMplMpjkynAso5xDlHVKmaKKKwHCw5WGrrCzcCVsjzkL2x2jbENocJvCZtmptldiWgLbL6w7bi+4NlsR30JoGRINzSIOtKn/YBSulcscInEgNXOWtgixTI8eWjFezDZ2AvRjoFcI6QiTrHC2d2vrVBAsV4hh0nT9I2JZT6q7OrkeAX9E2clQRsCNDheRAe9GNKucv58h0A8D14KOwLYiIrpa5Lhtgw6vlJVfl/pD72+e7fA9ti0ZLEVrj90byfhNAs7QAizBUvqHdiC03V3y2Eztg+ulE8t6Ps8PZ9WLGNrc9ZZUdxWypT3N8jy33zTRPFbDhals0YodWilyOnYkMfVjE05A7urbHYd/eLHHFajRDL+LPlkrcIFGi7Y6QDuyO6yWZDm2IiKK5UWWboO4fRxOul0dDjiX7E4mH6MPXhjAo1xWRCu+/BOOulJMdLgdWaLSyN5R1HJErON4XCrIcj8acERx6myVplV7xL7ZO746lBR7ROxhONiGej5S9KwV74hWwHZ8qvNQEuAja2o5aBVtZA4ZAPads/Zn9T/drUZ8mlKOEmols6wlxnxg6jPjQWAT98ys6TvviZ2/KGIgVn8eRxjQBaUTM69iWVWWlxhhUgPlmpSTHF4aQAoyhfizHzY7DFlYsn8ZXuGAfvZzdnF0feI9EOsSnbKfP8HVJ4mFwrRIl3LlDZZvH+QIxbe1QBt1exjFuOWMrakycbugMxHd8wFyMPLb1xtlEpIvg3hmY7UOSttITl+evHNi5FxTKBvE1Scp1rvtHF++P+IpBTaRbLPA9ggRQW2XZJzZ8rouHY45Pj4byuSOwUyOKbQ8QGLEj+9E3lRmtd73OsXig1mK9iO934D7gfKsx1IrfObs+CS6eA4iIfFjQrZe5H2n83Ytd9lOPAJ+xmUkb/iasm5cBQVb05A0fTLhtIpf7y4qy9W8WeU7NQaZU9gSaDfM+hrlCo5gQR1QBFMtlVSZUJ+N8jx3Zfzspl2M9zzFMuebF/Dok7st5tVUMIa4OpthH5Q39HI+VPYg7anoVMQ3X4Weh7It1QMcUM45BLqA69BjPOZy/s/De7HoyPaN56gBWsOVcFp+tgq18OWty3hQma+iMbO5+DZqmDsWOQ9sVGU9wzTuCdWmmxkoTkIg7Y67xiVq7bwBytAtYq/ZU9pcQ5jkXMBUDmCsuOwviOysQjzHWh2rfjdinOqwP9d4f8UYDKLueRzC+e4Bbu1mVa+FujCgbuPdUPhf3J4hoXVb4sKOA4/EXOvzZcSDZFhsljrPbVb734YTvfRrIMl2HtfGzLUZRnMJ+hkhiHqvw/iGfl3mtFLguVnx+1oHaj3iADPNinvMqORmsxgk/a2fE5V0oyn0VNhViS/YciZsZZvxzP7o2u+5FMu5cq3OdY9bftcR5SBRW5dOA6sRPEK1FJNGY+Ine2x32+Fke3FG/OQkBr5WDKFhw5f2qMD8gEukokP23nfH6rhpxPyi7ss5xfJQhHrsKAdXILsab6XG4kjI2JMlwb8flmKj3RqPpxWuBbijr/HbGefrWRb7H9YoMXPjOrIM4IhVbUij7qsfryGLyajhIeF+o0Lr43gjXLeoVhsAvJYCbKak5Hu+HSy6J2ZTrFtxzlDNmVyEyh4iokHGfeFOOx01OLRYfDrhuz2Luy8VQ9o+DwD2HRZ4nm+NNJpPJZDKZTCaTyWQymUwmk8lkmiN7iW4ymUwmk8lkMplMJpPJZDKZTCbTHBnOBXS7/8QmsypdQ8KGhP/r4HvS14B2MQKL9ImyMaPFYAQe7tFUprtSYytOFVoK3SPLvkIzCBsY3+/hQOb1KiAJEP+hDpgWdrQOOLXqyiFSzXE58i7aVmQXezhEOxvWl7ROnMLJ7HjC9yiR9s6UEL0BNvJU2ZXA7oFWIVfhRFBYf72Ura1FVSZEcqA97rm6tLb1p/y9EDA3ZU/+X1YGdXEQcPn2A1npnznjdCcB18ulikyHbaptMKhK7mKLkrbXo6BpqALtqW2HscN1fgR2tlhZKYOIbeV5j627E2cs0oVgNd7IGM9Tz8n2XAYkCbbNQHYjcfL5MdgLe8qC1QVkwlNbXM9+SfbfXTjVPEkvtkgSEY3nWL/0qejHMB6OwAIfp2xx2ipLC1wMp9kfg331WFmXJHKFn6sO/6aXB9wey0Wuhzc1ZX87DDhPexNu3yvVikj3lgXAWkGceWkg23CzxB+exYhOkvFytcrMlU+fsF3vMJR9Hm1waCsfxPPtW4jU6ECfeDiR9jO0+L+aMIZjfu6PZF5HU6Iwtf/rfq3SWIQqAZaKGJEwBrQLEdEwa8+upxnH7eMsVOkYdzIYcdzBU+6JiArgV8TT5+8NuW+31NSTB0ttI89jylOWxAb02TbEqjvOXZFukHKcRaSB52jbMfdn363TPLUAJXG9zFZZPPGeiKgTwnyYXIxYIpIxtxNejBYhIkoc/mzNY3zKIiAcDiZy3OXBvhtnl2bXD2lfpBs53Qvz5qi/LxlnnKeSwziMsZqXShl/hjinTiTn0CLgdaoOonaUrT/m+gtTLm+SyjVIBxBfQFURsZ2I6HHEtmi0zo5dHg9ooSUiOp4wvuL5BcQUyrwiprAx5XZKHVl2HKMPBjy+NssSqeQ6F8djRDsQyTqbEKBd1Ay24bI1+0qN46xGp61NeAzc7j09uy4q9gkOyy7034I3fx7BOiq6fL+WK+PHAvTZesp9arsmGWv4qKMxj+sHiazzY5f7fQgIAkThaSESpgHYEiKidY/zhOvG4zAQ6fYAg4blLSl0D9ro+7AnOiCJXMnTxdZ0RLhkCumXm4PqKnoSnVTwOKYhwmUjXRPpFgvcBlj/n40PRLpOtkMpzV87m57oT08cKrjOOZQCbotqsLfw1X4pEHtU7jurpfnjFa/7CiFw4uzMrnFuDAEpkaXPie+0El4LVASmRcaWOqyvse8gvoVIYlrujXgQ5dS81AXUw4Mhf3YUyD7/fAOxoHyNeFoioijl+P4m2L9mChNyBs+9XOHrhWJZpMP1dQ1QhV3YhzYLukx83Z4opiRoBFiaXUj36M7W3O/0Y+xH8jPE6V2pcv1pzNgRoGgeBNwnzoKRSId73imsYcJM9jdUOwWEnlr2j7schzZLHLdfhjXlYkG2J7432oWlSqT2FIi5QRyJWj5QM8/POok45oYqzhXg/RfuyUtqDg0Tvct8Ij2DpjCXH7gcZ0+T+/K5LsfwzOHvuGrewPdBuEYKMpkfXEeWXa5MRMDtjeUe4WVE/Hj8HE18jFLAUAFetqIwj/iaDDHKfbVf3a7h+y/ENcu1IiJIM8iUHoeLRdyP4DtQmb9TiAUj6AYau4fvXHCs+YD4WZ1KTFYJ2ikPfSrW7ZTxeF3wMbbonsSfHeH7PbU+f7nrUqze286T7c5NJpPJZDKZTCaTyWQymUwmk8lkmiN7iW4ymUwmk8lkMplMJpPJZDKZTCbTHNlLdJPJZDKZTCaTyWQymUwmk8lkMpnmyJjooH6cUcFNKclc9XvkmDFTZ6pwTvhjCJzsp5sSvoUsrnt9TqfpUIuANWsCk6gdzWcuopDBdBZLbtNCzJlY9vl+scoEFIMQWaf/96WeZ+CR7zFLaJzIsiMTsg9sy2AqOUuIvcsDj201LxlpyN5GprxmJjUc/l41x90e2YwLRVmvaQbDI+LvlxVXHFngyIFSCCxaLDAjLQesvI+eVUW6hQLX5XAK3KyxrPWDMd8POaP6uYi6WvL5h6vysQTYZ3q5zwV5CGxdIqI6MUsS60/z9FG97PDC3yMDnYio4jP3bS13i6/TVZEuBu6lAx1TM9dOANGJvK4VhdpDDl+UzB+TQcZt8/KAb/KmuuRnNYucD2S4tuUwpK0Kf3YScDmGivuK8WQK/NU+9I97I8lBROZiB2JGojrITcAg+3CewUAxb1t5/qyW48q8VlE8+DpzzJCtX1azDaIG8bwFFQqoBR0Y+3LOla3zoMcFwTLqmLZwMSKVvnjGZdJs7Rz0sSly5DzZkToR949xxp1qO5ODrQYFaQOjzlc8XcchijSY0HROLjnkkkuH04H4/dDhn5GDrrmULrDywoTTRep+QYm54DEt8bU6g6MAnL8NOKNiCfoeck+J5NkayFysqHHTAJZ6dwoMaOdUpItSLmPmYnyS3FIse5Bx2fOO5CUvAhcZ+yly2YmIHowYvrnvctwvZfK5JHHikB8Zw3MwDyNX+W+tcJ33YgmY78KaDWNu0F0R6drASw4cZpoiI5+IaJzAzzDP5R0VTICJ3gfm4lEi+xGy+jdLyKvX/FWOJ9i3i2FDpKsXgPFd5vbwXFmXexPuixm0B3LQU1WmdsoNhczcSyruI4ezFTN/3HHkwmAI5wwglz0bL4l0TTgXYAQ8TWTNE8kzhpKMYziyjYmIRnAmwsL4+ux6S3VLZOCWcziPyzF+BkzYE/dkdl1K5A0bGa+Xig735TgDTrEn++8zLjO5R8A313uO8pxzaMaK8xzBYDt2+bMSyTMQfJinJjAeqpnkh+M4RA66ZuMGsP6fJMBEVmfS4HyLayyPZN8ZuN3ZNfJwe2NmCS9WbuJXaAqc/Eqe15FppjIBGhGvS4fUFJ+twF5iCnNzMZPxsuS0KKX5zzA90b3JgHJORA1XrqVqsE97U4v7gWYGv9Dl6y4cCDFSC8kSzFlj6ItlNS9dyfgchGPgL2OM1Gc04ZkeuMet5mWM1HP5K/r82Xx2fsObPz/gz3cHPA6RXU1EVIVzC97S5GfFio2N6+ZDOIMrSORzl4tc3ucaeD+Z7t4Iecn8LHxOoDYkWGef6XLbHAVqbQzXLrCr9XIZ37fc7cPeKZ5f580i17k+MyMGlvUo5rNYCp5+b8T95STrza6R1f3kfjyP4NqnqNZfXeK4ncAZWY0Cx+ZBLOMlbiewvl5UAfhFuK7nYayp/ruKLwngPI3DSPLgkV9dh72O5m7ju7UxTG5rJdl/14jPWDkOmrPry86mSLdPR3w/WO8X1Zp3xeV2q0J5HwZqneZgP+DBe4uzQFerMm7d7XPeh7BmyKu93SU4SwDPdbgrj2sS6yzN8Z+nbgQMcxUzWvBuAt916CNonqnxvHlric8m0ecj/Pc9nlNxDaj5/Dj+T0PeG/czjlv6jKGew+sqD7j2OcW479Le7PrlAa+N31SX++4WvONbjbhPHMSy/4ZpRNNs/nlOKPtLdJPJZDKZTCaTyWQymUwmk8lkMpnmyF6im0wmk8lkMplMJpPJZDKZTCaTyTRHfyGcS5Ik9Nu//dv0R3/0R3R8fEypsjJ/+MMf/kvJ3NdajYJDBdc9hztABwoiSCbKLob2Mfzfiacb8v8qtstspck5bEvoRNKrgc6ZhyO2IXRDsDFHEiMRpvxzs8D3Xs9L+8Nmme+3XJyPcBhPL7aSBMpW2o7YwpIHVMnuRHYxxGhEYP3aG6lKhx+X4MEbZZm/PriSpmCzaocyg/ittTK3x1aZ0x0qu9hxwHaOyxW29bxtQaE24FFLBbAJerJtCvBzF+qrmZd5/aY1trD0J1BhJ02Rbjjle6DVeFE5zLHd1uB2b29J69I+POuTEbdbRdkdh2DzPZzwwyKwJ0+UjbXhMqYln3G+x0Vpm2/lr/B1uji7Lruy/yZg8fcBnbDoy7GGdYEIgmZBtuE+uIjaIdv8Vnz53CWf7UFHYKlDqxIR0cEYrJ9gHdWYkGaBLXFo1dYIErSwP8dOedoqcR/dC2Re7w4AMyTuJe+NLrPDgFNOZPcl5Big1XalKG1PN2r8gIMJ33ygLdxgm8c8Beq5jyH24Wf/24q0Y56E3Ni3h/Ontha0PeYvgNgZpTITS2DvbOTme+rujoZzP0OdQuUGCTf2ZdU4K/55K+1Xo7+uc3eVfMpR8dz46jrcR1ya3241h+2i5Xxzdj3MyfiEttCNAseCkrJqnkEMqcDcG0F1TxJZ97X8xX/ToMeD73IZ6zmOpaW0JdKN0hO6SNWcRJqUM0Y65AFvkk9lX0RbOVqw+5EsRx/QEYhFiB3JskLkxBTmi1a6LNJtF5qz6/Uyt+EY6q+Zl5UUZ2gjh7ifk5OjD3OM73K9xNkVkS4CsBe2dDeV9tMCoEtOpvyZS7J/eGA/x1hfULEZUUDOmPtbUfU3xATgelVb2y+VeL6pFdjSfLvPv++pMlU1subL0mgynF8rDvfLWK0fEkCxoeW6ncrnbBUY74LlHarJMYIYNoV7e8ryGwPWoxPx+NwZS+s4Yt8Gr4Lk6blscU5p/hyPaxfsB+2M79dN5LxRJ26bPPSpNJP3TuF+iBOoRU2RLg/tkSe+1igLbJvImcB35Hx6EgPKBjBSRZLW9lLGZcd6GZPsPBOHLd1liLEF1YbHEE+CaXd2HUP9dSb38Su0WHqaLlLVk3FwlJ7CNcfOsSvj0UHI9YdlqpO0jpezMk2zkB5c+PQ/v/7azt1O8TwWi4h6gMN4mWkY5CtsRjvgAYtIgkQFvzPAfYr1a06iihBReRpwmwbJtdm1RnpilpAoqad03IsdA2ryzlQiMxEdcyt/aXa9WJQ3/MKA42cf0HXHsUxX6PP+y4d5fEEhF3CNfgJ7AY2RCFKuv3qO71FVqB1cazdgn3sSzl+LffaM558pzLsaFVWBOFbNcZy4XlcxA3Bzt0Me4xqr4gDG1x3wWqrgyvlBzz+vSK8BV0ocC1dgv59T70Repj+dXfseP3cR0HBERB2YKx85HONWJs/NrpsFWa8HEy57q8D5q+V1Ou6MOMU0i7IusWp9GANLUzmG8LXUIuCC9XqkAXlCvMnxRCUEIepltSTr/PGQ121DmLs13vfpJhcE106TaUWkQzzs9To/61oFMIpTWZdnId8Q3yXieyciib99NOT83RnLdQYBcu0mdImCK8uO3RLoUuQ5si4Rl4T30LGqCVjhKpT3hT05H+L964Cdfqkvb3g45vGbh/WvC3iYYzVbdoOHs+tGcYuvnTWRLoX2bTu8LuuEcn2zBi9Vn2lyOw1P5fyTZAm91r8x/wu9RP9n/+yf0W//9m/T933f99Fzzz0nmMQmk8lkMplef7K522QymUymN5Zs7jaZTCaT6fWjv9BL9P/8n/8z/f7v/z597/d+7192fkwmk8lkMv0VyOZuk8lkMpneWLK522QymUym14/+Qi/RC4UC3bhx4y87L193jWKi2M2EFZuIqAPWIzxxGS0JRER1sNKgtWJvLP9iYAnS3YITcAN1UvbLA7ANgiNjCn6bUGEHcpCnClhRlnyZBzw5F63Q40R2CbSj+B5Ys3LSHlPPcZ3tB2DlTedbTqpoQVZWqIdjtnQVXLZk1JTdOQd2lC5QJa5UZV3WwGZSg7zjY1UW6G2LbMd61xLbV6/Upd3mdpc9NmjtmapTjFt57jtogUlUunqdrVV+kb9TPpP2rm9e4DrvwqnciwXZf7ENM7D/fqYrLVguYZ749wuutDgFYLv+jPO/ZtfTjPtyw5OnZjcyQLOAhRjxLUREl1K2TC7nOF1FITTa4cUnJ6vDv0WfQBRTSzr0RNsPEv7SusLIXIaTsgdwsri2O9YB4fDpM26PQ5J4iJ3hxuz6NOS2XlI2uut1Ltjf2WDrpwv96GBf26w4f6I9Fcqm7PF4cOH0akS7EBF97oy/92DIFXizLiuzCve7BBa2Y4VLQhQTxjeNr8DcYjvtj6RVC235L3a4zptF2ThouxxAn0CrfMmTcXAKGayDBfHRQI61LlhqnwJsxoKy4d7tcx8rQd++Vj1vY8w5862Nf179dZ27U8oopYwyhVLYSNdEmlfUc2QMX6eF2fUo5bYpODI+BSl3WsSv6L8JDFLuF8GU+1IZrOIbZdkv0SqL47WnQl07QiwV32NtfFmkO6Xbs+sM7I6IbCGS2IvlrDm79l05Bub94aMeX4OELfBngHfISK4Z/IznlaHTnV0XScaTKzW+P8bt/QnXf5LJ8YXrDOwRW1U1rgELhul0PB9BnPjSgPtOqBA1A8DIhMQN13Ik6gERZL0IY4jMXwvixlaVC6+bogPZwHlJkXbEPHKtws9d9Rnn8lJP1n8d+jnG31BxD7vwsCDje+s6QpybC5bZU3dfpHs84vXJis/fQes5EVHbYdYDoueq2bZIFwMiYQTXnzuTldQQ+CWwCbtHIl0J+m8h435eIzkvXYJ2a8LcsTDmNdGjyUh8pwoIAdxXjBRCEtdBk4TrfOB2RboxcR3VYC225DRFuiKgYzzAI8YKpXDo7syucVwPSE7ejYzb8GqZ66UTyj4WpvwZLocP6UykC1IuhwN7nQz62yQ8EN85A4xXmsGio3RNpItTwNe4PB7KmRy7HvTZCPqRjm9lKgocxVerv65z9ziLKEcOhVk8N80IUIWLuZL4bMnnmHlTbpGE9sacDmeLWC2vYljrXYO5B1EqGn1wCFhA3GcM9M0hch+NAT2VyQlnkbggmIeHA7Xfh5IgAiZ25vc7nBNw7U9EdBW2emJPpDCvuLf9zBlfa2QF4h2GcA+sFVctLNBhEcK4PnUPRboAcCcliGmdUNb5nRHP14+yz86ul1w5/lOH6/YUUHZpT+5/EQ3iQXsW1Z4B6y+BvcWSK+PJPqBOHWhPje6KYT2BCD58jsaUPhrCew/AeDzX0jg4OaZe0br69aXS9MJ0oXrPg1lH5NhRMP+9DGIK9XoO8STYf9dLsrwF9+L9Zj/S6fga32utV2QbYvkRW4TvtTSiZqPMv7gE7ws0BeiIpxu6N2Yc2diROL1hzP1lDO+XtisyFuDSoAfvg7pqPDRh6kXc0oJCIh4Aqvf+HUaplBSmeAneN90Z8HfGUx37WLjnbXm8RhgGas1W4LpA5GYlrYt0Yxd+hsd2Izmn9KFeMH8TNfd03Q4l2cXvmLT+QgeL/vN//s/p3/ybf0NZNr+STCaTyWQyvX5kc7fJZDKZTG8s2dxtMplMJtPrR6/5L9F/6Id+SPz84Q9/mP7wD/+Qnn32WcqrQyv/y3/5L385uTOZTCaTyfQXls3dJpPJZDK9sWRzt8lkMplMr0+95pfojYb0Sf3dv/t3/9IzYzKZTCaT6S9PNnebTCaTyfTGks3dJpPJZDK9PvWaX6L/+3//72k8HlO5XP7Kid+gahUdKrouhYnkkxUBOlXJMVBI8wmRHenDdzTHe5wAFxF+v1GaiHRXq8xJ7ITAoe4yqOlLHckWxLyulIDXpcA9B8CmioDvNFLsM81xekW+YqnlgGdVBXZUN5YPDpKLGVHI6yIieq5Rhc/4932FzSsBz+rZOrOaLpUlN3N3zMzKGHhuWHbNE0P2lg8cqHEkG/T+CPsE32+5KJlKyMb9yCkvjjUzfzhltu0SMNEHU1mXV8p8/6frzNSq+5JBujfguvzIKfPdFNJU1AVywiLF3e8SP2sUn8yua3nmezdTyecuAod3o8D56ccyXdnhurwOkGDNFrwfcPuuOlwmbXRtB8D7BtZeP5Z1HiScLgGOZaDG+FTUEf8e+6G+/yDjcX3qPBLpHkyY5ZUn5HVpXjL00xrfrw28vlt1yVL7hgVutyKcWdBXPNJHI44nCZwx0FYstYOAnzuM+R5H+syHEuf9BmD39f2K8KwTaKepgszheQnrwFw8DFRgBSGDcDyVsSoF/iSyjgfA9FdoVyqnXN5GxN/Pu7Lszak8Z+AVaTQmcm4XgCm9UJAPTjOiifvV2bf/JszdS/kS5d3iOT4djmtkol/z10S69TK34+fayORU3EGMDQkyEuVzY+hAQ/hwZwRcUNWsuGZYLML5I6qb701kn3tFy55kbS65N/nexJ8105ZI13S4X1xv8DypueAYS/swOa6UZMLtKt+jETwzu47VuE6gPRxnaXZdLsj7YWlvVHnO251wxZyGsk4myNCEiU6fDYNZ3x1zut2hZH/iORknLs95yKElIuo4fF7FAvD4x4qvOIQ8IUs5DCWENEy4jGtlXj801Jke7YDv1xnz/Z5qyLq8XOKKuVzmeF5weX6eZuqMixzcG4qRU4MNOcFj4ljadzoyITRBDnj8k0yme+Tsza698BLNEzLWi3A/35ED55DGcH1M8xROuV5Ch9c6BcU694G/njjQX9S4xnGOdbRc4h9yipOLvXQCcQbPZCIi6mVcphNgBneTHZEuSnkvMc1x2xThzAIiok1itnCCMUydIeHCWsWDtV1Mcu1ZdLjPRjD+q3nZL304D+ok4rVdrHj6FZfjRKXA18gw700ey7zC2Q5hxH1sFEvGvZ9rzq496DvIIiYiSjMuL47dSMWCKvnknju94M+nvwlz98gZk+dMycvk64gqcZviWSd63sTpAlnHeXWWzHrp4r0nxn19fxyvuIdeVOu0AM4S2x3xDfAMFCK558J1rT5HYSHPfQy/46mYi3z+Rsbz+qW8XIdiOQ4neD6Y3mdwxTSAkVzyZLqXB/zzacCxr6C44AVYH+Nci3VZVXv/ty9wu38OeOsHKragcL2PDHQiogfuS7PrDGJpN9sT6Qpwbkkb8tcltQCD7JZT/k40lDF8AvEA48C1kmyb7/G+aXb9eMRl9NTeopE2Z9fX4R543lJB7RUquYvjz4Yv+28ezsIIYI+LZ/Y9uT/XXxn2lO1QnrNzGnE/QJ5+T72/wfcl+A5ILSnFuMax21brPuTD495Mr2XxPRK+F9C1hfcbwburKrwjrOflWhHPncP3ZGXFEk/gHJXrZe47o6lk8OO5OHj8kD4r6xl4B0TE/eMwUOctwHmImyWe21Z8+c5sb8wxCZn3+H0iol7Mbb0z4jJ2I1kvBxmfb+JnPMYvFTiveI4YEVE+5rGXwPkIy478z+UR9eki6Z3zIayNNytcr+tFNb+GKzTNQnqJvrL+XEz0paUl+v7v/376rf8/e28eI0l2nfeeiMyI3Lfa916mu2cnZ8ihyBEl88EiH/1ggfazIEGEZAmSYQOWQC38Q4YMW4ApWfQCy4JsWRJlQYIBW7IF27L1nmVb1sJn2tSQHJLDWXqWXqu79srKPTMiIyPi/dGcPN85XTnkjLhMU+cDGrhVeTPixl3OvTe6vt/92Mfo4ODgS3/BZDKZTCbT11U2d5tMJpPJdG/J5m6TyWQymd58el0v0S9fvkzvf//76d/9u39HZ86coXe+85309//+36dnn332q1U+k8lkMplMfwrZ3G0ymUwm070lm7tNJpPJZHrz6cvGuRARnTlzhj70oQ/Rhz70Iep0OvRf/st/of/0n/4T/eN//I9pbm6OPvCBD9AHPvABes973kMZZf+5F/R/b3SplA2pPZZ2VrQrRGBrGMby/yAOw9NtIZcq0sOCto4IbIzKvUMLFbZg7g3Z/lABey3iDYiIbg/YrtAEd4a2TCHhIAt2llJWGiAy8NkI7Cyv9KXFaQCokUfqbCupq8NvjgK2TZwA3qGgbEdPAIoCsSh7gWybweT0/wc6CKS9CBEzaAPPgCVGW8J3oC7jlO0250uyPRGDgro2kFbZ4xN+9hfa/HyIBSAiqkCdFYacfrgq7TbrJbbyYl/JKqvhs01p339V75yT15vLsXXr40dsl7k+kH3i0GUkSZpy2csOW4HdVLYLWoPbEbdnPa2KfIhBOBxBXw5lneP1ci7fq6WQIdtDiUh6VRlntj0WLZJ7gbSzLRZ4HDagKy7nZJ2/0uX4N3K4nsNE2o5uZF7mayeM8Zkj2X97gId5ant1mi5ALKl4so4efTv/1VKmwTHs+iekXWx3xP10BzBPsfLNIubmIGbLZFlZwgsR//wcOPQn6noL4LF7uMH3LSk0DqKFihm2dF2APk9EdHaZ7WKLObb/txW6B7ESKxA/xwkipGT/RVsfoiz02F30uH/gfbb70hK+WuRxvQZd8SCU990sTCghxQp5A/pGn7vrOYd8173L6t0Z8JjAj5YLsk+sF7iOr4M//MVAIiYQS3UCMUmggEjad9sTsOgGPDZ6Cs0wmHBASSEGNRRmbL3AP2cd7L8K3RGcm6YRbVHJyjkZ8UvnynxtRV+hQwil3Qn353wkY1UDPKdrJY47Gh82mpw+/zcDGUsRq4bz9SLMc3Eqn30BioTfb8thSMfQbIcjji2HsYwtWWgPxOHU1Twygflwo8ht3Ve8nxsR96vI4X4wTmWf6EPd9rp8vfmcWn/B9U9ibqheU7ZNewz9N+J4h3bpraJGFYClHtK4NtTC/j9Im+KzvIMIMy5PCZA+RESNhNcTFZ/7kZ+Rz+6H3PYaPYdC9M58ytfW2Jdn089M093xbf5O7qIsH9j/530eu3pOwPniTJnLikiD7b78Dq6RUIiGI7qDwnhV+HxhLNcZ4wnP11lA5u1nrol8C2l9mq44PJfpfukRX8OBtV7qyOfwYW12K+Axhdg+IqIArn/gMmqnlUosTZUkhmt6vQz3Hbeo+gegclzAy8SJrOMYyhDHnN52viDylTNsOQ9T3uuMk77IF7tvpZhOb8fXo2/0uXuB6uRRjl4jnFAbsEW3YoktzAfc9j2YyzSyEfebjRx/1vBlPqBmCQwK4hySsaxnxEPgWlHjwy6WOT4twDy5VZaIjwXAud1X4v3D59pyrT2X47X8rHITEV3v8X3LsGZIFPCgC/jV1Tz3XY2afabN+8MAYu5ttVdcL/I4X/B5fk0B3XOlq/aXeuHxRRXv2ity+ToJ19FtV8a09pj3qxnAlq3RQyJfJgV8qMN7Cb3+LkOM9AmxT3LuCQE/Ezv8WZLKtkbkypkyl2+sFkwe7A0erAPiAxAuWMdERA21TnhVffXeBJ8wD9fT77jWijz25sswDtuybZCCeqbIZRpO5HrpBNZf+LQa57IG6/M5eMabQ4Vsm4E6XFLrfbmm5PSKpOmJcV2HKasL9Zd1ZGGX4H1VDOjgui/nAg/64kZp9nvAusc1g+vfrEL3BICYwbGrkU0BtGkfsLGViXwdPIJ8uyNOB7FcU5bg3ZoD675hotayboc/A0xeNOZ8l/Jz4jvrgDB6KWVMWyuR64cMvMouw/iKX2PvjDFSY1knlAh052vpdf0lOqpWq9EHP/hB+q3f+i06OjqiX/mVX6E4jukHfuAHaHFxkf71v/7Xb/TSJpPJZDKZvgqyudtkMplMpntLNnebTCaTyfTm0Ov6S/RZ8jyP3ve+99H73vc++mf/7J/R5z73OZpMJl/6iyaTyWQymb4usrnbZDKZTKZ7SzZ3m0wmk8n09dMbeon+67/+61Qul+k7v/M7xe9/+7d/m4bDIX3/93//V6RwX2s98MARVX2Pio9Ju03wLFshn/kC2wmP1InEWbAKJmBr2CxK+9kJfC8CuweiXYiIntlfnKbRxvFgha9X92QZxgnYi8CNoE8nXivzh2t5tlNo+85RyM+EdjZti65kwV4EpzYnobRdjMDxBF+hY2Xh/t/H/OHb5/gZz5eUrazD9qBPHsMJ08oKhdbjzpjvtVnm52uF0r6B9bcPt22PpYVoHqx3bThtersvLSIr4GQqwcMXs7LO0b5fBJTFYl4iAxLoO88dsKURbT1E0vKE/SCjTnf+Qpv7Pdqs2iRtqiHYVvPZOqcBedNz27KsiF9J1qZpj2RZQ+K+0wo5XcnKUFV3+b7YtjtDWUc3AT2D1vHF8VmRbwNs1uOYPV1+RrZhd8x1dgHQBw/V5YnwT7fY8t8AG2LsSgvhRro8Tc/lETMk74vdeS/AuuD0maL8TgJV0XuWP+sq1FERThAPYRgejxRWAdowD3bukrLXh2Bnz4JNytc+NRDGk446Iz2EuLNUACtqKrE0FcARXYSTyo9G0gLbyKFNldu6kOFnKqqZEcd/D1xqUSpjRgJoJ8ThtBIZtxYTebr4q6oonNajc23qT/70lvBX9Y06d1/pDynrxAIjQSTjCcYajU/IZwAnBG2QVUukc3mOkYhBG49mYyR8sHvO5/A+cm4cg40Z+1hZoc5qHpfdhTIoioR43jGcbL9UkOMBLed43ytdecHtgMcUWh0PJRWMHMA7IW6uLh+XEHrXhFilUUrzEK7esbE/Te+dcFx99pZEls3BnLye57apqrn2pS7aVAGx5sjxug6oshLYYzWGDucLnDoOg9n9IwMomoJCY7XAVh7B/NoJ1LzpcAUiamOUyPkw7nL/nSSAUgHcX82T10Zr9u6AnyPzGj5WB0yuOcCe3Ckf37cKdt3QkV7qssP5cDqMlN2/6sEaFdZfR7Fct6wQr6cRtTOYyOsVJ9CXoM+6aq2SQnzHPqvn7hSM6j5YsFtjHHfS+rxPjMDJpVwPBYV5m0t5Hklhfj5QqJIwavNzAL5Gowq2s4yvmUu4vsaOHOSRw9cPwJqdIbkWyIH1G7F2TZL9co14HYRYGkfV+f10nsuXA0TdCOb7zJH4Tp8kTuhVhZOO+Dmb4T7hudwXU1VHrYjXlP1gl/Ol8gV2WowpSWeP+9erb9S5O0kTiimhhi/79mjCdYfIIB0jcxBoMbzfHsi63wt531yFNfCDamLC/ZcLYxenpd5EjnGM9XO52WiGENaHb29weXxXlnWzzjimxhLnc57fFPluDrnse4BiRBQpkVyr4HyqEWtDWGvj+4g1tb9pHHDcWchxGfTzYqi+NuDYgOv9kiKOdIeA/gSMT5XkmhnnmDWf55hi9IDId8vndcKIuF7ziUSL5CDYnzg8rn2ajf7sO1wvOlZ5gIcpw14lr/aUV3uA3oC5t6rWQZtlvh6uAecAa6cxLUt57Mss3X9xLsJ3BGuqPYOYY+7zR7wmuj1SeA343rkyY7zeUpN19Nk2d8YBhM+i2gfdD+jO+5Z4TXR8RY4HrNsq9Kuyuh7GCRzj2P+JiCLo0IiHOoLpMFAv15bhPc0mIHdLeTknPwjY6OsDrgf9XgbrEpHKh6HcmwQx91NHxC2FZYX3RvsjLrtG4+zAa8uXurwe3irK+PswoIU2Srj+leMmHvF7n6bDc3Lf5TF5OZRzqEDFudznj909ka+b8M8nDpdPrz13w/o0PZpw3wkUBrDjtL9sFNsbwrl89KMfpYWFhbt+v7S0RD/7sz/7Ri5pMplMJpPpqyibu00mk8lkurdkc7fJZDKZTG8evaGX6Nvb23Tu3Lm7fn/mzBna3t4+5Rsmk8lkMpm+nrK522QymUyme0s2d5tMJpPJ9ObRG8K5LC0t0Re+8AU6e/as+P0zzzxD8/Pzp3/pHlB7r0Cx51PvUP4Z/+d3tqZpNEbUPGnBxP+RQDPV/kjaVBHhgqiXuaK0EC/AZxU4obvdK0Ie+R20qV7v831zGWnpuFRjC8XyHFuSsllpA7u+y6flPgu4D7SBEMnnRTvKjrL59CP+rAoW4pNQ1iXau+dz/ExPzEsb6CagVa50Ob0zkvn26HCaRrvzsMPPp5EhC/qI6C+qPZZ1GYDFZg+saGPlry9kuXwrgN5YzsvrvWeZrS7nH2brUndX1uW1fS77y32uo9tDaUlqBlwOtAZdbkv7ZMXnHoyWvzmSeKMsvXWaPs6wjSZOuQ2LqbTRzIPtuJCZHXb8BK3ynM4pC5xEH7AFqJ/Kdh87PIbwdHeNaUF0TwlsdNo6jrqvzGNSnFRORJcqaD3kumiF0uKEp0JP4F4jfTI7BBdERbXBhue5qn+8OHfqdzQ2CsfrHFhZy8pOuExsSUSkQUqyrAXw+SMqIqeGE5IQbg/RfjbbKntrwPmuduUFbwyXpun7Stwn5nwZWzION/ZKntvwls/jU9sdEQ+F5UOUwJ18fG1w+9NSRqJncjOwSoGyE36l9Y06d7/iPEeukyVfWQjnUsavrTo8HrqRtA2+1OEGwbDdcGS71SBG5qAfjBNpccQujHiHmg9phZ45GJ1+Gryr7J3tiMuA+BX97SEgE8pggc+6OpZy+qU2X/AwHoh8VbBJRoApOCFp9c6PeUxslDgm5dUaZAfmqc749HUBEdHZ4umcXz/DZVgvymsfBoC5gbkM0VVEEg81mHDGfCD/4hNt1mgZ1vMDIrASiIsa+4TW8SHMUW4qn32duP9ulbj+9ax0dcA4C7zvRPUKvPoQMCZIINEoAESVHI05XiaqFGWYfwqA9bgvuSTy4feKLucrARqOiChKwL4bcl/W+DCcr1cK/FkwkOvu+6v8cwUu8dSx7OdDpzVN45omr6zB2O+dAVfgWxpqvQ/VeRuWCdh1so6ca6sJr7liROGpLZsDDYfr2oI3J/JloG3m/QvT9EStl9rpzjQduP2Z+TziZwxSxqL4qo7aLtclohQwLhMR1T1AKUSA01F/5oXjCBFEpYC/307luhYt3Y0cv4R2XdmPsg5fwwEEl+vIOq96PE8ueFyXYSrxQTFFlKRfOUb5N+rcfSdGJdRVSKMY2hoxbYh5IpJz7S6Mr8FEWvRLMAZWi9z2DYUZQ/IWxnoXYmlJrWXHsJhFLMUthY0bRNihZ8fzK31exywdcb3cGMpnH0xOXy9qRFUJyodbi6H6Pn7tIOQ6GgSykr7zHGPVPgnY2f1QVswsJGoJhtQ5GTLIg3F5vcPzzT69LPLVnfVpuh/x2EO8FBFRP+X99BhQpCVAghIRbWQ2uNwTjhOIriIiGsL1B3T6XEFE1HAYX4EN7LlVka8bc2xNIXbhWpOI6OEaV+YjNS7DLuAqcwrdUStEkA/aU7X7CPrswQj3IzJG9nDtCeiYGwN5vQXopvj+a1u9D8KxgijROU+Om0cvHvB32nyN5bys81HMnyE6ra6uV4B14AFgURBrQyT33Vg+7L9ban26CvialVWe86rvlh198jvcd64PeD48Gc/+2+YuPG5FYZAaHrwTgbZRS16xBsan1YhmjVh+VWO15n25C8hWGP5V1X9XYthLwb2agFs7cfYJFTv8wNWU48wcpImIQhiTOF+X0rrIFwH2sDnhgOSrOX4+nbtrzTNLb+gv0T/4wQ/Sj/zIj9Af/dEfURzHFMcx/eEf/iH96I/+KH33d3/3G7mkyWQymUymr6Js7jaZTCaT6d6Szd0mk8lkMr159Ib+Ev2nf/qn6caNG/Rt3/ZtlP3iX+8mSULf933fZ2w2k8lkMpnehLK522QymUyme0s2d5tMJpPJ9ObRG3qJ7vs+/dt/+2/pp3/6p+mZZ56hQqFAjz76KJ05c+YrXT6TyWQymUxfAdncbTKZTCbTvSWbu00mk8lkevPoDb1Ef1WXLl2iixcvEpFk892rOuiWaZDNURhLys1+wPAhZAtnHAklwm8h5/Ykkpywx2rMTFoFtlWzL3muq3PMUxoAC60dcDpVXNXPt5m79F93GDy0kJNlHU7q0/SDE+4GT/yILOv4XzOv6JUeXzuvmHDIWQsn/OEDVcmbTIAPfQL8qWEsWW/IQV4tAGN9IOvofmC7x8DdfmJesrcud3mhuT9kJlQ/Yu6TZjhtlYDTmuX08x2ZbxfAXsjaRm40EVEIldSLON+lsmRPVYvMgQM0I80/JllU/U8wAPBMyPXXjWRdIgP7Spc5TysFWUcbULXIVrtYlYzPQQTcfeCs9R1ua2SgExGdKfE1kAevsVtHwAVG9vRgIsFeyJvOufx8y4phXIkfmaaLr8Fib4V8/RHcqzeRvLNvXeZnX4YzC7ZbknOH9XcW+tFSXvadQ25qUReTVLY1suOK0CeQZ+wpdvKNHnNV8fyGYSz75QB+3gCGXnFJ5rs5BN4slFWzzs+V+EN8jBtDmRGvUcxgn5DPgajoXoT9Q963C3F2BGzxIJlNLWsUuAHWC9y2J2PZV4bAmEPusWbrI9e+7J1+pgIR0bUej8MOnOtwpiy/0x/71Fd8y6+EvtHm7q30AcpSjnKpbDcPAiieeaF5qdcC5vouZniOqnty3uyK8zA4XcrKPuYL5v3p7O9IsQr78Av8Tqg4+Tg8ahDCRxNZBg/68AAghCehnB+Q23otYfbkIjVEvnNlWINA4ElC2T99iMddOANFs9jrUHbkoq7mJZxxzuefX9xnVvkJjJux5j5CN2hDXDgKFQsUusFSgT/bKMl2bwEe8RAClz4yY4wcb+CvTkj2N2SiOynXl6vWcyt5rvNVmDevdOX1kH3edZjV7aXyOQIYH3EKHGo4NyVMZGX60G7II2/HgcoHMTLhCvNeY5uRwARRUQEd+/3BhNcWJZLPlIczOJD9u5RTZ77A1/Bcm7wry1eEfo9Mzs1kXeRr+Hz9GDqCXhtji97qc7vhmSM1X5ahg0x/+L0+Twb7y3zMa67QfUDk62aOpul6ugzXlkzZicMLjVzKa7ZCKs98yKT8kIcu18N8siTyNV0+iygD7dYF7jwR0VHEAeDAuT5Nu6rv4Jk3uwM8l4VVStW5MzAGCsRtOMrIl89tl+uoHd+apuNEsVGhmzrwg+fIdXKR6nfV71dC32hz94JXIM/N0WCi+fFct0WIl6HizN8Y8s94LgPGBSKiMxXuS4/X+Ts9NW82gUmM/OA+vBc4VvMIzuW3+nxtPeeVIQbh2Wj6HCs8jgvPQGkrZnMLpkqMO8vqPC9cl+JYWc7LWB8lfP28y5+91KqLfOuw9/mmJX5H8Jkjyeb/fIvLgYzlMey19fpmBYbR26PVafppeewU5RKYGzO8/3KcZZFvFJ+dpjtw/khOzY1Zcb4E95UayXjST7iAkcuxAc/fIiIKiZnLeEaYZk1XXOSg832LWdXHoOHW6zzHIxM9o/aA2LNvQR/TOwrcYrQB/t1Woa8bcfmwi+m9WAfC3jGw9QuuZpPzRfBp1wtyDVi6xPma/5OvV8zIusRl+AD6VUUtQXbgnR7um7eKsnwYC2BKpotlHuPv3toV31l6N5yj8PDZaTp+dkfke6nJ/XQIZe2qKQPXqIfwg14vZYDrjWGnnJWtXYGfMd2NZH8rZnFdxX2snpt95t4+wPX1+68NOEjCd3k8eAG0Lal3V4TrSP5Mnw3zKD02TfdgvtbnZ2DH92F+CBLZj+qZvIjPr6U3xEQnIvq1X/s1euSRRyifz1M+n6dHHnmE/uW//Jdv9HImk8lkMpm+yrK522QymUyme0s2d5tMJpPJ9ObQG/pL9J/6qZ+in/u5n6MPfehD9OSTTxIR0Sc/+Un68R//cdre3qaPfOQjX9FCmkwmk8lk+tPJ5m6TyWQyme4t2dxtMplMJtObR2/oJfov/dIv0a/+6q/SBz/4wenvPvCBD9Bb3vIW+tCHPnTPTubnN5pU9Xw6PpZIiGqOrSVjQJVc68t8HbANjwEhUM5Ii0g+y9YBB6wkGWV1eWFvcZreBpzLayESPnPMFtvn6PPT9FIgrYudW2yL/nybbaAb//qmyFdbBT/PDU5WpROKPHA+vNBha9Wf29gX+fDZP93k+z5QldaJNlhLEGXxuba0lSKa4n0Pc9mv3lwQ+XJgH9kF39sOICpWpRuTqlnlEX/1Wsqui7ZcTJfU6EK7HdptNJKnM+BnXIjZItb+grze8YALjJaSM0VpTZn3+fotsFYpEoWwEZWElUxZxOAjjwAF4rKV71uXteWa0020xo+0nfB0dEVdoXbQJtkFL79G8nhg1b4GD7gfSJ/aZpHrvAYsgBUVIh+q8vduw/j/XFt2HsQLYFUqdyetF/l5BxNAOPiyk7UAmbDdRws851nMyWevwKA8DNnmuhvIttkbcb6LgBa6vyrtiXNgXz8GVIRuMkTZ1DwevB1lm73e5/vuDLhtYoWyWcrDGAWWjR6vJbCmnYAF9sWetHTFYAM9X+HvXChx7BwrBIzr8M8ZaMS8GkR1n8tag3F3MNKIBLToch2tJrLdV2o96kXS2vin0Tfq3F0gj7LkCdwEkcQdNGB8dCLpmbztXpmmK8mj0/R8To5/7OuIVekrNgu6TB0oQ2vM9x0ksl0RyZGkHFuWCrIM91f44gchP9MeSeGz9xzG1VyXFA46dnmOjgHnMKeQXF14RjEHOLLPRinn64D9PE7lmLrAtCnaAvvusUIp7QVcF2i9bfhcD3VP1v+tIdqEuQwaoYM/rwM2bqgs5kgGmMudjpEgItoLOWYiyuK1lAerd5t64rNkhAg4Dnjdiew7IeB6eg7b613VNnHKdZYLuZ7zLucrKZvwCDo6xua20xX5agnMoQ5b4K85N0Q+tOjOJSvTtBdIfBBi2tC+W1DopCDmRsS2uVST+VYAXfAJoAzq+eY+uo/LBHHfV/fFNdL+CMZ1pGIGpBH/NYJyN0MZj265jDTJAwIxH0sEHyqAPuAqk3EOrtF32lzWtCnyYYdGPMnd6Equ5zBlbFFOIU1ayS06TVV3VfzcBgRR0alP09hXiIia7jHfK+a+g0gejV8IUq7beob7aJiodQGMDQcGvKcs5llAIiH+RucrJ1WapDZ3fyllnFf/aWwZ7I0hhsdqvyTmObDy9xOZrxryPNIcc/89UYgUxCfgGA9gTugp5AIuA3PQd1I1QyB+BfEJiJO8c30uxBYwYCpq3z2ACQyW0FRQKBCMO2UITxW1x41T59T0jtozfKFTn6Y3ClzWoULh4ryJ7XsCSJOqYm1sFXhMVSHm7gznRL4lwJFulSFOK74C7nmvdPlDXKcQyb0n9imcG4mIcoCLKiTnpukuyf0SxuAcoKw6E9l5ihC7jkPuIJ4rY30Z6uJqsz5N4/uDjkIHnynyeHjHHF/76kBe+wTC7Bq8uBirzV0b+mkPulhFVTrm2yryZ39h80Dk6wKaGOfNh8/KtdPwFUCGAFb45lC+D7oGhVqCqehoLOtlCHvtC4BmqWTlu5OjkK+PyyLEhX5hVyLM3vYUr6cbAc9/k6Zsd8SM4jsRRa4V79oEVnCk+lGW2/Rs+fT3CkRE3ciDz/j3GomIMQTjjKPi72oR97Jcf7jXISKqeFy+kX7IL6qcltRv9M93lHNkzFjM87XLgKfGtSGRjEG4zm2o63XGsh+8lt4QziWKInriiSfu+v3b3/52mtzFNTOZTCaTyfT1ls3dJpPJZDLdW7K522QymUymN4/e0Ev0v/pX/yr90i/90l2//9jHPkbf8z3f86culMlkMplMpq+sbO42mUwmk+neks3dJpPJZDK9efSGcC5Edw44+e///b/Tu971LiIieuqpp2h7e5u+7/u+jz784Q9P8/3cz/3cn76UXyMlkUsxuVRSJwO3R+wL6YPlBLEFRPIUbbRGoN2JiCiXBcQBoDuaI2lNOQjYooDurBZgEY4CWYYh2AcrLuNg5hNpzfZm4DCevr0i8r33/PY0fbHCts18Rlomd0dsEUEHUHcoLY6fP2HUSx9sJmNlvWtDE5QBeXG+JOtyAt/73S+wtarmyXwetBWeKr2ikBCoK32u5y0ow/tXpG0LUR7Xe5zv9kC2DZ62vVbk9ANVaeGer7HX+MYLbHEOJnK43hxwG6D1LlR1WcxwOe6v8cNfbkuryyud0y0sw1j+fgwnGSeE9nq00egy8L0OASdweSTtxGVCRA1fw1e2dA+s3mjR0biUAKzoN5KjaVrbhM86Z6fpd87z+FzOy1iAJ40/2+G+vTeUbb2QBysUVF+oqniS4hjga9R9+SCIr0GcUANOyl7MSftUCZgS+8FsG/h95dOtVc91pJUK49suDIET9VDLgJ/YghO59RhPwK4bQB/rpJI3UU84ZiAm42ZfZKN5qHMk2+QVfglcoNQGK+9nxtz3RqqdjgOMH/wdjQ9CxMycz/UaKptrCaxt+J0LZdmGhUJEk6zyDv8p9Y04d1+qFinn5ul6T47r+C7gxh3N5+R4uBA8NE0XM9x/tcUcySoYz/GEeiKi/RHHjSDleDKEuNN12+I7uZQ7ggfYuFYox+HiAl+7lOV8t5RFF1FbWw7P63gqPRFREMu1wavSts0TQE64UC/9VNZ5AH3dhXiur/dih3/eHnAsVU50eqTOg7GcxTHFGV/uyUHei7jd5yFGKlKJsMffX+G4s6fiZQ9s0jVYpu2NZP9CJI+X8jOFzkjky4K9O3IAQUByPjyG63XjrWm6qJAVOG+24NpjZTHvODwH1hPm6ZSc2XgurDJELJRSWYZjwHqMHK7LWjIv8u0DqmTHuTxNt1OJ4KtHvH6tA45Ez/EoRCzp58A1yMUqrqVkW4/hIieA0no2eVbkq424fHvJi9P0oP2kyFdxuR8cJhKB86pyqgyew99xAYOkbcyIBQtjjjMjZyDyTSDuTGC8hnFH5HNgnZXP8LyL+BYiIgd6xYS4rfOqT5ylx6bpAuBOEtU4fShvN5UISFQOMEH1LF8PcRU3QvlMRbgvxvNRquZa4oG9QuenaY1ziGBM4njYKMq9zlLBoTAJ6OnTm/wN6Rtx7g7ihOL07jUo9u0A9hwaH1b1ZiMqUYt5nAf4frgGJ5LoROwvuNQrqDUl7vcR8adRnYi5wLJqnBPGifyIMSbHCsV2CFhKRBo1FIYO9yM1QJ9tFOQFd+AdRAHm2rovnQ7X+9zXPxtwxSBCkogonlEviEdcKchYoPcJ02dQa7a1Iu4B+fdNuRwR8XwC/SxK5SIfY9Jkxhi/cy++2XqB591zrox9OE8hKmYwkdfrAVoQ11LHgVysnKvw81/p832faSGyUNbdtyzznBxDvb7Uk3Mt1h/QTMmb0RZEEivcUnvAAJ4xD+vN9Yflu451wNc98ynGovz6Z8+LfK9ADMWyllR/u9HnukxgHTSv6gWnnx6stRMVC5CyhBgkxBb976Z8iXR9wOjkx/f5+XyFdcY9YR26dqzWLUs5xChyWdsKUTOBh0JE61Je919Ot6B9t/uyDbvQL0WfL8l5bt7n61+BmN2cyDVvDzAyuG/uQT0sqDEUJrj25wZYyGs8D5dhBTrw7cHsvXMB+qXGXxERjRVadZbe0Ev05557jt72trcREdHVq1eJiGhhYYEWFhboueeem+ZznNeY0Uwmk8lkMn3NZHO3yWQymUz3lmzuNplMJpPpzaM39BL9j/7oj77S5TCZTCaTyfRVlM3dJpPJZDLdW7K522QymUymN49e10v0H/zBH/ySeRzHoV/7tV97wwUymUwmk8n0lZPN3SaTyWQy3VuyudtkMplMpjefXtdL9N/4jd+gM2fO0OOPP05pejpr9F5Wu1ugSdan44Hk8hyFzABC9nc+I+sgAuQR8tM0o/qzzfo0jWylicLC1YFdphnfr6rqzeb2pMD1iuj072u90JO8o7XPMDsL+W7vuXhb5BuHXI7hiAFPGVdxoKAunpxnXtSu4sG3x8AMznPZ37lyJPJ9ap+5lH+wz9d+pKH4aXlmuj1e5/ueW2pN07/z8qb4zpUuX+8tNS5DayyHDfLIm8AG03w+xIYhD/7GQDJvg5jrsppjJudIMdGf7wJjyuO+OOfJOkce/Lx/OpOPiGgCY7o74fs2HcmYrBOwVF2u58MAyhp74jtId3s+ujVNHznXRb4lujBNj4EjqxnmS/HaNL2WrdIs7Yd8jcBl7uZcsijyITNwzmeO1kZFMtw6ETPqsS6R0U4kxzLWczeSMWNWGG0GOrZA+wLw+77y7HHdHHMb+NAHFn35narHY+MEvvNHe7IMuwHz9TCehCS58WWPY0YzhLMXFJ5sq8Sf5TM8/j/Tktd7acRjtArs31JG9rESsDFX89wA+4otiO1RyfIzXu/jOJZjCLmKyF/V3N0AqrYEDLxH6jLjIcTLDjzuQSDHeBhkKYzu5oW+Xn2jz90XyikVMik5JOcv5AQiL9JXwa+RY+YynhGSz2qW4ulna5Q0bBtwgE2nPU3HxIMgm85efnWJY9VgItcjyG186yIztCtZyTZ3gR88QEZ4Xj5TccAsyl7EHbii1hbYDQcTjhk+zV6DIMvWT2Qd4RkQ1wZcYet5+by3hnz9OiysPIe/r8dhH573CA440CzFSzB1HIU8l/UmsqzIN32+xde7Pm6LfBifCunyNP2i85zI1445pvku972cUxb5zqTMBh1DzA00Yz3lNQTy1vvAKSciKjkcm0cQtw9wwpL4W4r0L74oV/F+j1xmWSPff5Hq8osJn12z51ydpluTGyJbx9nhH5zHpslMKNumCCzgGNo3p7rlAcTcaz1gzY/lfIg0DIyVNZJrhnPu6jTdc064fKqfV4Dl6cX1U++jubvFhLmqLWjDmORckAFGat/hfO34lsgXxgyV9Vxum3ymIfIVHf65AH0qcuScnIHYFcN3vFSv+2BOdThWdR25js9C3K7T+jTdUDz9sz7HuEtwvs8OnD/UH8s1Wx3ODxgnXH833WsiH+6XkgTOH4jlnFLL8M+PNPh5v2VB8qWzTkqDSUgkl7ivW9/oc/dJPKJsktCKL/dB87DORb6xroJVOF8Kp+uBCls4Rd8a8g/IQCciykI+XM/J8zPkYvb2iMdDBHvcu8Y1XBzLmnXlfrUD55vcCnlcFx2Zr5A5fQ0xVmDlzph/xmd/Yk4+xzk4dGipwmuQXiDHwNMt/vlalyu6kFWxGX4ewAICj2XpRvI7eZhsT8b82aJ8RSBOu+nBY9zsyWdCrjJ+R+8fIogNOL+eJDLmLmZ4jj5T5vJtFmW+Tx3DuViCES7nzQjYy5MJnJWizglA7n4P3h9koL7UEU3kZ+GsqRG3mT7zCfs27mubY3lBZIvj2WZ6vzoP87Dn8PX6t2aveW8NOU4j05tInjEk+NyubEPkr4+gzhN1xlgN3pdgXXoKhbWQ43x4BuIQzuPJqTrHz35vj+crXecPVvkXa/COq6/Oz3oYzshbynEbPnUiY8HlNvf7xTzXS1ntYXAfjuGprjrPAOolSHiM3+jJOu9ABeDZZl1Hrj1xb4tjrwJpTx14g/uqtSKc8VSS/e0ynK+0P+SyrhRkf8OzCMuw93+xK693HEYUpV/eWWSv6yX63/ybf5N+8zd/k65fv04/8AM/QN/7vd9Lc3NzX/qLJpPJZDKZvi6yudtkMplMpntLNnebTCaTyfTmk/uls7B+8Rd/kfb29ugnfuIn6Hd/93dpc3OTvuu7vov+23/7b9+Q/0NuMplMJtO9Lpu7TSaTyWS6t2Rzt8lkMplMbz697oNFc7kcffCDH6QPfvCDdPPmTfqN3/gN+qEf+iGaTCb0/PPPU7lc/tIXeZPqdrdCpWyOLsy3xO8XouE0nW+zPeOFnvQXIYVgHtxPL3S9mflQBWU/zTj8fxzockC7yDvn5cXe1mCr5gvdi9P0obLblLKnW43aqmy3ADWylGf/Tn5eeuXqD/N9h59lG0e2LO/719+/O00ncLPP/w9p2wzA0nIBbGVpKu0el6ps3TyJ6tP0gi+tGGjFaQKe58Z1to52lYUbMShXBzxUtGUKsT5oibk1kAiSC1XuLw9Xvzy8TggIl5Inn2mjwDappRy3x3xO3vepEx6TWPYLVVmXL3b4ObIOWL8iaa8vONyf0a67R2xpvjzZF98JUkbCdEO2abuODEE+YBU2Eka7LGWk1dOHB3mt/wkMwYqO1uALBYmAqftot+O2WV6T1uABYIZ6EV9jR7rrxViueDgGZJ2j63IscAnyeutgo7tU5g/LYNe7MpD2LkSpvGOO+8Si6h/diNvzEKzy7bEMBimYIRHhUiIZB1cLfF+sV/3sKERFlJRlNQfW8XIWrV+y5Q9HfBEfPpNlkPgltPKh/V+jjrIQgNGWuh/Iutwscl3EUO6VvLKECqQO31jRfuh3r63TKJb3eKP6Rp67P3WUkucmdDSWlvo5j+NYAbAKA41VgjTOtTrWu2D3RHt4Wa2k5nLcT3sB46/2XY59qUIzlNL6NN0gji0bJVmIhTz3h0qZn/fPvW0g8rX/KyMhPnHE5QnU1IMIFw/WHDV/tpXXc7nPDibS6v0CLHD2E4773kTiZgpZtPxyuhnKuFMcog0ZrfJgN1XPhLbQHrxoKnlyYJ/ArY4A96HRU1gV2GoFkrFqAp/eV+I5azC4JPLtZpjx0Bxf4fREzjdhgddSVYexO7VE/iXqvsvrql7KqAzPkXN3HTBmHZfn6wHxmtdRM+r9yUP8fY+fdyeSdl0PkBy1lBEfGLOJiBYyXPaViPtE35Xt3nIYQYIDtKUQNX1Axy1BH9MxHOdDxHogcohIjoEBzHMdhSDpxvwc3YTH9dWsvN6a+/g03YA+fxhwJ0O7NBHRCLBPQ7BFJ8riXyYuQz3lmHHkynbvBox3GSQH07SXORT5+l59ms5lOG4VANlCRORDv+onfI3r2GZEtJE+ME270K+yCru1CviahQxfu6LwS482TkdHYfxeSOVeIpfh7yDaYeLIebVM/L2my3VEybLIt+bxXuctNW63xby83iePq3fFpTeqb+S5u+j45Dm+2JMSEbXC0/EfJYUMwT0rUjxDVfcvAu4AMStZhRBARCqiNxC5UFFjvALj+jw0RSeajZFApGdN4SZegql8RDzHn81VRL4GoBSKsAbReDMU7k2+0Jb7oIdrHGvq8/ze4/ZVmQ/nw7XS7HVVGJ9eEFzv3+ipdVAW4zT/fm8o6xz7hA9tmJC85yDlDrLscWzpTOQkj/PKCaDJKiTn2nzMnawZ4twm+2UX1lWIiillT49hRPKdz1Len5kPtQ6hfiUn66g15P3ITcAU6/6B7zDwXULtrvUSlx3f0Sypl1c4piJYp+0fy/57tcv96uaQ+xGiWIgkdgSxs3tqo7wA5SjC2F1W+6/7K9y3rw84nrcVWqiu0Livas7HNbNCJ0W4luW0Xo8gxgjv81hdrm989/RJRGOoRjHXhQP4Kz0Ex6c/0l3xdwmQMC14L3AQyD1WB/DGuN7PqTke0VOIacFY0InkmKwAohVRuIeBLCtiO4eAlEkUanKrCLgqQCWfjAsi32jiiTXia+l1/SX6XV92XXIch9I0pTj+Cq0WlHZ2duh7v/d7aX5+ngqFAj366KP0mc98Zvp5mqb0Uz/1U7S6ukqFQoHe+9730iuvvPJVKYvJZDKZTPe6bO42mUwmk+neks3dJpPJZDJ9/fW6X6KHYUi/+Zu/Se973/vo0qVL9Oyzz9I//+f/nLa3t7/i/xvearXo3e9+N3meR7/3e79HL7zwAv2Tf/JPqNHgv4r4R//oH9Ev/MIv0C//8i/TU089RaVSid7//vdToP63xGQymUymP6uyudtkMplMpntLNnebTCaTyfTm0uvCufzQD/0Q/dZv/RZtbm7SD/7gD9Jv/uZv0sLCwlerbPQP/+E/pM3NTfr1X//16e/OnTs3TadpSj//8z9Pf+fv/B36S3/pLxER0b/6V/+KlpeX6Xd+53fou7/7u0+9bhiGFIZsvet271gQP9cuUD6Tp8+35Z/2P1jhvCdgXdgZSksBusJWwPq1O5KWggisBzE4BvTJwOiMQvvNtT7/38dNhXDAU5abYPfwlWUNXZJdsMP1lB3rU3AC8HsW+Zl2XpH2rnMbvHgCRyiNDuT/04yOOV+7iacxSyREZ8IFbIVchoWytKwvVvnn94DFbm8g7aw3wdb0co/LBE40OlHUhDb4XqKEv7NRUicIwzXwpPiMI+0siIuIZ7QtEdHtIfe/+2ts716qS6t3Gax4myWuh2dasm3+cA/s/4A0+KZF2S/PVcBqPOJCZRxpwUIrjjg1G/AVGUdauMcxW5QycBJ92VsR+RDh8mCJrdnKwUkn4JNtRlyGeirHw9Bh21YD7M5LBXlBxGusVaCsOWnpiRNuX7R+acsUusyW8zgOZb5n2FFPAVxEl68BSJIAyrAP4/+q7B6khjxfqyA3OkWPC/tcl8dJMSP7RxWskP2I+7a2w6JtbavA9qyGJ6/3Ug9s1vDsa0U5bvDqS4CKeUYN2M9PGJFQDrjvPFGSfcyHYvRgjGMb6vqvQ7fah5Pj5xP5THOAvMBqebErr3ceMFcT6GInoazLUexQmMxoyNehb/S5+/JkhzKOT44j26MYMwID7eGummuTFOMdIkNmPyOO8aoMd/QIYAeybY6f2dHmND1ROJe5DMZ97nDfuiBZUZ0x3+zj1zam6bU9Oa57EJ/wMfpqjm/H/L35LJfhtXC7GMcG6rN+wvE4gPibceoi31aZ62hhwnFHW3kxnqBdvwgW5J2RbM8A2uZSjeurrhzS+1C1eN/BRLbNRul0i/+iJ9eKbZiL8BqrWfmiK50wvqIPSI1JLNswSrn+usT5Aldafg+DF6ZpRKQ5jow7GZ/rArEtJeIXZH4qnwlxLGWY+3OR3D4sJzzWMnjtrF5n8Gd9sD7vq3XyCnE5BhOen4/jIc3SETR8IycH5ZUOf5aB4Fz3ZD6BeoB1n6ewZSNAvWDcyZC83rWALc4Pl7ieS7AWa43lGEdM3mLKcVrHjPaEY6EH1uUH0sdFvlbxvmn6xGH0zyCWiBpcs1ZdnjdjkjbrQdqcppOU67XiLop8qy7Pw72Y27OcyjXl+TyvzRC5VMzotfbpOAyMR5WMHOQdwKENidN6bIyJ22AhWZ2mz/l1kW8BNk//3xGX53duSeRgOxrRJP3To9i+0efuxZxPvutTP1J9G2Jpw+d+qbEI3ow/BcwpagZiR3Buy8xaKKtrI+7TUa9OylkuewqxD9GrRERtwLvg41bU+mE5y30pUvEYhUXHuU1jWXF84F4qUmvLqz2+b26Hx/VWoyPyzXdxf80XvK7m7pOQr5EH/krNx/29dFM8cwKYtgnue2S+CNBWeVc1NpYV1jSIAjkEhBkRUQbi51rC/T1wZGzec/em6WGXUS+e6hMRoEQv5OvTtF5T5gA9GYCzJFAZsb9g7OtBaK5r/Aqga8cwlz2ocLIZQJIcwzuutsIRdaPTY+6Zssy3APvpB+Adzf5Qvpe5OuCOfx32r81QxgJE9IyhH4QK71uANa8n5nj5vG3Asp4p8prr2kCOtRtD7hNlmHtGsFZpKFwo1hEilnTcwi3ImRL3sSd/ROYbf5rnaPdTHPfvr9VFPh8waEjXKam3vPjuD+PEQI1dfC81BNwcIu6IiG4Ro+y6Dq9RG7Qm8m04vBaoQmAdwn1vjuVu4hj25zfH/J1GX861WLVFwMYs5WWlX+nzw7/S57bWM0Ap69AMks9del0v0X/5l3+Ztra26Pz58/Txj3+cPv7xj5+a7z/8h//wei47U//5P/9nev/730/f+Z3fSR//+MdpfX2dfuiHfoj++l//60REdP36ddrf36f3vve90+/UajV65zvfSZ/85CdnTuYf/ehH6e/9vb/3FSmjyWQymUxvZtncbTKZTCbTvSWbu00mk8lkevPpdb1E/77v+z5ynNn/a/uV1rVr1+iXfumX6MMf/jD97b/9t+nTn/40/ciP/Aj5vk/f//3fT/v7dw5/WF6WB78sLy9PPztNP/mTP0kf/vCHpz93u13a3Nycmd9kMplMpntVNnebTCaTyXRvyeZuk8lkMpnefHpdL9F/4zd+46tUjNOVJAk98cQT9LM/+7NERPT444/Tc889R7/8y79M3//93/+Gr5vL5SiXy33pjCaTyWQy3eOyudtkMplMpntLNnebTCaTyfTm0+t6if611urqKj300EPidw8++CD9+3//74mIaGXlDrPv4OCAVleZYXdwcECPPfbY677fy52UfDchjQJFXvdWiT9tKL4mIqw8YEwt5yVcB5max+HsvzAYAvfqKGCWz+6Q+U7I5yOS7NOjEeer+jJfDsqAXDXNI60AY2s/4Afe2ZPsw/z/Yk5YbYEhXZ2mZEwl6enPu6o4zcie7044/dkDyQLcDZipda7IrKaHFk5Evv8N/PWXAAS1UeLvawZZexxBmn8/iiXAbrmALHFua+R4EhG5wG1tBsAMVUxOxJptAdCquiz5is4NTu8Ca+wkmgEJJKKrE+ZrTQ5lXV4AsG8pezpzjYgoAJA/cj3nI+ZcJons813gyEUJ87/Opg+LfG8Hztc8rLc/35SFQC7qscscrjCSfx3jAZ/UA25pIJuG5nxmfr3S4jIcDSR7CxmC+MdBulfP57gvIRdxT7F7+xEXBPsLcm3vlJe/l6bcvlUf+c2y/yJn/Pf3eexWsvKZLiwwG3CzwG14qECNV7vcBjdSbs+liexHk5S/1weO5EEo++XOgPvICMoaqb7TmnBs6EbczweJZLMhJzh2uD13hnM0S8kM8HMuIzd7RRgPVRig2M53xD9jXF1Qe8eCy884B/PIsx0Zp1thTOPky4SzvYn0tZ67M+RRhjyqppI9nYV5DpmrK0U5vrB9W3CWyLEKFGPop8jXXC/JsYLrhLPAgD5b4fF11zkKEOI2i3Buipozn+nwnLo94HxVT3KG56BvYj8P1Y3XfI4HGNN0+WA5Qdv908cuEVEeYtdbcsxsxzM3iORZEVlYL52MZZw4hKWB53K+lTxXWD4j23MPsNm3Ic7sqXNsxPk08Bh1dXjFAtQlssSv9RQPE7jK/XD2IXvIT33Aedc03cm3Rb6Rw4zIMfFDadZyNOF8GZeDjeY+4/eQcx0Sx8719NLMcuNZAsu+XNt1IlgvAcu9NZZt8xaPAx6eJaDPg2iNgWEO9z2Xl/0cWbGDCX/n5Y5cMyBnPIXx4KsDV/BeUcLPeKTYxNiG8xlm6Lok+/mhywzy++L6NF2A2LThy2fCPcII2PrDWMajYcLPiOubdXWuSAGY/AdDZte/kh6KfBg/PTjv43nn0yJfEHenad/l79SSeXm9PPB1gfefUXXUDHku78D4X8jLvnNzwJ+V4SOcIQ9ieThM121P01k4t6eSyrJ6xHW2nJHtgdqFwzBOotN560REE4ppovix94K+1nP3q8qov7ZfzvO+ez7P7V5Q+GuM201oAs0CXgQ+7grsyTcKst12RhyfrvQx1vP3V4uyELhXx3XfcCKyic9w667PxUJ+O65bdD7cs37igPvaYdIV+S7keA38xAKOIbm2xPOWXFjLLj8gz6H4gHd7mv5/rvIcf10OPSrAeRgY7wCPTlU11w5mcNCDRFYmxjvsO3q93PC5LnHt403kmq1KvLfw4dq3nJdFvmHM50EEGT43IiEZm3243tqEz3zQ+zT88QA47YOx3KcdNbmetgp8bay/ttr7e2KfAXOFK8vgQFvfgvPjwljPjZzGM4EWS7LOz5d5HbRS5bXFddhbE8l97TwcOeKr87j68Fx4Dpvu5yV4V7QIMUOz/+twDtijmzwHfu7ZsyIfrnMxBuE4DNS5B6ilfDrzMzzTZz7Pc2Py5LtEPn/l5jR9bv/yNH1b8eV9OGtuAPEIefdEMu6MoOz6TED8VgbWTqGa5yL42XegXyY1kQ/HZQRrC6zXc7A+IiK6OeazGPoOB5cGybGxQ8fT9P0Oj8mx2jrjdm4QYf3rMyOdL9v9Nftt25tA7373u+mll14Sv3v55ZfpzJk7hzKdO3eOVlZW6A/+4A+mn3e7XXrqqafoySef/JqW1WQymUwmk83dJpPJZDLda7K522QymUymL6039V+i//iP/zh98zd/M/3sz/4sfdd3fRd96lOfoo997GP0sY99jIiIHMehH/uxH6Of+ZmfoYsXL9K5c+fo7/7dv0tra2v0l//yX/76Ft5kMplMpj+DsrnbZDKZTKZ7SzZ3m0wmk8n0pfWmfon+jne8g/7jf/yP9JM/+ZP0kY98hM6dO0c///M/T9/zPd8zzfMTP/ETNBgM6G/8jb9B7XabvuVbvoX+63/9r5QHK9iXq4MwpKxDNO9LTou2BLyqXEbaJI4C/vP/l/tctasK57Lg889xyrYGfZ+6z9fPgQUDv1OVjiRCByvaERBBQCQtsF24cSM3275ze8T3dZXT4RO32ULh7XK5L3dlAdFGtwp27OtDWedo0esBEgJtKkQSNTKYsB3Td+siHzrGyh4/B9r15xQaZ5JyWW8DjqQXSHyFQ2z9PAjYlpN35fAqgW0Y6xytu0REZUB0HAT8TM3b0sKSy/A1ah5XxJp6jobPbRCHbLHRtjfEBHXHnA5TaVM7W+KxNQCrse/y82Vj+ewe2HxKPtffiittsziiXurwfT3V4WoplyFJGC20npN1lM+APRmuoa2eN6H/PcOOOlotyjhyqaz8ma/etyhjAfbt5zp87UiN8WWwhe4CmgHr/4443wAu8nyPbZZ5R2OBuD3QQvjplqyjkzH0D0BH1FRsqcC4qUy4H1WzcuyGUPT9EMqtxu4xWLjRlt5ypEVvl16cpntjxv/4JAu4kVyYpk9cxhYdJ32Rb+jwzxm4BuINvGBVfGe5wM/oQSzpq2dCu9hWEcenbPgjqBe8xt5QtnszjChKFU/pHtDXeu5epgZ5lKOA5PgMAH9QBOTHWPmii4T2R/59L5LtsRNzcBg7HOvHvRWRrwA4IJy752FOV2GfgEwm7Lafa8v6QKwKxjTElhARdSPEV/Dv9dyN32rARLlakBlLWawzzteLZF1WPH72zZIDv5f55jx+kKw72/Yap2jlRdQG2lfldwYRIBcAB4UW8DvX5qdf9LmeFeFDjPlFsOhe7cpy9yF2IVLKVyiQFWKUxJLgBMv5EHFfiDAbOtJe380xMmQUsrU1Jdkn+rHEd7yq8YStsklOxpv65Jum6fkcj6GaWisGMdetB/M/rguIJNJvNc/PN1JT3nbCz4EolkvuksiHY7yTcFs3SNZ5AwYiDv+jQCI35nPcr5YK/BydvkTFIbboNmJ8VB0XM9zWLUAErsCcMp+VYw3Xv60Q9wuyv8XQvgHME74rcS6InnRg2+eM5DMhxu/GkPtYb7wn8hU9XsPNuXw4ZDaR4+swZKv3K84z8Byyj805fI2FBNo3kG14DISkzTIHzGPoPC2Y+4mIvJTrYpX42nVPrlsQBYRrp+ZY2tcR2RQ6/FkjrYp8c9k8RUlIT9O9pa/13F3LOZRzXeqreWStyH0RtzQtRchpA35NzFDKjo/764UcIvlk+94GnAsOywAu/slDOeHguFwvcr9s5GQZEEUDdMS7Yl8ASBO990Fhv99JGTMSO3J8eS7HoBK8tyhm5PyAlJoEYu6VzzVEPsQx5GHuxr31nXx8/Q7sacYJj695tV/FFYme11FhCrgvwHVUsnJzh+jaLNRlI5Lov32C+oO9AOJbiIgGIcf3cZb3Esu+RJPmYM7HvbanEGsx9NpiyvvkiZ67Ae12HALGtsLxbdFX+4cx92VEa6rphgoZ/t5mkYPsW+fk2LjRAyxYCGgshQWq53iQbv0F/mx9b0fkq/4v3ru/AtfeC2SfwLVsIcvPWxlJVOf5KmBCAMF3EsnrLcAa5z9ePsNlkFtPOgw4X5KqDfGr98nL9kQU42qe+2gpI9umBnid7T4/+9yH/lDkq8xzG7SbvHfXGJEqtAG+Y8yptfXJmOtiB9ZsOTUMyx7icPjZB6Gck+OU22AxrU/TGdXPmwmvJ5IRXyMDGJoLNYXJOuF6Kcbc7rEaG1VYNyOiEdGcd8rK6VWYXzQirJlItOVr6U39Ep2I6Nu//dvp27/922d+7jgOfeQjH6GPfOQjX8NSmUwmk8lkmiWbu00mk8lkurdkc7fJZDKZTK+tNzUT3WQymUwmk8lkMplMJpPJZDKZTKavp970f4n+tVQrHVCWIgpCaYXE0+dvwmndmvKC/yOBp/LmM7Otynhy7mJO5vPBhlGFliqBPclT/w0Sgk0KP1JOOYE3wZPGz5RlRjztXNrI5fVe7HEB0Sr7Skf6sQ7BN3EdbN/NQN635KHdbvZ9EcuB5fvfTWk5WS9wa90PlpFjeL5YWfxHYOlCW3TRkTbQpQIXqurzfW8PpKUOLX8bZS74VknZGAEFsuCzLarZl6cx1wHhUsvzg1Rz0u/4+3v1aRqfaaJwLkHAbYVIjXIq8R+OA7Z38LzgKehpKG3pu4DNKBDbXovKN98DzyVanwvqtO4SWPYWMnw9ba1CZAKOyViNh08e8X3bcN9d6Zonh/jZEU/wYEWdWA3jcKMAaBxl77wx4HxYFxo3cQiW85sOW+Jazq1puu6si+887p2fpoOYn6+n6CCfPuH+jHiorsqHqKc/V6xDuWU+tM0iwkVbCHE8pGBpROs+EdElevs0XYB+1KdA5KtBv4wTtpih5ZqIyAE0TgbQWGU48TtSY2NvyD/XIXi2FYOrGfIY2hlwxVR92X9PAv4e2sV9927cRJLqmcakdb6So5ybpz/pSCv/PsSxRxy2bZZUZ8R5xc+gjVz5/GCOGYG9di89Edk2Isa7uDCObvRgDijJWPVQlftOZ8L31WgRLANa4LcDiS1KYE6tuTw2Msqn2Is5thzB1NEda9smp7H+qr4sYAw/Yr1WsjKmFbP8IE2wBmtEEl4e6GH0Uod/0PGyF3Ol9wG/4KZqfoBYcHvM9f/KWMaWgyHb4R+Z47KW1YIkM+H4lMCaIZfKNWUdkIEVb7ZvtAjzXCbg+NSdSGtxx7vEaZfzJQqbUc3yHOED7uSYrkzTUTIS3+k6/HM34ucYx2r9kHB7rvlchrWSHENFWA8DteCu63UcHsuI3dody/UIrsciwDlNVNz0wcqP/WWU6nUa3wvH6FJB2v9xzdrpnZum9zK6j/H19pP2NJ0PuU/pfoTC4bqck2vP1pi/dz1h5MqVnqzzQpbXjnPQFbMKbzSABqkBkmrRfUDkw2fKw/qwTBLnUclAvoTXab1kX+ZL6tP0QobLqvE1iGOYh7VUBtA6AbQFEVHF48/OVPh5hwoV8WKH52F9X1SBuA0asGZYyMu2OVtxKUwmRC0yvYbq3h1ci56X+jAs5wCDthfLfDjcRtB/R6p9yxBncV16pSfjyVEI2Au4NmICjiZyY4BxZ5147ammRtod8jVw6bg7lGvUOiA4EXeiyScOjt8Ro4owFhMRtWBden3A/fQ+hafsRDxW/ueBRGWgDkPEqvHv9bjB/VceGmoD5oSKomTEsM/IQGzvjOXD70W8/opgHe8rTh7iXARGVa21CzHPhyGgKBqZMyLf0GHMWIT9QA5/WkwZeeUB0kzvBRDXEwC/ZkByH79A9Wn6PCBccH+J6Fsioj4sxprQZrpfZqEuELXz6KZEk116lJ/96BqPm95Irm8Oh1yX/+pfMvpzksq2aUV8rzKsD/VaUX/vVS0oBslF6M/bI+7L+3JJQ0WY29pQBo1RRcxoiIg66FOatoRom2qWy6ORhbsjnisRsXL9lU2Rr3gN955cvr1ADpzDAMvE12v48r4Y+zrj2e9EJhDvcI0/SeR7qAzUH+5fg0TGliy8kcT3N2O4caBiO+IDCUKkRhEvwprhPniJEShMFo5/+a5D3rcZJHftKWbJ/hLdZDKZTCaTyWQymUwmk8lkMplMphmyl+gmk8lkMplMJpPJZDKZTCaTyWQyzZC9RDeZTCaTyWQymUwmk8lkMplMJpNphoyJDho7IcVOSvMkec4T4HzVgHtVVNymHnB18JOdkeQ2reUZ1PPkfG+aXipLztrLJ/VpejfgGwcJ3wfZZESSEf6eFeCWZiWfCDlQmB4oHikypq4C41fzBOuKB/aqsgoY9UKXua3zWeZmaYZbBvhOScoXn5PoLUFPrgKj+tZA3ncEbDDAsgvG3F4oOUsJtGIm5WfXLLVzwJG/PQRevTP7/6iQSZZXrKxz0A9qeWazDsayklfLzIQ7GUl+Lep9q/xc8y3mkzUD+bzHIfO/isC9yqkwcQzweeS5IXsOWWJERDk4V8ADPmwwkc++XOSKmQc2aS+S/SMFfmrGQaaZvF4B2L0dGJ+awb8Xcp2PgEU3dCRneH24NU0vATJwqyzzXetxDFnM8WDZVRwzHB5zwAI8VDCvHTrgdPTMNB1Gbb5WQV57P1ydpssu9x3ke2sdAwp4fySZfMj1Wy3yvXTMQC4l1rJmOyMDc5zy8+JYIyJquszhKwNXVSuGuyEHXfOI8TMf2K4Vh/O5zuxn8tz01N8TER3EHM9fSLnNiqEsdyHl8Xrgbk/T41TOAfPOBsWKjWi6W9d6IXkO0S3nefkBNM9exIzKtbQmsuGcgLFBc5rnHW7HUsJt2HIk+PZyj/tBxeV+hX1+byj7ec0//VwRzfXDmHEUcJzNO3L84/yDDEI824CIKDOCuDhheOStUMY0wZ5OTudDEsk54WXglneL8r6bcE7JhTLf95Mnkj39Yvt09vlxKDmyKM0TfVX6vIUTl7nbJYgtsSMr/eWYz6GIm2vTtK/418spTwq3HWZU63mkG/H8UPe5f+RUXe70+TkQN5t3Zd+pJ4vTdOjKe6GyxPfKp1zP5Swz/D3FtZ5LTz8rIlRrtjCdwGc8oNJUrpd2oL+tivNk5LMvjTamaWw33c9xPbYK563UFHsW2fN4LshqTjI+89AGuE47W1RraOCxZtz6NP3ZjshGATDlS8Chx3HTj2ScQQbpCszruo6qwNYfdbgPHDjyjIaFIbd7Bc4E0iuBARw6gOP40fhhke844T6G8ymeuUNE1J/wgiLnwpoyc17k8xJY68EZDXOenLsfqHO/rwPrdQQs1Ucb8jsrBc6Xc/n5Pnci6zIErjXOG3ieBBFRCfof9rE5FVc7Y6Jw9lLL9EXtj1Ly3VSs1YnknHwA+8GxhvfOEPZlIqIj2O8EwDfG8ziI5PpY7GngejHJ2Jcj3Ktw+fYVmB3j2BacixXkFBgclEDg12e5OHC9CcTcnmL84joGY5rG/npwRtvV/uw6dyAfzksZtd+v5zDWw+89Ls+D1QF+hQYTrpf/Dzjlep+xmeP5qyE46rPPu9H8atQ5vw734vYN4LwGIqJj7wqdpoI6OwznqQGc0VJIZJzA+WZyV0RmlTJcL8iR/2yLr3ei9vQLMIiwH9XVPIL4amR/b+4siHwPEO/Fags8r6XyKCL6f+F7h8Ajr6r3RLjcwTMCip7sb8hLx+8kipV+EvHztmEJqLn7OdjDbRT44Z9zZdvgeB3GcOZLiR+kpN6iRlCmQoa/U1Dv4HDs3QZ++x/vyTl0Geb/hdc4H2GA0xeuFdWaEs+kwT15W/WdAJ4Xz4bS57cU4byVCgSX3lhWOq7dt8p8PbH3Ukt6vFcf3lPiGoGIqABnG52BMwZf6spnx7rYg3d1TTVR5zMOOa8RK1D2l+gmk8lkMplMJpPJZDKZTCaTyWQyzZC9RDeZTCaTyWQymUwmk8lkMplMJpNphgznAiqmBcpSjjokLfWlCVhEwWZysSy9BzeHbPF4rs1egEh5pjIOWw8egN83GvK+62CHKIIV5FNgd77ekzaEJbDH+uBd8pWP6SjkfHuj2fa4jQJ/7211ft7tkfTloBUH/2fmUk16TkZNtkaegNVz270h8tWS+Wk67rFF+kRZ7x5p8PXfVsf6kxbd6+yup6tdrsuKx98fKvtfASx62ZTLoy3heyO0vfHvz1RkWdFG1wLsyOWussdkuH03I76G58ryxYD1ebrF30lVc6J1bgmcqXEq7zsCG9cY+l4pM9tqGIK9O4LyZBTKZiVZn6YbLlv00G5GRLRe5Gtg2zzXklaoAPAfPbDirLsSa1MBTw5iGp4H3AIR0Yv0mWl63j3D6UTa2WKo3MGEn/F/7M+LfNjWyzm2Z2EdEUlLF7hAqTuWdtEK4Cd8l22D9SLjZVaTc+I7mwUeA2iVVYQKCsCqeRAAPkghEWoJdx4cKoEaNyFcD2kHY4XawXFUBpRK0ZW2yIT42bcdthPqcdhMu9N0AZAEh+6uyOdAhMqly1w+6FNFhQyogV0Ura2dSFrvOi7b6DsJIyCyrrxeKeW2WU24v6EtmIhoQAFNUpumv5RcxyHXcWgzkdiBLNj8lrPcr7SNuQVdHXtzXuE6Sh7PeyeAUrvuHIh8LWKUx2py3zS94EhUHOoElhNFaHI9OyP5YTHP/WV/pNcZXL4R2JM3SjLmFsHT2QLL9LIn51C8HmLaPLW2wPERi5gmy/eHhxyr8XkHChV3GHDjdFKOTyGdjmUiIpqA3R5RKiGgNYiIWvFNvjYgx4pOQ+Rrx7em6YP06jT9jvTdIt8DNa4zv8c4kh1qinzDlJ8pSrhPaVwHzoGIPtO29PkUcFEpzwOIDyEiahFjhzzApVWI568GIGmIiPJZLgOWR4+NG4BEO4L+Rn3Zj7ICacTXiNXCpeGU4DPu9Fh3RETb6e1p+gHnLJfVl+XDeQ7vVFScMXQ/A7mOhrHMh3FiHR4xSiQqahwDAgrsyWgD14i1Q8A0FcFKjVgmIiKkOzQAC3QgpyXaCRmZELe5XjOvYVluR1zPfQrEZ32H108Y6+acDZGvDv1yNWG8XMeR6y8f4rQH+yONg8TYcKXLdYZtWPXkd7BbHcO+J6fs8Jt5rhfso3qsYbw7BFyHjtO9KKYoUcHMdJfOlh3KZxzakdtf6kPV4TYhUGgRXCs3oFEHkcw3iPmCnTH3g1CN6z6gI0PABM0BP+ExX84P2E+PIGgcTCSqBNe5+Pplraj2yRB3lmHP9mBVxr6Xejx34Pw8pzo3vnOowDuMl3pyfghn7OP1OMQ1Pq6l1JRAK3D5GuxDEYm4WpF1dDzgL60U+OE1LqkOz9Ec47sO2e5DWNN0Id96Ua4ZMAI3x7NRceXMEnwH+opaWxymPNfmAKNWjeW7E9h201yWnxf7K5HEaXYhvmukJApbbR7QOkt5+Z0W1Auizp5SaL0XuhwjH6lxuw0nsr8horYM8VijCXFc43uU1lheb6uYwGdcDw1fPse75tvT9NMOz8NX+rLv4Bo6gXHz+Lzs54sDft6Xu9wncDhsFeVDNaF8L/Z4YXBfWc6ht0fcJz59xAW6lR6LfDTidxA+vIvRSB7sEwPoE1e7IhuVoD2QzFJRmJYq/NyBGNsey7VKf8I3Lo0RPSPXKojGwhCOt9W4ZpzvEau9npdrSsQ5ZR0u30JO1tEhNME2TDBRKp/pobpPYfIaiyOQ/SW6yWQymUwmk8lkMplMJpPJZDKZTDNkL9FNJpPJZDKZTCaTyWQymUwmk8lkmiHziYNy5FGWPArVCfO7E/ZD3Bqy1XVNoShugTV1f8TXmCjveJRwtZeys+3dJZ+v0cizleSRKjabPDke7QrPtU4/DZdI2jiLeDq0cgY5YFLcLEvbFWovYBtHADYIbbdBi8fVIWMWjiN54nXiAVohYetGMZaoByRJvNDlfGjDI5KnEKP18wis4qNUtjv2g9ABe5eyJF0BHAtaYs6WZ5cBTxp/qSttYAdDrstvWuRnekdD1v+VHluttgd8vYdrqoAgtNGo7kvnK/yLUQy2dPVfbWgVut7jeinC6eGewrl0ie27Oej/aSotdVVxKjf2I5kPLZedmMugMShoEW+G3J5NR9rrfcD/rCUr03RRYzigffeg2Vqh7OfzYC8quGAxVf3yBL63PeALNp2OyLfvXJumRxEjQ7Iu+58SdbI7Ipy2Sqfb9YiIbg+4DNUsWw3byoq6Q4xFCAb1aRrt10REGWi32UZD2aZVBy3rsu+gdXQccPy94dyS+Rz2A3tgnxyCrfJOedkuWgXcQTkLeANP2Wuhz2Mf66XS9pk4/JnncJ86SWRZK1SfpksQwx2FqClQjswQ/uXrrDfbZo39qqVOY6+AxRFxAK76O4Mh9IMYeneQyPE6jnnsFL36NN1IpT0WBW5xYTe9K/5CiMOxsZSXsWq9xF/E693syRjZm3AP8x2OzTnlzR7DOma1yJ9p3AHa6DGfcouKmIm2TW3Dr3n8XIMxz9cujJuCwrmcwHyDsUAHpEpmhU5TrNaAk4THeezwZ0OVb7Pkn5r+xMFsLJgP668FZWcNYu4UCaylxrOneMrFPAZ6iYxPReh/E4gsgcP9NZcuie/M5bhPnIT8Hb2uLUI9HzlgZZ/IbYbEdQA2LqvtxNDu0Eev0TWRrzW5wWUAFF45XBb5EJWDt9J4s1L2dBvvyz1ZvgNgLpyDtdMFtaTvRIhY5N8XMlx/88p2/JljXCNxvp4aG10Y2C9GjFU5ppsi39DlPnE4gTpPqiLfGuQbQd+uOBL7UIP10lLKfXspJ/cjGBt2Ap6f+2p9swDz4YUqX6OrrOOfbfG4nsB654zD7Z51ZTt9psn5EDW1UpT9Eu3r4PC/C5fQBCbBS+n2ND0I5TpjIV2nWKGHTHfr1eq81VdoPFjX52BeaqYSBYR7szVAH5Qysn3nfY7HCzDZ6pj7UJ2/dxhwR5jPcT9QxCDKOPxZDvBLmaEMBhgysY/l1V5svcgZ3zXP7x+OAllYxL4cA65So7EQgTNJ+WZVRercG3I+xJvmFPcJER04PnTklGg2Hl8XyzwuDhXua5JwvkuAwPhCR8agfWibi+XZyMwO7L+eh3i5P5b8oC5g3/Ycfh+RIY194UqvOjxXBrA/IiI6cgAv6QASOJbPW4LnrQOGp5zIOIY4mxIUCeOWCpeiPRah/855cgExSfi+2OcxTUTUg34wjDnmljKyv8m5jX9/WYZ9ugaxdBhzeiEn67wBczfOoRpd2xnzGG/BvKvXlC90+XnxeufLsgIX8nzfK93T1wU5hQ96vpWemm8vUO2exXx831VSOD3NSPqiNPawH+HY5XSoFjgHI/4ZEVAXa7LONyEG7Yw4Jmos4wBQPu0xF6oVyj6G706q0DMxJuq9DnRLOlvm8jmqKfB7eL2KJ8s6SfmLRUDj6Za9vxLTSC8MZ8j+Et1kMplMJpPJZDKZTCaTyWQymUymGbKX6CaTyWQymUwmk8lkMplMJpPJZDLNkL1EN5lMJpPJZDKZTCaTyWQymUwmk2mGjIkOyjkeeY5HTToRv3eIwTzXe8w40zxn5I4hRzJR4CbXOZ059YljySe8BNywRxeZ4bwIDKJcXzLSYmD+ILv6cCTL8JYG/4wMyEAxmwPgcv2vI+YlItObiKgIDKwRXEPzl5HltU7MrOx5Z0Q+5JjGwCDLKBjSLeA5Pw9AsHak+OYps5qQc72YZ4bWUSC+Qj4MjzYwPkeOZKllxsAki7ncYSxBd4t5zteBsl51rsrrxXzfauf+abqclUytlTww833+zpmifJABMEmRV//WmuQ11oDBn89wfR0HknN5c8jPOJr4kI+/8yK9Ir7TSXam6ShzYZqeC2W7f/qY66gZcvnqvuKqAmyzmbSn6cNEtnseOGRjh+tFc+7Kzvw03U2B3am4dLWQxwCyxvoTCSgLYr7+EsDnFSZMcGUj6OeJgvcWnNo0nc1dnKYX0vVpeisredD5LPLcuKyaMzqIuewdePaWeyTyLQIrvuxyuxeysp/3YOy1Umbtdt22yDcMF6fpuQzzDm8FksW+696eppFHHKaSQdgdc74sXC9WXGAHmLztDPe/usP191xwIL4TEvedcspx+pbzosh3PLg8Ted9Ztt5ruQ5HmW4rNXk0jQ978uxFiUJRYn9X/eX0kHaogz5VIskgzSFcRTCmQgZkn22MQY2PvDwkaFLRNSdwBkawAxuuJsiXwbmmPmEOa1zwGUtKbAq8oORkVr1NGORM+4AwxR5mkRE6wUe58gJbAayP0UAHszA+iFVMQixqF3gL2p0ILLTkb/oKXg6cqk9mNdbYxnDqx7H/j6wu3EexnMsiIiqxGenRCnH1aYrzx/px4fT9DA6nqZ9dVaNn+HrlV3moLqKv7o/4p8XYSjP5+T8hWtFbA3N18R62RvNZtQ2oO1zwAJ+sSPnpa7D/XcIDNh6wrF4JS9j0Aacp4Hs1JsjGac9GFMpzGV9kuuRrWyNTpPm+BYyuJbl368nWyJfBtiWXsLjC89NISLKQNzHvugrHvEyhOoudMVXOnK9tA/nIBwG3D8avlyTY/s2gGu7BWfm3FdS7OwFjh8vA0cW11hERMcxj4EuztcqZODZDhOIW5Ejx1oz4esFDoyVu67HZT9XYDbuSlHGlp0B99mtAtdRITgr8uF5RLg+Ud2cFj05j3I+PPtCMZFh+sfzH9YcOSbzcLP2+PTYSUTUnHC97MbPTtN6jj/OECXqzATT3braS8h3Y3oluS1+f5C+PE0vurxnuAtgCzqE82+KE7lfmku5ffLALdfcXNzLDqEv4d5ar+MXc9xfzpX4+2W1NsYpcN5PIJ8sQx/21390yPFyqDjIuL/Ga3fUGt+B8QEYasFOJpJnZmH6rnEIU0TwGvt9VB7mtiDmOHGo9pdYpIbPD1zz5DNhO1WycEZWJMc1PkcZuMzdSO4f9uGsDTw3KU/yHJtyyu2RTzjfNWdH5IshphVc3jPguwgiIh/OtcK5SPexAjzWOxoc1Oo+zx0v9eSZcU14NxFA9bUjGafPwvzTibg8fdXfXurAmU+vca7IJpzBd7bEH47U9XagDZZgjVVRcGzsV9hj22PFbJ/wNXBtrM8SaYan8677EzleNyCEnKnAmQowxbci+Z3dEc8PuPYZxXLd88QCP6NYz6kxVINOgX0Zz0DQwrkso86ny+pDjL4odUSIuNdqAWOa/D5uOwoZXre0VNu0YYmD+PU8jGN9bhIG+yacoXQcyPa7VONCYJ/X3Hi8/n3Av/ccFX9jV7zHfC3Z7txkMplMJpPJZDKZTCaTyWQymUymGbKX6CaTyWQymUwmk8lkMplMJpPJZDLNkOFcQBnHoYzj0EIyPzMP4heOlKUArZpoCZ2X7k66OWCbwHZ/tkU3B5bw9QFf8OaAPSa3hvI7lztsJSmBrddRlo4e2MUuAjZmrNABhyFf40+OuKw9ZZU9U2arC9aRoq9QFSzsvsv3KqQSZTMgtuX1wVoVxNJmXYPn8tHConAzaKHyUrZ+LBf4+54rG+pwxL4atORvuxJVEtMinaa9QFrHxwnb1sbgf9JYigBswp8espWsO5aW9W9e4jpHOyHiW4iIbg45H/aXYkYiTc7VutN0CtbFtYy0bdfAPlYEC9b/swMWrvE58Z22w3WUTbh8oSPH0B5YxLtOb5pOwgWRr+Jx2efAXtd1ZF3mAAu0SKfbyImIbqV70zRiTLIK+zKYyH76qnqpRIaAk0/goIrKtlkF21oh4Ht5aV3kO+ty/eUz/B20ZlV92ecRCTGC8ZBRdq4MIAlqBP41hUjIQV0gbqar0EmIhPFgitlM10S+Etjw0cnfdJoiXz9l5EKUjE5NExHFCffLlLhfxbG0ytcLjHDJQizYjnjceeSL7zjw/80JXBsxO0QSI5MCRsJ3pSV0Av0l73I9aCv6YOKS/V/3l69j1XfaDmN5EOHkqDo9hn6fGXN7uCrf0OWYlMG+nUicSyXD/ScBu2IB2nelINt6CWJ4BPFXIz7QJrk75njXjnSf5b6I3SpQ3lucN32IVYGyep9M2JragjoqZKSdtQQW9u4YxkBG1mUVvJUxrBk01gOxWQ1YJ7QcQLGQjL9FiPsdpz1Nj0nGDNfxIM3PpMdryeU1Idq7Q5J+0b0hP+8Y0Cf96HT7MJG0Gu+MZB3NwZJkuw/2ZIUPQ7wGogQHJGNfB8aHD/1jK8voqbWi4puAENXjvAZXAeOnRifhnOVDGtfPdz6TV5xKNjVFCeN18nBfjUQ7RpyTg+NL3ncFfPNoae4kEkvTd3m9dEA3+b6RbJu8w312Zcz4tb0hx5x2XY7dhRyiDzitx1At5TVlMX2AZsmDOBalGGfkOr4P4+MY8ASJ6uc5h8dAKWDURqrmzcGEr499NK9iBhajDHuEpbzuY/wcI/hOC6ze2q2+Cv05A01YUN0c7ey9iC8ymsixi/2qkK1P074jY0aRahSr8We6WzfCDmWdHCWurOd6hudUL+VAmEvlPq0O6K4T4vk5q+ZuHOdo87+t9tAY1xB9gvvaINYoBb72+RJfvK4QJGNca3t88dsjuc+4Ae8IrnS5DwWJHK95lzsxYs9uR12Rr0yITOH6a0eyf+L1cIxGCluWAfwBYrc0guTBCuJYOL0bcJy4PZL7VcS0IPalr/b0iFmqABtuzpf7kQmspTAWPJosi3zlEfejEazPNfIqD1igPZfj/jCSGOCsw8+4Qrz3QRQmEVEF4h32luNQ7SlhDjxbkv3lVRUzap0Bt7rWh/WW2iu+eMB9AhE/jZzC1Y15fx7BnDBQiNujE0ZjlrJcrwuS3ENHgPJZK3E/0O/MELm0ABikg1CO8Vnos8NYvsNAYZy4NVb4u3F9mr5Y5Xyree7oz3XkRFKA925+ys/uq4lpGdb7TeiXiwVZSSWB8eEJ7JmOzNeE7oKvEsdq7I6gfXG/2VVTFc6B2B51X8a+CyUu00IeyteS70pi2GsvwLMfQRt2FcoGES57mmUFOgrwPR6ibGS+7pivV4NxV1Zvwp9rE41nL9mFbHduMplMJpPJZDKZTCaTyWQymUwm0wzZS3STyWQymUwmk8lkMplMJpPJZDKZZshwLiDPcclzXHqwKk84DmdYuubz8v8gGmCdqYH9ZDUvPQovdzlja8yfbZSkzWcIp8M+22aMyW4AlmtlK5OWSS4fWoaIiG6ANfhMkfM9WO+IfHv7bGM+ATtsJ5XW1vr4dCSBcmaL0549QLEUEnmSOtroF9LZeJ1j8HTm4b5oSyMiygKmpgj4BHTYDJTlOki4sFivHkkbTUhoE+bnKClcCtrXD2Mu93p6UeTby1zncoMFJkxk+dpwCjk67+eUfee5Nqd3huz5GSfSM3V/lX9eq7MdcLctbTm3wIZ8EHL5zpT5vuuJHEPdiL8zBHvsWPXflsOWuK7DWJVyIq2yHtj8KoDhKaSyzrNg4ewD7qMJFn8iogH8HEN7Zqkh8mF/RtuQq6ztiA9Ci5K2RbbglPAYbMLVrIwFj81zWyMSZm+EdixZl0htmMuh5VKW4TZgUfqAw9mnl0U+H/p2mXhMDh0ZM9D6jafcn03uE/nygLwogR2zFG+JfCfh6jT9cubKNH2UyPLV8mz/RdTLID4U+RCvc8N5fpqeQExbdx4W31khxh0ghqoay7iFoaGbMCIoTuUcMEoYV9Vx+b6DofTUhRQJ9IvpdLlpljKUFfgmIqJiCjgstz1Nd9J9+X2H+0TicGz2lHU8Im4rH2LDXeMfbNdoi0Z0ij45/nDG3zRoO+YBIEMQeTVU2Iy4g6fP82eVrFz2oS16b8Lj/9iVdRS53A/nAPW0lsgYmYdnxHr1lJ11pYhYD/79OJExvAYfTgZc2DzMCROSPnJsjwrgvvRawiXGju3kdqfpRF2vCBgZxAk03ILMB3EMn2mjJOscCC50ADbVobKsEzzHEvAnDiWVRiC6TkK+nq/6xELCsbQAXu8stFk/0qgC/qzuo/VWttNJwpZpxMZEJBF8+4AIXMtz/Wn7LHaXEqxfDwNZPkQzLXgcgKu+fPajgGPwGNZSg1jG5s8e8/XPVfgZN/NyTTMKeDykgKLokBw3uJ5AC/y1hFFTx01ZR98yz312AOge3H8QScRMDH02r7AqK3nus7g+3x/JINSfAA4S4lvkyDmoBsjLQ+L53w3rIh+uwwswNsbqOeYBQXQOlnoT3RVBsJwWVnlNTtoDJsfNgONbISPXtQ9D0bF82+OeyNdyuW8n0I8ihYoq0jrFqeFcvpSqTpE8J0dzJMcXxvAIEXpZOa6b0GcHgFgqJTI24/4JUWVjhSDBrtmCzT/GI70XO1fh8YbYkmsDuX7YgfX6AuCrrskuRs/2eEztuDem6Ykah2WYz4oBD5yhK7GWGcCRTiDQ6n1QFhBk9wFCUo/X9pjLvlXk6yHKhojoYo0fbAJ7tg7sXXdiue7Bekas3UDNjYjh2Rlx/a8XZNss5GA9B/Hovqq8bwxrxcGE+87l5IbIFzg8zyEicM17VN4XMGNnC3ztBfXeCEleuCbU7zBQiMBw4TVe1ZOdeR7QNohvxXdIREQ3epwPx9dGSeabz3K9HE0Y4XJIV0S+lsv7otXuE9P0k4p8GwHitgv9clFhvHAttZLnMbAXyP1XB3AsiN2qOvKdTQD7sUUPsIdqjboG78YulPk7+wHXuV631P3TX6vWJJ+ObgIBB4dXoOJRO8R3idxZ3lKT7+D+2z4/I2JMNNIE7zWAOLg3lOsgxDIewNQ2l1Pv1hyOcS/1uQx5haFsAAamE3FdYPE8tQXCvcSDde6/x2oNeAKVhmtw3Z7wER0GsOZVGNWMk95Vb7Nkf4luMplMJpPJZDKZTCaTyWQymUwm0wzZS3STyWQymUwmk8lkMplMJpPJZDKZZsheoptMJpPJZDKZTCaTyWQymUwmk8k0Q8ZEByWUUkIpKUwzNYAnnHWA+1yUGQH1TBmHPzsKZTUjfucCcKgVMkmU4wQYQkeAQkI+KhHRNefGNB1FzI5qndwv8kUpst6YRec5khOIDOhihp9jMJGFRbZ1BNAlhSQiBwhI7QmXr6x4eFVihljRZRYScq6IiHYmbf4MGmArOyfyXSwxcxL50HnAOzVy8pkS4Ot2of7riQR7DYEj/WzCHO+GypeP6tM0MjlXFZtxld7KP0D91XzNouI0cnN7E5UPHgs5gzd7koF1vc98sbe+h5+juiuZ0n/yJxem6V3g0lWB7aa5Xr0h8AiBBx8kMqMHdb6YMuN6y5fM0ACA30Pgj8/7kgU6gj7RQX6w0xX5kG2HDOrQGYp8JzHDwYoJ36uekcy1InCQkdGp66UIjTif8vV0X0Tu2vGAvzMCppk+H2EIn9XhengOA5HioKfMGe8Gt0S+RuH8NI28v5tuS+RD3rcLMSMgyQdtj/FsAq7/8xXZf5HBlna57xWzsk/kU+6/icMVdlKUjNoU+JrY7iWqT9OeYutHwJstAOdZn3tQj3nMe8DqD0gyKlO43gg425Ej28ZLPYpJwbNNd6lAPmUpR0VHjn/kf/cmPMccpfKMBZxWsrAsKirGeg6g9/mUWYrlrOwHMUycyOgbAG96uy/5ph6w9hvAVdRY4MEEme1839iRwSWGfl7P8HMsF+V6BJncHYfHbi89EvlwrHSAC+wmMlYVIp57KwA51JzL+0pcXjz/ZV+d34IIxhyeRzCpQ9lkzJjAs5/JcZyoq0UWxsxSeGaavprsiHwRcVutu8yhXSrIdsfL47Wr6pk28sB2hxg5VBBo5GEeB7PjQBnOv6l6XKZCJOulNeafs3AmzXHEMSiI5RjaKHmQ5uc4VlzVWynPlWHK8S6v1nY16IsZWE8fjhTPNXc621KPhxyM180yp/X6/HKbnwPbRnM4D4GdjszQTcWHzTjcD8Yx9/lOtCny5WC+GAAf3QXeL54FQ0T0SkfWGd9T9qO5DMcgXNv5GflQyBbFcxmy6nplAo40rlvUOgjPLMrC2NPlQw46qqoqfR36FXKBNZcUz1TCsYaXU91InHU0hOfojGUdX+tz3zkE3v2ee1PkqwCHuu6uTdMJSThuLanYeSZfhlbyPvmuf1ffwb0jdqPjQNYz9vsJjK+uI9dcIZx9gHsx3Re70F/wrIn1IvL95TMgZ/gzLV4jXO/JsuKtYhj/z/fkfmTP5bU3rlezaj3ipnzBtsvz9VyyLPItZnltXPXgnBIVw4fA8Md9/LKv4wmnGz4PuG9ak3vFQcBzycsd3ucOgIPeVccGYL2cLfK19dz9Epz5gu8frnRlP3pigRvrbJEDyE4gG/H+Gp73Aec3tOUZTQms7S5Uua0VVpm6sNbDeXMpJ/vE9pDvhfEOz+Mgkn0Ox0Yf6rKn3ss8UOH6Q1a/xj3jHH+lx3vcjx/KYOrc9c07yjs1lY/LcSPgfXdJna8WJzgPn35GDhFRH3j41wdc55pHjjEE44IuNXLQHfjOObX3/PNLzL+v57ijBjE/h6fY9bj2L8A8fKQmplt9LnwDFrm+mvQ+N+Ax1Qx5nfF/rct1GrL1T2DduKvui+c5bJX5S6tF2d9CGFOz0kREL3f5GYfwWUM1Yk2Uj9Nj8X5EXrsPsRjXXwU12G4MOLZPUh4orYnkxrfhHRCeVaPPStLnYr6W7C/RTSaTyWQymUwmk8lkMplMJpPJZJohe4luMplMJpPJZDKZTCaTyWQymUwm0wwZzgXUikeUdRIi6VykKvgQLpTZGtFR1pnDAKxf4IocKHvhGthMF3y2KxyPZ/+fxloeEBgxN9u1nrz4iNgWNorZmr3rSrvCurMwTbdDLsMfHkqLCJomhjHbgdwZth4iouaYbS9LOWk/2yrzM56rsGXitWybPbBFBcrufDLi8m67r0zT41jaOB5P2GK7zs42yruAvFCW632wug0AkVIayGFzPWYb3UbKNjpXsWwOx9yxGlm2JJ0tSxsNWmK64HUpKgsL2mDwVu1I9qN5wNccjfgztPUQEX2+zXWZ+z22sI1ieb32mK8HTUNNsOighZZI/m8dIoJqCoOwmOH6Q+vYalGWoQ/P2Aw5XVL2YbSIhQH3ibEa5PV0ZZpeSFen6XIqMS3Y70PAyJQcma8Evk1sm62S7L81j38+Aav9sXIB4zj/7PjGNJ0DpIRGkKDlem3C9rMolW0zdgBR4zSm6cBri3wVh1ElZwGRcDxuiHy5DKMyEM/jqelmAuVohmytulCV9rjSjFlq4ki8AT4/ImrmkhWRD5EpWEdbLsfEVDEDECeAfWCs6rIAuI8o5fs4juyXscuf+QmXO5vKhy1SjiY2TX9J5R2Pso5HR2lH/H4Y8fgoAerFVX8/4AGqBO3TniP7Yi4F/BpgDOZyqm+DTfUY+vaNdHeajhw5yOdjRiRlIh5fiDAhIiqDHfvxLH8nVp12DGVoj3msfLYjbe6IODhKr0/TWUfO3XO0MU0XEq5XjVIZA2orhHGTVfMhzlMvgpv9cCTHNdrPcbQ9VmFra1/NN0eh8oi/Wm41h2KZMoCUSkfrIl8r4RiZz8jnnaUS3EsRdGijAHMHXO/TTVk+nEcRZddW6xvE312o4jPJvvMcoAZOQi4D4n7mVV9GezE24ZmKeqgex9mrsF46k1kQ2Z5YkPPUq3r2RK5ltwfchmjTRhQhEVEWxnIDlq8rOYU3qnGZjoLT52ciovk81wV05bvacLnA19gDXJ2rsBRhgugujgXYnoEzEN85mvCYfKTKsUCjJrF1cT3YDuWzX+9z/0XUUd6RD4WonXWPF8rdSNr1ezG3zbHDaKe8QgGhRRxxThsFOV53RqfjevTzdsf8C0Rt4BphQYYtKsxxGZwW98VhLOvo2c6ITlORJIIA43YesGAlwMkREa3lCzRO7O/UvpTCOKUkTQXmkUgiL6se4M3UPHfi8JwfpDyRDEliBhccxnWVsnLNisK9KGKQkJCkSBsCT9CBPjpRZa3C/gTRjno/XQO8AGLjtPB7HYgG845EHeIcmod5aSUvB0srhHzwwBuya1PNgzkeUGwHHYnJa485HlwbcKVdbvP3XxlJlNVjFW6bqMJl2B/JOvLE3I1YUZlvCCiQE0BvtNUS4eEqz4dnShwLqp6sS4xJj9cB+xTITvE/Dzkjvj/YUTHhWpfvi1igrbKMzasF/jAH7y0QB/Na8l2u842ijHWOw+0WAj5yQDIfIoN8QH/1YonMDACz283ymrfff6vI9/Yi77XfCkOy5slYgHjka/D+RS37qIYIuAjwMKrOEZeKKCC9Frg24Geci7h9b424PTWCz4eLDAAnm9HNBI+4M4I5ReFhSlDPuMZvR7KwVXiX0IF3NIi+JSJqJry2WJzwMz0xL58jAkTwLehjrVDmw8fHWx0Fej9yOgq7G2F96WfiunjnPNfRraFcZxwDNgpRbC/Qp0W+kwHjanMerGn8bxb5uuPiXaigWbIZ3mQymUwmk8lkMplMJpPJZDKZTKYZspfoJpPJZDKZTCaTyWQymUwmk8lkMs2Q+cRBZSdHnpMjT1lg8c/6bwzZXnC1qzgtoHmwiyCKhUjaauZ8tnfeX5U25iDme1U8ztfw8XRiafWKj98yTTfhJNoFkpakSzW+BlrTepHIRn3gddR9zliMZddZhpPLcyFbxJQbU9gun2iwrUTbHl/s8TXw9O7hRHosGi4/f5PYD6RRD2jLq2X5Gniy9fZAWknQlYufaNv8Etx3Kc+2ks5YtjueFr+aYX+cr9zhO2ANRkv4BekqJaxdRKx0lc3n9oCvVwbMyMPKLnYGTkL/TAvsMSP5vNhf8ABmtFIfqv+f60fcHuernLGqbJEHcC88Hf61rDWID+iqz9CauZiylXclXRL57itze6DlH1FHRES7I7bRVzL8HHi6NpG0hRXgo0vlscoHFqzo9D5/52fASKVsnUMEwUZ6QXwHMQvYZ+d9aYWqxOem6V7MAeCWJzvcmZQRB2iHy4fyJOsTh+17E2Jr1VxGnnKPdtEcNNSJQtkcwOniMeBTCiQ9piXi59oDi7m2Y5dT/rkKVjm0hOYVFgiHPMaFfFaOoR4gvVLw6+WUJbeP14ByL/oyXy+KKKLZ84zpjnrpiLKUUOzIupqD+WGpwG016ctx7YOFsgh4jZNIdsYi4InmYR5GfAuRjP17KduVuw7bYcskEWt47Sx0Mo1BWS9x33ygwvf5QlvGoMsdjlWtlHERLbDkEhHFYN+tOBwXN5MzIl/F5fL1Uw5QXZIoilsx11kl5jGKmCciokNwCiNu4oR6It9awvMrxozHGoB2UAiNTxxyWRFl81JHrgsqMH6xR4QKGVIGtI07Y11AJNEna2C/3izKhRXakxHXM5eTcWcIxcV8lVTa8LH7HUC9zuVkCRcLiGbhZ8d1QV6tRxBVgngYnQ/nwKUJt5lGFXShKnCJpG3HQQL9EuZajVjCOeHFDl9wotBY5Sz2F/69Ip9QCdYk6C7uqrUxrl1Got3komYw4RukMbdb1ePxUB7PxsbhGlAtf6kNCyPENyF6kYioCciLkDgu4JqIiGgMSI0qzEWI0yAiOom5ky3BNeY82S+XAOHyRIPjwq5CH+wNT7fUh4rngrFwFapsKcf5zpVkzH65D/sRuNxxIrFWE4iDaKFPHIWKiq9M04UM9/PF9GEyvX55rkO+65CieFEn4X6KuCo9H+I6C5FjQ0fuBlYAkYJzaKDG/5kyx5dd6Je3+5xxqFCuOSjTUcAPgjiCO2Vl7QxORzsREXkpx7ulDGJP5bM3JzwOE5jB9LyEuDPEzWjh9ddKXHbPld/5fAtRKpzv1qgu8m0UuM4Qr3UZkYgkAyviMY4AcdlX/aMGm89FiAVqKSbatwX7ZMTEEhE93+UYV/G4/i9VZPmKGb7gmQrHkKNQIoJWi4jG4d9rjAyuGbD+VVNTLcv3jVLAm0F40ujLAbxDQjzH8VjuMyqAshWIILcp8kV0OhI1SiT2JQNrRd/lsdtI5V4MES4fOL8zTU8UQjYCHE4Q83xzEMrxhfMjzrtbCplbhy1wM0QM8+wXDYjrwfceiCkjIjqB640B5VZQPJdleDFwe4BIWoUtGvMaGtdpO0OZ7xg6OiJllgry2SdDHiyHI+7b7Ui+F9gE5KAPa89RQd63A8MjhfWm7ueIw3u8znuG21Ce5lgtKkE1eAfaVvtuxPgiMq/oyjGZLT06TTdobZo+66yKfK5zN9pnluwv0U0mk8lkMplMJpPJZDKZTCaTyWSaIXuJbjKZTCaTyWQymUwmk8lkMplMJtMM2Ut0k8lkMplMJpPJZDKZTCaTyWQymWbImOgg13HIdRyKUslFutxh3lMO+F9lxTtbR4YYMC93A5kvSrjaNwvA+JyXXKkFj3/2PeYT1YvMpToKJcdob8gsqsqY71NUjN/MDN7PgkQaCpb4ccDXqMrbCsZUnAKfaCK5Y8gnS4DrlXUU+xDKh4y5l+PbIh/yje8H5nJVgcaRJz6M+eIHwDDeG0o4HvYCZNzrdl/IcD0jl3K1KIfXAz5zl5CR1lQMaKxzZGANFBPOm/FfYAojKXh4yB1dKshGLAIrvpTlAmqC3gDatAwsNSyPZgRPAEaJTHVdVhR+dlsB8a7Gh9N0JuWy1mPJ3UV+6qLHHDjNVaz5p7NsP6/aBtmFyE5eUFw05KzWPP4OnodARLQz4jYIoF9qtp24rwMM/oS5XqtZ9ezwjMgc16wvbCs8D6KQSuZ4HrjRyCnvOSciX38CzHZg490imW8u5nMaqlmuB33uQZjwzxMYlXgWARFRF3h95bTO13MkY9lD9jQxbxJ5qb6qJKy/HoynSCH0IuCq7iUvTtPnnLfJ6zllSHNj6yHdS0OapKoTmu5SnnzKki9iARHRQp7rtgpj/P6a7Nt49gG2tZ+RrGJklS/C3H2zJ+PTYczMvwTmthL0WU9xrVEO3KeoGInIY+wAmzVUMbcE8xLF3M9ziWQkYmypufy8jbzMdxwqyOEXVSBZR32Hn30MbPGTUNYRxh3k/vuKKT1MmHHowBkQUcJtvVaUZTtT5rrttWbDDatQmXKMy+8EKZc9hkF/UfUjjNuV7Gy+5lGIz8G/n1ddohVy+RCpuZmdvXRvQV/W50tsld1T04K9rrCUWC+3YB7WcyjO1/gJcnuJiHonGO84Z1/FuRHEcy+pTdNVT/aPADje+wF/xyHJfV0t8vPi0kwz0XH6wc/02TU43B6s8pe2hzKK78CRARsAXMe+0hjLfjSesTBSWFXB0N0J+EYnwEAnIsrC+SjI+O6lQ5Fv6DLj1w/4DJS5nKzzoqM2ADOEz/hMhzv3vjpnpwmxAc9vqfuyn68W+XmrsK7C6mqN5XfwjCA8U6Eay/4xII4hIXBV8QwaIqJxwvXs4BkBrrzvcJJQlMyOAaY7ymfvrLf2Ajk/xDPOgsmrzesDWe6nuQzul5ZFvjoMYDwO6mJZ3gf3ov0JZ9wfwdlLY1mGCbRze8JxzFVzI65tjyY89pruschXTnjDivNzSe3jl1yOG9V4HfLpcyM4fTTm++pxvJzH/Qh/6XJbZKMrA44T1QxOWjJOIC99Lc/tuwXc+SRdFN/BOA3bDFJHeIl5ah3Y60eKk92BpQGerdVXTHR8V/FKn+vhoYqcl3IZ4MvD+W9rBbkG8V2uC2R3NwORTcQkXNvpM0d2Aq4A7KMLcB7EJJX98tqAy5Cks9dBZ0swf/W5T4Xhisg3gTF57PLZV3i+ChHRapbPh1hL+BqPz8nzs/7vi9vT9Mr7+BqpWiv2Ps999v/Mcz1fO5Z7wLrHZcc1Vimr30fgPMK/xz04EdGFMrc9nmPzmRbPHXoPeF+F841gfT5W63M896wb8XxzcyzPGDpwuY42Ruen6WIgx26Q8jXuK/L+8oGayEYLeWCQB/huQpbvGPjkt4en1xcR0XKeK2Af3hHq8TUPvPRLy8zaXx1wef7LbXlW3W04p7A55mfS7y8Reb8A5w1mY3lOXBveBSCfP1HvH5P07rMVZsn+Et1kMplMJpPJZDKZTCaTyWQymUymGbKX6CaTyWQymUwmk8lkMplMJpPJZDLNkOFcQHXPI9/1qKlsy2g5RQvrpZr029TBXogOj+db0hfQj9gSswd+xxvDBZEP0Screb7G4w22UmkrCbq9EOHiKevt7oAvjnbdzbL8fxXEtPTYLUJbJXk9fMIE7JiIySAi6sM1Pt9hS0whM9s7gfbOE5I4lwa9ZZpG6/5KUVuX+BpdtMBDHS8UZHuitWpRuPLkta92wZYOHpCy4q0IPMmAf9BPXoMb74PP+sWO7JeIwKiDt+1MWZbvYo3rZWfAHWZvKDvPHwsbLNp8ZPkK2dOtPXvg2h4qjA/2xSOwEHXHsgxo1WyBpes46Yt8VWJrD1oSl/LS4oRlHYC96DCUFr0mVO2Lba5LR3WjEdh8sy73X/UYwl6E+KBnu9LeidiALowNjXM5TLp8L5etbTWHr1dS/W2jdPr/kbbCdObPiKhZj+dFPrz+TsBl8FzJIFjK3j9NbyZsMfVdOb4qWURt8LUPR9LKN4rB6g2RNVbYLR8s61sOWwgPE2mVP3I5hoTEnXY5BivZWLbTMvj/EfvSHEtvZtttTdOd/s1p+kZRtoUHqJwQUBEZV3rllrMlgf8yna5yxifP8UXbEMnxj/NpQ9EIHPh7ArQJNxTD4RCaG+eOjAoUuGaYT9lP6dHcNL2v8EZd4jFViRFvJK+N+K8+xA8dMx5q8C/aENuHEzled4an44Ja40j8jP20CP20kpVYhHYE14ei51w5BhCRtpnh6yGOhEiOt5OIAzWisKJUNijOtTj36HZag3UCdp31oqyjfgS2aPi9ClVU9/kqeVjTBLF8dhfu1ptwHztWTTEPmDBci2WU/RTtrWgP11SQC4AuKGU5vTvi59PfQfv5ENApO5MmzVIOUEUjR+JcxvBzO92ZpnvRnsiXz9b5evTkNL2UUW2dwHoOcFp6jncdLtOFKn9nWXZfMQ+jtI15zueesAFW/rfPyfv+8SGPf6zaKuB+HqrKPr8K17vS57lI283xZw/mP1f9fdSJw0iSLODMIkf281Z6a5rOAQ5ny1sV+Rby/NkhMBe6kRwQQcxxZ3fIT98MJEID1/iY1lg1jHEYF7HNgteYL3FN3lR0qhRG9tCR601UPlOfpi8QY9q2SnfHoHGiAoTpLgUTosRN6cGaHIh7Q65PnAM0zmUE/aAEeMmqQgGVYC1Q92Ds1boiH+JIt0c8diuA8WyPZbvuxYxPChyex9PJnMjnwITYAgyim8rxivildswBOE3luhTxUDVYqyjqi0SaAMaoodhdOD8eAN5Uvz+YhzkfPxoohEMLsDdxyu0BUz893JCFRbQYllvjTHEfOYQ5dE4h0XDKR0SNRma1AamxUeJ6jjQipcvPjvPwpYrcC4wT/t6LQNeqevJ6S9DtESGxWZRzAsZ6/GTB53a60pd9fhYi9LVIFXXYi81FEr+CqGMvOcNldc+JfJs+f+9ildvmL6xKzFihwg0c3eB507tUFfk+cZnRTMsFrudnu7J8e6PTsaXNcDZ6DvfuGnmTh7XGEaD1bvb5Anr9/J4V7jsrsPW82pXt2cidvj/Xc4+vsHSvqpJRiDUYXziOFQ1S7H1wXCdq81SE9WsBYq4eh9ifMTTr91CPVLmtl9/LZW9c4T5xoVUX39kf8WB++lhxkEAVQPwhytmP5d5/HfBXQDam5lihBCeZLxvFZn+JbjKZTCaTyWQymUwmk8lkMplMJtMM2Ut0k8lkMplMJpPJZDKZTCaTyWQymWbIfOKnSCMh0BaCFgyNuTgKT7c8VBTSpDNm+80f9K5O0ysdeZo4WlPxxNlhXOHf+9IWifaMPlhvJsq6hAiXpQI/iD6pvBWBzR2ujXY4IqKax9/rgfVZ253b4J1BlMUolte7FfAJxVecZyFfS+Qrgt0GXMx3nd771hpbiAdw4vpOdraNGet2Jc92jxsDaalrgiWuAKeid5W17Rhsr80Jl2fRk3YdB6x8IVhKbrk7Il9jwriN5oSfI+OURb73LPF9qx4/++/vyVOgEVWy7vE1/Iy2QuEJzPx7RNncdRI12HwDsN5pO+HBiC/4CjEOI3ClxamU1qfpScondBcieT1EEA0AC9JJhyJf32FLZzUBC6cj2+ZCgW1mF8ESrtyO4r7oXNa2SFQbMAaIaSAierDAbb0V8fMiqud8WY6hMEGLNJZNXnsQc2HLDtpD5fQQwhgdEVuziqnsbwtUn6ZXCmzHCtQAQ5vluQrGI2lTe+qIyzciHoeI8SGSdrT1IqCOlKXWC/j6IVxv32XLezeRz+SHjHpBi3lEMl720qNp2nG4f0SJRBoUstyGURpDPpGNFgoZGicukQx7JqWVQpZ817trfOGPiAmZV1ZKtADjeNUOQsRPYbqg/NM4Xy8VeBwJa3Gvgl+h9Tx/B+Pidl9yByYJf3a2AnOF7Iq0C5gmDOFlZSdG6yeiFBJl+l32ORYifimn5gech/cDrkDHkeMakWs4XdQVQgfDRrPJPyCObDiR9Y/PgRZ6baGtAIKvi2gcVUceoGiw/lYKso4WAPFxbcDtvjeU+R6pp/AdbrhRrO31nJ6g7Vhhaapw34UKpxHZQkS0VuA45LnwHZ/b5vmutMDiWhYxRcVU5jtxj07NN3QkLqGbcJxtDl6cplOSwa+Q5bhddmbPI0HCzzgElMKxcyjy9UeL03TG4Xn8jEIYHgMmANc0uE4mIpqDdlstcr3ms3LRe6bI89d+wH0CMQHLeTnGv+X/4LXeO+AxPv6FLZFvB+zOOF77Tlvka0W8lqp6a9N07Eh2zXzKFv37MjznXZChilK4lwvr1YpCyqFVfqnAz7uh2FM4PjBO63GI8yMsp8VeTK+w8DNsz5Kyw3djwAw5B9P0IDoS+WrexjS9nuF1wpayr/tuemcukZQik1LFdyjnuvRNc3Lc7Je4X70EFIiBQkUiYhXRAno/hz8GgNpojyX/w4Gc5Sz0F4GGlP18xeU9wy7s2Q5cGYNi2GMN0/Y0PVb7EXIYj7FJjFK6kcrOFPX5ncG3LvO+dDEnY2kT1icl2POq7Rcdjk4HfVTVyw6cfvaG/Ew7I7lPOw55zbABuCNcQ2vcBCyXRPmu9eRcdnvc43ww167l5Z4N8WYB7GEQDUck5yyc855uyf6B66x1mAI7kYxpI8CMITZuPjcbeYNxcCkvK2YxB3hTwPOF0Jc1iqwCIQ7xcg31/gbbcwvmQ3yXQyT39Ysut2dN9Y91eN6VPOBmehK/cvU5/rn4IsynT8mF9ws9boNd2L89q/ZFe0Nu0/NVLl9FodgG0PQhBAq9f9gZ8XM93zqd8+Y58tkRR3S2yM90FMh8/ej0tfY5V6LTcH1ez/OA0GXFeIcfvdIT2QjIpGKvk1PXe6zO8+EjNc54HMp9dw/euZyB562oddBjFzh2JT0uRHaV2/bxVRkv2xGvVQ4A37g9ku+uYqhLT+0FUA/W+V6I53tOjfE7V5j9rubuvCaTyWQymUwmk8lkMplMJpPJZDKZ7pK9RDeZTCaTyWQymUwmk8lkMplMJpNphuwluslkMplMJpPJZDKZTCaTyWQymUwzZEx00PmKQ/mMSweKC5YBVhDyhEaKQXodmF2IF1rKS/AYsox6E+btNYEBTUSUc5m3dxs41JXW49P0ow157QB4cTsBc9bqGcnxRk4rcsf2Ank9ZIPtDfGBZb4/t8S8qKU88A3HkiuE7Czkd91shSLfDfdlzhfenqbr/hmR75azO02XRsyLrHoSgpXPcNnnciHkY87SONFcVS77YkEyjVG9CYPRmvC8z7ckE+powu2Rhf+/6k9kvoxzOovJTyXrLXT4OZDt3AplPuT/tYAFppm3yISrAZdW880HAGfFjwBDS+1I8sO6KTPO3AS5hZKrei1lFmiHmJ1acuZFvmrKDMLlLHPVXMWx2pswo69A3Cc2sjWRL6Eqnaasaos8QGoR/9mWeD262uU2xTo6X5Uhd6vEHx4HEBcUHHsD+JB5YNYhk/5kLPvvEw3mhr3S5z5RVBDCeZ95YDVgBlcUj7QFZz6spVxfyB8mIurFXBm3Ror1CEKuarHLsa6Rk/dFVn/PbXO5Uwlq9RyuI+yjWvMu97kg4fEfI5ucZIO2xjy+VgscS7fykp3eDTkGtbMczz1XjslaynzeNZ+v8WBdxtWFXHpnntkl02voWxZjKmYmdHskx9chjKnVInCts7LPNqFvw5Eld/FX90fcL8bQX3w1v26UIdZAiMP5VJ9dsQrM4L7IJ8dDCQIPxqCTUJZ1CGMAGaQ1dUbLaoHHAPKm9fENGPvwbJMgjlU+7sNzHscWzSPHyzeh7IGq83ru9L/1wNgSJXJ8rUO8xHinz0rZGZ4+16opj270efzj8z1UlWXDoiNvUp918FwbeNgF5NXKG/tijuHvdBSesxlyxvvgbIxvOycZuovvPJ15e/wpvvahYl7eX+Ofr/fg7JVA8k1rCY+BCZ7Ho86NCICRXsxxHNQxcs19eJouZXhc696AZ2Hk1RoJhWuunRG357Fa8wYJDz7kZmMd6zKdKfFnY8W178EZMDiU8WpHqs6f/+TCNH3+bHOaXlHr0CqML5z/kL1MJOu2kTJn1EvlOvlSju97qQas86zswA3gwV+E/vZZxRZ9sc11uQYc9Mfqsk+cLXIb4lkCRTUeMJ7jGJ2H2+qxhjEyJ9Zv6tyDmO87Sbl/TNR5JuOU15Q9WLu3x3LuWS8QTb48rOqfaW0WU8pnUhrEenxxA3/TPKePxnJ8YXtXoL8gK5pIMqEH0DA3B7LPogpwbsSZMp7lJNsa569ejyf8Q7Wn7014TzMBBn8uK/cfCdzXR+ay6tv7dDJNX+muTNPFhj73hMuH87+e5wowV87Dewt9jgrOS7cHcE6JK5/3EPZct/r8jGU4R+l8Xj772QqeJcS/H6tz02KYV7qw/11RcwDO/5HYe8rrlbIcg48D/qwZylhahPm/lOV+0B7PPktvJZ+c+nsi2Rdx37IzlP28AOey4NqiGeJaR167Dme+4NpHv2I4hGt0YOujjrgQMfM44NiHZ9gREYVwbt/OENc3swMintemYyn2gx7U11Eg353cTIGp3eUzPdZKcp7D7SGekzFUZ9dgW03wPA04+25FnZWCZ6q0IVZN9HsUeKheyrHATeQ7kYtVXledL8NZP6/Rj/AsEn0eD/adBQh9eD6QVg3emcUqriL7vAjpo0Duifpt/nn0J1zAXJ6/c70pzy/DuPUQ7I0LWbl3wnNUDse8LwjVOiiI+YEfqHBFlBZlGw4mDgWxS9SkLyn7S3STyWQymUwmk8lkMplMJpPJZDKZZsheoptMJpPJZDKZTCaTyWQymUwmk8k0Q4ZzAdW8O7ayjkKQAPmEypB+tiUtiYchIyuqYA3aU0QDtADWaHmaDlxpiVlIVqdpxFSMwXJ9sy8tGHtgN+86bDuspLMta8+32PJQ9aSFCO0sR2N+voOxvG8jx/aKi2W+3oMVaad4GbASiHrJu7IrIu5gzmf7ad9pi3w5sG7dnvBnXl/aQn5vj+1F94ON43yZGwetKEREi/DZyn29afqsslLvf+r8NB0m/Bxo+SEiGoPd+VbKGJ+BeqbV8eY0nYMhmjqyv3nw7A2HrdX312Vd9sBjjlajW+62yOfC/6lFwB1YzUi7HeIEmmPub3mXnzdIZSUhbiZwAGujLJw5h+vIc/j5Cqm0jucJEQRcL+1UDrYB3CsFdIyrfLZ1wPogPkFDQRB3sJjjT/sT2daIw5F2Ynk9xAnsDRUTBjRO+IuIkTqGHwYTWZdzPtefB7YobdFbLJxuBwwUrgptkWXAJfUjGQu6A27rpsN2U0choLLw8wsD7i/lgbSBHbqMJIiJ+2U/lQioosNtiFY5jS3qJBzHmg77tYpgMV1xZPzwwEqJlleNvOmM2Q4f+++apvMKW7TiME5otcj1cLEsY5DnpjRUuAzT3bqv0qdyNqIglrGqBXbKxRz3A40nyGcQM8ZtqlyqtJiXlt1X9Vo2VURoIdltvahwAi6XD1FiiF4jkn0ObdW+YpWsZNF+zr+vK9smlh3XE9rqDaGP2hE/VAvwBkRES4B6WshxnNCEpStdiNtjQDapcV0LOR7sOzxv4pjqR7JdhhAL5wAPpZsJqwzXOjr2uTAn9Cdc1iCRsWre5/H7cBXROLINr3T4BkHM114vyeDc8PkaOHdolCCW98aAr/E/b66JfO+K2O68eIE7dx0oeQv7ch46CLkNH67z759ry2faHSIOi9MaGVJzeM1b9OuQT65RC7C+kVZqZSeGuagG+IBaTiLbcN2MOJfDpCfyFYnLUchyXYaqzjsQJy53eO6oeDJjKcMdH9cWSEFbycs+PwG04J9cXp+mexNZ57jW3u5zuXeHEn+Xy/BYWUt5jkIcARHRuQp3sgWIlzX1TN+8yXNyAP1jP1gW+SqwnzhT5LIWM/J6VY8/24C4qP/K6xiQCXvB6bFFr9kQC7Ipxpdcj0wgnrTiDf5AhfwcYLiihJ+jGag9ke+IsW06XQ0/pUImoZwr6w/nQAfWcBfLMj5drDIeqgkIgVBhldoR9/X9lNN9tQfBNsMSZWENvV7QiEvYE8GatxPdFvmCMX+WzXB8W8xcEPnmEx6jRdhLLMRyXA9hX4Xouat9hSeA9TCiWTT+owv7+hHgU7KuHok4RjmdI7nORQQpzgO4H9T7EVzTXO1yGTQ2suYAlhH2m6Ws2lNCN0CcbBDLuTuGOaY95nm8l8jYHMFzHAIKLKc4LSFcD/sH4m6JiBoQX7DPDuRWgF7o8mdYR/i4OtogyhWvNxnJsYF11BnjOlReD+ulmfDeeujIfXfW5fctVehvu0MZnQ8DHsuPNLg9Bmo/jeudBOLCSO2NEGOGaJeovyDybRR5rkScy61QrmVrLpdJ9j8uH6JE9PV6MZdnNS/XN7jvixJeMzcnEh+Ga9sHKtwX33FBMj4PDrnOL7fqcF85yHOAipqD9arGX3UgXjagX/pq7i7D+qZe5LIvVeTmqVznffcfX2bsKba178q6HEJsRrTLUl6WFddLiGzJKUQgbtcRC6bnlCQlGsbqRd8M2V+im0wmk8lkMplMJpPJZDKZTCaTyTRD9hLdZDKZTCaTyWQymUwmk8lkMplMphkynAtoGDuUkHMXcgEttdsD9gPoU5u7xJYWxEVUCuqU2jFbMjIOW4Mc9X8ax4Ax8IltCeMIbJHZRfGdhs/Xi8I6f1/ZsfAU6NaEbRZX4hORbz5lrAFiEfqOtGpchxPJHbDUzvmyMl/uonUcLOGJtLA4Dn9vPmVbbiaV14sBcdIAm6X+3yE8rfgATqJ+sccWmK2itJK8H+wo3gJ/Z7wv7SzLeX6OY7CIKHe9sAOhta1CEh0hvnOXOZXVctnaPge4k6GygW1DX8yBtbWYSvQB4lwmgM2I1CnmiOVApMYQbJAjCsR3em5rmu6mfEL9PEl74jvLfML8s9CnOgp5cwi2/hzlIS0tU3gKfIkQ2SJ7yFHE/W8Cdf6OurSEX6ryZ2hxripPol/musDT61/uyL5zNOb74jPNp7Je9uCk9l50OvogUYilXsR9G23avhocIXj20BadUfZEtMOjJRRxCXfuy+02HvPzdpwWzdLQZazKbZL5csRtOEo703TLkYifPrR9DixdGWXbjul0PEoC7V7IyO+sgu8VMRzHspuLOPvWLGOZFvOy0pHQsaLsnahe5NIotv/r/lL6/b065TN5mqiqfKUD9lNgqawWZPvi8G2Ds08jUnBMoDUwo5roZo/vuwJ9B+/TVW5BH66NQ6+qkEG4Pilm+YHrvsLQQb7lHPf5MyVpT366xeMVn3egME3D8enjJqd4BysFRG3xmLrel1bZAXHsK0AM12sLD1BWnnM6TsdRc+1twEMNwJZ6tiIbCq3VRZgbhwrDEEI8wXlc208RSTCAeVfbrLGewWlPhyMdw09H92iEVhY6TAXoKdcGEqWS2+f14gogSNAifRTKORSxe5sF7gP312Rd7o9g/k95EFWUxX+eeN0RJ/zwuH4mknG7JNBE8r5DGPRYz6sFNXahk1QBdXazJ7dByzBez8I8XsjIOpd4I0DhpfK+1wCtgG1zpsgxIpeR/agX4b6ApftbOcvtsZCHdfdAoWxgHYQxTM/xI1g7onu9kpVj/6TLbbo74Hn4Ylmu4882eL4OYRze7FZEvgjs2I+d5/VhYUned+cy953ggC36u4BV0AgV7C0VQFk9uSDrch+wNOeCc9N0Z3xW5GsGXEkZQAt1I3m9az2XxpqJZbpLT5845Lsuzat1ZAmGZQXmuY5CPRwEvH+qe4j/kPfBrp6HsawxBogNQE1EjNTjhr/TJEa4jCOJinJhv484l7zCVZ7N8/hIYd8dqbVrGebN0YQ/6ykk7QQG8xGwMXB+JpLzKI7/gerbPix4JrA/XE8lPmzOOx0ji+XJqzWWRJXMxiEhPjQCrEdflTUDa3J8XL1nQxTuXsrvQWokY9W8z2sQvfdBRePT53ic04mI6vB+aQu6wc2BzLc74Gecg/3EIqAt9DsHvC+uayMVlzyoZ8STHIwVWgTWbIjjrKeyjnAtimvyWyPFNgYhkq43VlgP2JgWgV+DfYCIqAw4x77DmKdWKteUeZgvhgnHc/3eogbjC9cPer2JWoCF9ypgilclWUSsNxs5zteNZF1iUyHi6vLNJZFvd8RlnfN5/VtRmGIcUnWf9wLnNuS7P4K1bBb6aLYi+w7QUWkScNuULsi2SeA9SO0Kl+n2CN5ZJnJMYsxFHNFRoNfJ0GdjjrlziYyruEaNxLtEua9ohkRhMvvdG8p25yaTyWQymUwmk8lkMplMJpPJZDLN0D31Ev0f/IN/QI7j0I/92I9NfxcEAf3wD/8wzc/PU7lcpu/4ju+gg4ODr18hTSaTyWQyTWVzt8lkMplM95Zs7jaZTCaT6W7dMy/RP/3pT9Ov/Mqv0Fve8hbx+x//8R+n3/3d36Xf/u3fpo9//OO0u7tLf+Wv/JWvUylNJpPJZDK9Kpu7TSaTyWS6t2Rzt8lkMplMp+ueYKL3+336nu/5HvrVX/1V+pmf+Znp7zudDv3ar/0a/Zt/82/oz//5P09ERL/+679ODz74IP3Jn/wJvetd7zr1emEYUhgyC6jbvcNOcp07/7LqvxYOA2TgMrhJc3MXYmYZecDh0syvhstwpBawt4K0K/J5DufrExpEkSwAAOYxSURBVDODi8AInlN8SOSlus7pHE8iyco6co6n6duTZ0S+0Htgmr4vPTtNJ+p6x8AJDGIuhObSXR3zXyu0geldcyQDGvncI2K+U5Uk4ygGHpsLoKrlomybJcDSI3sL6+FKT9bl1jGz9nKfP5ym00Tykxxg1nWBd1ZRjNoO8NLnE2Z3eYrPPU65gJqBh0KGM3Lpnm1J7ljD536wACy1QiohXU2X22bi8PjoKUb1w5P7+BrAv0fOeDaVz+QC66oGHPSLJcn/KgBjcs3nfr6qmIHHEYO4sI5qrjx/YKN0OkP3JJD1mgfuLjLgBwqy3InwuTh8KsSy4M8hs/VIMeZawG2TbHcZmgcTHl9YIuR7jxzJcxtMmFt6MOJyP1SnmdoZ8ljLq/hWgra5PUB+uOznW2Xgrw+5n19T/PAsjPE+8PW6dCjyjR3uz0HMffHIlWzn/pjPkJjEXM/V/KbI94DD88J5Z32aRr58XkGu74Nueh6Y0jezsn/lMtxuD1W5zSqeBGDfHvH3isDD1czbKMncxTm8l/S1mrtv9FPy3YQCNeFci3jeDJPGNL2YV+0GdYzMQORDaiEntKz4mlU4CwTnHry2blb8rAldW8+1yK/EvrGal2WdxYDdqEhO60GADHPuv7fUo2MsLALrsahiFa5BxFxLks0ozzPhuaibyvJhfM+lPJeNYV3QjWUsGADbMjfhdr8uL017UIHYhiW1Mr5UhfMgYLzmFcv6KOQv3oBzLHA+ICJqhVwXHYBKzmflnJyDdSRyTA8mkvE5B3zdXIbbUyFv6RjYm3nItz3ken2mJXumWNvV+PtPNORD3Sjz9SI4z6SWlVx2rGe89jXaFfkysLZ4yHtwml6TiHW6AstmrKOxYh3jmTcbBeCR+rJ8OKbOl7iPlTJyzTCEsypiwe7UXE9OI98UWczX+vKhcE15DjjjRbWeHk64PR6tcV22xmWR71qPr1HK4vkeco7HM1+QLX4ylgPCBQ56CZirnpq/Xm7y2JvP8Rgtqrp8us3X6728MU3Xrst58/qA+zmWD8c19ikioou109dpnzmW/eMw5Dp6sMb30eul51p8EWRh59RE3Qxiir5MruqbUV+ruTtN7/zTe8V9WCqPYYIO1ZYIx95jda7vVM2wGYc/w3V8S8XmOiwNcP5CJm9Vhoy7yv6qyrkV8XMlwz+HKZ8RMnTk3v+VAM73gTiomc3HLn+vne5M04+Ovknk2yjyImQ/HMN35F7Rg7l8xecxqdeg/Ygrpghr3gfKcv9Vgnra7nP94/Vwz0dEtAftXoTP5nIyBkmuOo/X81UZ05A5fhXixFg1Gr4jGcGeo6z2nmKdBmzm1xrpeN6Nwn1TE/j6yIMvqzVIGc61WoeJYAKTij4T6BiujZzxbiQHURCf/p4Bz+kiImrSzWl6ns5weUj2c6xabN4Hq7Iu8d0Y9omRYuHjmqEOlaTPnCjC2TUunKXXd+TC7xqcwTd24Dw0R843lYTn5a7D47UT8u8vFuU7jLfUgXFfxD2lbPhjOIODAhh36oysNThzrwhz7c2BXCveHvE1+rAGOQrleHipw+Wo+/wcb2nLM/JW4L4rJR4Pq/MyVpWW4YzG87Am6sl5OIRtvQex2IMzKMqefPYXulx2jLkPyKLSx/fhPSDE/Zjk9UZ4FiRsGfYGMp/rEEXJ7HdvIu+XlevrrB/+4R+mv/gX/yK9973vFb9/+umnKYoi8fsHHniAtra26JOf/OTM6330ox+lWq02/be5uTkzr8lkMplMptcvm7tNJpPJZLq3ZHO3yWQymUyz9aZ/if5bv/Vb9NnPfpY++tGP3vXZ/v4++b5P9Xpd/H55eZn29/fvyv+qfvInf5I6nc70361bt77SxTaZTCaT6c+sbO42mUwmk+neks3dJpPJZDK9tt7UOJdbt27Rj/7oj9Lv//7vUz6f/9Jf+DKVy+Uol8vd9ftHqgGVsildHch7oSWmArYST1lOisCBaYZscYiU5SQD2JFMyh4FxLcQEXnEZayli1wGYgvGUFmS5sD2huiOSFmITsATFwG6w1FoEUSGzOW4rJVE+tm6ET/vGGyuibJT+GBNW002Z+ZDS8Zilp8XMTlEEnOBNtWKstsVwc6O7VmEEaDt9Xtgc6fdpWlS23J2R9xOaP9bUl02n+FCvdzh8qwXZWGxrW4O2Grkp/KC5VTaiF7V3XXJ0n0W1U/Zb5N32AqFtmoiaSWvOVwmtPUVsnJ8NXJc1s0Sl0/bBD9zzBWIFrMLVYlfWE7Yrny5w+VpJRKX8pDH5cC2DpTlGq24OF57auDsjdhehHWp/zeyB64wvPaiL8f4MeAmemA5q6bSjo1ewT5YOnfca9O0bqeMwzEjAF9/lEh71yY47JoBP4m2+OH4Gk34s9ZY2rY6Yy7HQp6/cy4v+yuiBo4AdXTkyufwId7lM2wPzzqyj3WSG9M0xrFaZl3kW8ty3WI834D7XlB2sW9Z6EzTjSLX/2ZJlmG7z5WJdv0glj0ECR070KdW8nJAnC+NaDCRqIp7QV/ruTuMU0rSlHR4y8EcOoq5n97qy6XPUgHQQhHYeidy/PsYJ1K0jsv2RRvyzT5fb6nAv9e2zYyD9mS+3u5QPlQr5HxoPVfTCLUBLVaEcu/0ZGxBLAVajU9COa7HgCbwYR6+GR+LfFGfMWgrBS7gGb8m8tVyPJbRWn1ThnAxn825PL6CZHJqHiKiHPF9DyK2onqRjH1DgrUPtKHGaT1W5zpbK3I+3e51sKNi12kqfFgz4TL1HI4t8yTnB7SYNydcMdrW3wOEXhtswr7CjF3ucHl7Ex6XiKHT69UQ2n1vyOnbednhVqFv7w64/nw1ya+XYH2YAl4mlP2ymtSn6QtVsPX7ckz2Crju5t/rtUU5ezqqpJiVz1uBn7EqDkL5vGhDRnSEvu86oGNwTsBxt+DLsZZz+TtVHzAoOWk33+nwnLpR5Gt866LsRysFnkPnAW/wTfMdmqVPNXkSrHiy/9Z9tnovlrkv73ZlG4aA8fujw/o0rfFGiNA8AvzPU00Z67sQmxEvh+N/raTHLt/sxR7HhYnq572Ux9fNHl8jSuT1GoDQXIZqPgrU9eKIJqlsr3tBX+u52884lHMdWsjJ+rsNYxQRLhqdsgTz6ATQLusFGSNf6PIYeKU7g79CRMMJX6MzxpjB37lUk4F1EdZt5weP8PfdtshXTni83nQuT9Mn8U2R73b09DRdzfH6tepIbEYr4f+IaA2vTtOXS7LdquMnpml88q4jxz9iVNcAndpQ2FgcOw3gt12oyNg8hDbE9xHz0A26aog0YY5ZhNi+WpRlOBgBngvmFI0Za8Gc34QJ8VpXxrQ+4ByzMP+PSK6/d2GPdAjpUkbODxtleF8CH+n+i8+P66CawsFerHB5ce5BnNFLPdkvP9vk2Dwgrpc+oEmIiAZOe5peStb42q6cl/IE7wXg3VWg4tzekOPsHPQPXFsTEY2hCarwvPrZcXzhJzuSRkQ5B2I13OpQ4VxGaefUfHpPOQaUYAT1d+xy/S2OJYotgb6ThXaKFF5uHxAur0DxFvOyn9c8rj9cg+yOZFsjUfIWoASfbymsCqznFgBrGaWyfI0cP28I2LibBw2R7wzgfsfb8Oy+fI4TQLZdBtxfF97F6PeUDdjf4PMhnpKI6EKN62Uj5v2HKoLAc+L7G0ftHzIO0ThRX56hN/Vfoj/99NN0eHhIb3vb2yibzVI2m6WPf/zj9Au/8AuUzWZpeXmZxuMxtdtt8b2DgwNaWVk5/aImk8lkMpm+arK522QymUyme0s2d5tMJpPJ9KX1pv5L9G/7tm+jZ599VvzuB37gB+iBBx6gv/W3/hZtbm6S53n0B3/wB/Qd3/EdRET00ksv0fb2Nj355JNfjyKbTCaTyfRnWjZ3m0wmk8l0b8nmbpPJZDKZvrTe1C/RK5UKPfLII+J3pVKJ5ufnp7//a3/tr9GHP/xhmpubo2q1Sh/60IfoySefnHlC+GuplJ1QKZu5224D7iB0AGrbK1px0O7cVccxH8fsQck6bLuYS+T/4qOlaAj2mx7x91/uKQuGz3bKWaeMExH1wb67nmxM025WWkTycDI14hc06qEDp3wXiC1rqarLItjr0f5U92VXbIGVfAtsUXfbMxCvw79fUtZAtH/cV+Z0BSy+PYX4QHt9N+LydZT9DPEuJag/bVPD7rJV5sZZl24g1VbsF91SVu/9EXdMRAucLcp8DUD83FfmfINoTuTzBo9P06tZbnfEFBHdQSe8KrRq4/Np+x8+49kinOas0Dge2Km0XQmFlum8y21zlLZEvitdfo5HGtyPHp+X4wbt0yh9svVgcmo2ao7l9dAGulXmupj01anjYx4rE4cvPiJpSTwAW+iY2HacJx7vMcnCIXZnCcbxK13ZNgtoRa1yXXpqrGFcRNv8CUmr3PWUETX3j+7jMij7fzPi8kZgmzuT3C/yNVzuz4cp3Es1mVvktgrArreZnBH5qh5aRPnZL5a5PDVlXw8AlxCBDf/Bt0uUxQMxn2b/x/97a5ruRDK+Yb8fwLXRpklE5LkJea7yuN0D+lrP3bmMQ77rUF6xFJYi7veDmOPOzdFA5OuM2f4MtDUaqXmuAnPyKvBTxmrRgLbB1pjvi5irhiRUUQ7shSs5vm9eIcwQ04I29+1IxzROo00SrZREclyPIIQgPoSIqAtrkHPEqKhiKq9Xz3EdoS23qNYWaHHuQdk3g0WRD/27GEPaY0DwTORk24cY2XJ5TCIi704+jlVrCceqHVeyfWs9/sxxuAwaH3QVwtO1Hq+J9p0jkS9xuEHQQh+rBRPOvYUJrINI8qYyUEkD4MjkVZ0DJUggTfKQreTJh1qB5y0AQmOg1pSLOb7vO5f4O9qiWwJcigNz40EgsVvnKtxWb61xexYy8sYZwCBinFbLbtoN+DPsb3fN6bB8qitEEgrXi7PWD0R3rw1eFSJg0lTOD1j0/qQ+Tf8fW3siXyfyIM3P11B4mPcuc/1dWuM5q7gkx832S3yvcyWe5CtZjWzj4OVCKD033xb5bmwzGgARGhqTV4C4vQ/4qn4k2xrHxxjQDPOwhimrXe1+wHWE+4C3Lci4mj2pT9MhbHY0ggBjbhnGSlmNm8WcT+Nkdr94s+prPXcPopQiN6HbQ9keiE/pQzddVMi7S2VeKyPuoDeRHQHH4QD6lV5dIXYB74Sl0zgi7HP3FXhNf1WhybJwFURHDBI5P/hZvoYLmEaX1P7cZbzGXPES53Pks1+HdWku5fvmSW0+QT1EpcZysYLIJJw71FARqBGMx7jmKKht3jzgLHD43BrKWIBrrifm+fdbZbm2SwCxmHFm7yk9qNstwObgeoaIaDg5fT3uKiAsvn85W+Sy6zXDtQHf93Kbr63RuihEb+BepTWW30GES+BwZ4xJxn38OQS870TlqwJWeIW40uu+7B/3A+4I5/vbA1m+OViO4X5zpNYWWOWwtL4LnRbDexAca5VUvusYAoIkBhRNzpE4MsTeIC51IeFnX1L4MOznu4AExvUWEdHzbU7vDLnO01SuUeseX3/e54q5VJFtcwzvKkawR33XksK+wBjAd1mDiSzfy/DuZBfQM8s5tf6Ctc8IYu5cQQa/UYTPAXt/QEOfqDpaznNlPlrjthio2D5JedF2e3j6eotIYjtxCTJQYzqfdenL3Xa/qV+ifzn6p//0n5LruvQd3/EdFIYhvf/976d/8S/+xde7WCaTyWQymWbI5m6TyWQyme4t2dxtMplMpj/ruudeov/xH/+x+Dmfz9Mv/uIv0i/+4i9+fQpkMplMJpPpNWVzt8lkMplM95Zs7jaZTCaTSeqee4n+1dRLvRIVMrm7cC54AjueIBwlyioLLoJjOKl9eyC9X02X7ZTlhG25l/LScuKD72d7yHbzHXd3ms4l0vqB1qqGz2W4UJInlR+GbDP59AlbWMo9iT4YgtXlJOLn2CFpdx4AbiI3uTBNV1xZvjaxJaPusCVG2/Az4KlHHI6v84EFyJnt4hCWpzU4tX21yv7rYSgtSZfbbJV7qceWE20hmvO5fNsDrvPrfVnndY+v/0iDGyqvrMB4Ojz2N33fIVhaBhN+vpxyrG0V+foPVthiHiUlka/q8/Pi6elFdRLyPpxI3h7zZ8GE07oMaOmK4NTjvC8tSQ9WAVWS4T6vrUaINJjPcdu0RtKOhdih3SFaEOX13r3SnqbjlMv3/+7WRL59ODl+H04G74yl3bnknW4bmii7vjPjbGe01BERjag7TSeEdiquI23R23WvT9O3U/5OfSixUQ9Fq9P0Zpkb7sl52X8RSdIdg899KE/rjhIuxxjuexzIOr9Ct6bpAE47X1X4FdQePJO2tvrgw58Q19+A5HOMYrSwAgIG+uUgkO3yUp/HbnufrbZ0bUnku6/Mz7ua53hZ9WT/gMPE6UKZ+7yjcC4HQY6Ui9V0ioaTlCI3Je2ex3kEsUM9V2KfBjHPRZHD/ddLJc+hQtzXsYe4avLxIKYjggSdn9reeQDddBkszXkVS5cAm9GOTrc+ExFVPP7Feehj2gp5fcDBHsldJUfOh4cO26RTsPUWSNZRCS6ClvxQ9eNrPR4flSyXaaMk7zsAjzg+o+eiNV6O1yogZo4gFgwV7gst9Yh90cLm3dUcE9D1cXuaPnRvT9Nh2hf5MHZtpQ/A72WfwDm14s1eg+C0UgXmnc6H/Q/nJUTPjSayIy0V+CKI09OYliNA66lLCGG/R9v2RlGuFd8+B0g0WKctNBSK6TojQ1ywKmtMC94X66U7loUtArLmOMMVpl38fViT4POWs/J6JyH/jNfYLPH3myoWYNl7YOf+5I6cu3E8IFbQV17k4YTr5dZhfZpeHksU29oGY9CqsGeIYzm+dts8Bx6MeA0STGSw6gLyBp+wFynWIcSQ3RHfN1SIOoxJi8CBWC9wReyrdca13umd8UxZPhNiQnaGiJ6R4709AcxNBrGMd8dBL3mNTYmJiIheHrUp6+SoDHtcIqLVAiDWxCfypx70uTqss5qhnOfwW4gFakdyf75R5PuWYR1/AhNYU/FRYwiMWdi3I4qQiCiE9XAl5TVHonBJE1izlmGfPJdI1Fne4T3cocNrY49kXY4c3qwgGtZL5VxbSXm/U4I5WffiNjA1JJpFjqlvmjt9P7ec4+fNqDXv/zrmMiEyt6ZYruIdAVTfU8dyz9Ycn75fLShM6aU812UD9t2K1kEngM1owJSFODMiIgdAQRtFjhl7I1/lYy3Aug/nISL5DiOEePz5Nl/vpnrngGuz1QzvjfV7lHG8OU3n4b5XAlnWAvHPiHDZKMm4/2iNnxdRuBm1TkN85Qm001Eg+wQitXDvHsSyzhNgfCK2JCK5n0blXX4Hl6HZ/LZGyvnmsjjnybJ+4YTTuOZV3U2gZ3AtiwhEIqLNIscnnO+bY42ROX2ddq4o59oKxMibsPYfJbKAn28DwhjGTb8o29p1uF5CKENtJGNQ3eNyDCBmD2Ht01drtg3oHytVjlsrl+R6uveJs9P0S13ul7HaFOl+/6qyav+2O4goSmewe5VmQ5dMJpPJZDKZTCaTyWQymUwmk8lk+jMue4luMplMJpPJZDKZTCaTyWQymUwm0wzZS3STyWQymUwmk8lkMplMJpPJZDKZZsiY6KD0i//2RpKPg3yhCiCTNKu0mmXG0cUyp9cU69FvbvE94VbIXyOSXK5xAnyhgBmQZ0vI95UMsl1gb3mKkfjYHLNB6x7zDT/jy7J+vsnpV+jmNH07/KzIF8XMqQwLzCtaTM+JfMvE3PdeDLx1xfJCnlUeuIoDxZ9CFhdyvCPFnkbuWtVnPlYIvK7r3Yr4zpU+NzZg2QXjkoioDSwvxEUh84pIciB7EWdczMnrpUBJQ/6qZr4jM78E3N314myW2hzw4OuKQVgB9mYVeLpLOcmGilLO1wacYD3H9YAcdiKiIvA654CDHqeyrE+dcH8WrDjFm/OBo7VUwHFTFfkGAG5tI6CsIkNfe8yd5/f2uAz/syU5uQXifBn4P8ihYq5d8OqcDx5Ec4tbLg+wAfGYbKY3RT5ktSHHF5nqkpUumYsDp83fVyzxm2PgoI6Z+XwYSj4c8uvOw1BJFEcuHzJjPUq4/jXvtzHh8h24HLewToiIfDh7ou4wE3aNFmiWOikzICuO5FI6UA5kz77Y5bp01VjDWL8z4Gc6CeXYOBhxXSwX+L4rednweThnIOfyd+bUGQF3xo2C8ZruUj3nkO+6pMI+jSBOlICZX0pknxjB+MWzBTJqrAxjZPzz7xvqEAjkeCO3EZl863LqFjzBQ+D6Ig/yjvgzZOhrLnvVg3Mj8lzYNcXnHwIP/gTOOqj7clxngMH4Em1P0z1HxsjdwfI0XUs5ngyBy0pEtAR8eeTI6r/saMC80oGJuDUG7iaNxHe6cEZLL+bzW1xHxv0l4vVJETjqmmV7DhZ+yHk/GsmYW4U+5iVnp+mKI9dV2KLIwxzEcvwv5Ph7CxCEdHzagPkWWd03ejJ2NKHPaubqqxqqNdYE+iWes/NCW34v0YPviyqo+9RnnPOyUZL5HqrxOhL53GEg2zCfiSHNv9fPgSxVnIfHelKG8dWAtdPNobzvCUz5Q32Q0gwhX345z23zck/2+sMR8s35s+NQxRmIb0/OQ335cj1yOOK+/Xyb10j9sRzj73rb4TTdawJ7VvFXBxP+3s0h7zPwLAcioizEJ2To41kpREQezIF4dpCjxg3GggcBfYysed2cGHOPYbw+dSRZ2AEwUCNYS5VVGZCDXspye+hZOuvazP3laOyMKXaIxqmMfQXocysF7h963z0EPvQ8rJ92A9m3MRzgfB0msm8ja3sCnQnXjVXFLc5AQPYhumP/ICLqjDmOJbAn12cjJbBuxmWz7ou7wEHvhDwnt9Xe08sw7zsDc+BS9n6RrwYs9XkYa/r8C1z3zuVmv0aqAgf5O+/n2DJ3P4w9PQ/9F34/8mKPx1pXHaMgYz2nD0eyLpvAb0fGelUx1h9r8GdrcJ7RtmKYt8b4XoZ/34nk9bDOlqGwD9bkORSuw5upIvSX9bxcp20UeY3zUpf55tt9vtFR2hHf2czw+5YLNW4n9apJvNPIQV8eTuR+Gs+DqMJF3rMo55sH53kve7PD18Czb4iIrg54jOKaRq9l8V1KM5gdVUewdo9gjvHUuKnBPjID7zM8kvkW0/o07bvcNgcTft/Vncg4g2cvdWAvW0vmRb71DK9/52EMPd6QHf0+OA/mqSP+zrW+OgcIioHn9FXUer+U5esvQhzU8bIPxcB3J/r913VYuyzDu5i6r99n8jOW4X2QC2uERL0P2g24fP9rl/cVDw/l+nzWX4OXPNWPxPl0s/PdHkU0SfW5LafL/hLdZDKZTCaTyWQymUwmk8lkMplMphmyl+gmk8lkMplMJpPJZDKZTCaTyWQyzZDhXEC+m5LvplSTrgZqjTEPpwfSJUFfOGF/wEKeq3ZRUjPo/WtsUbg+4AuOpcORitA6iHqZB7TLE/Mam8EX6QEG5TOtksj3cMxleGjhZJoOleXkZp8LURqxLSfjSttLNsP2iknKlrXUkQ+VgNXNhaLHyvg4TNketAT3Skk+70qBf36wwt+ZKFvItQFbsl7qnY6B0Ja162CFRrduN5INjy7m++t8n7fNyTovQXuiFQ0xQEREUY7L3gM+zJLqR4hSCaCatbPYgToLJ3y9VOE1unA9tD4/XJVWLexjEfRFDxpUlwGRMstgS9tXthy0d6NV7tZQ/n8fjkPE+CzkZL7tAV+wO+ZCHUt6EP2PA7Y7XwPfYEehRcKUbXRzKfuJNaqkGfI1ELvTHkuLUBXsYrHLnwXUFflyDt93MdmYpvsu55sopEwVypcDmyZiC4iI9t2dafplQBhl3bLI58FzLOTw9yIbDSY8PrBeSlmZ8UyObYyl8Pw03SfZOEs+l/dba/wdRVIQY8Bz+Hn3VcYYBiyiGfD5tGUQUUWYzrpqshD34fTJWPaPAGz4GOdPxrKO5v1YIERMp8tzHfJd5y40A1pOW2APR2smEVEd+kuUcB9zFTcjD99zwHKqEQI9GAOdhPtzr4frAjk/IJLguRZ/R1tbS1nEbvHzrcvL0XKO+80zJ/Vp+upAo6wgDdbn/VAiUmKXnwnRURoj1aGDabrl7E7TBacm8i0B6mUw4WvkMhrXQacK48See118Nk4ZbZEiOkLhXBDdEzj8vEmqUWec3oJ6HigLdx/WXGdyHD8vVOVDYIveHnDnaYXyemiJhaWFWgVJLBr2xb5a1HQn3NgRnR5Xqhm5tgtjvjFeezEv+yWik7DNymqXkQOUVRifbo0nItof8dpgdcLt6fty/bVaYkzQAyGX9ZqrESR0qiI1eAtQ9g3A37UiOcB2gE7UhcLrea4CVmGsix5gdxDfQkQ0noGH0fi7tTy3YQOQTRmFb3xolZFLJzcZBzmKZb/c+xS3/af3FqfpdiTr0r2rB97RcCKffQjt+xYmC1ClL+sS130tiEEP1uX1EHGA+xucN/V6xNfsoy9qpOzaJ07n1Hy5VD57BPPxKObyHCqb+9lS8a4+bbpbZ9wF8twcNSdqvpmxTuuotdS5En92tsbr4ZtDGccK0MlqPu4LFN5wwOXAOT4HiFFPxRZEaI4Bf7WvxnXfYQwEzj0a51J1GTeRSzkOIqaQiCh2+BoNn9FkjiOv15vsw2f8HJWkLvLdV+N571KF67Wp6ny1yHG2Cf2+4qn9eZXjdqnBc0/vOpeheSJjASIN5wEB8bIanjloz4vV2fMICtf4mj52EHCdtSPuOx1JfaIGzMOIZXxJlS8PmJq6z1+qZOVEhHEM4zvOp0REAezdEVv01jkua7m3JL4zB+8SlgEpeVmV9XqP944VDxAral+LKKsV6PNnKxJRs7zJP4+h3Bm1318ArG0A46autlVBwnWE79lCNU8ihm+ccJwOHRlbAoL1BOA+VxJZf4jTw2nEi3hPGiZqfwnrqkHK7w889c4shr0/7lu2FTaukgXMoA9oXvWuowfT2Ta8LzkM5X4f1zd96Ip6D1OY8XYY15B3fsaf4D2DQlnhuDwCBNc+4LM7Y1mI24BOvdrlfHtBQ+TDbyHWRi+jusBwwTUHYvaIiMZxkcaJS+o1zKmyv0Q3mUwmk8lkMplMJpPJZDKZTCaTaYbsJbrJZDKZTCaTyWQymUwmk8lkMplMM2Q4F1AncihM3LvsQCOwAOCJtUeBtOX0J/xhG+yFN3oKY1DB093590vS6UK+e7plciHP16tkpZXk6TZboxBloS3R20M8cZqtEQeh9NEk4HkqgFVD28WGIZ+8vV56B6fTZZHPz/D3ci6nDxQipeMyYuZ2wGWdz8pKQivI2SLasWTX3oN8Q2jPQcTpbiTrcgwnpG+VuAxVX9bR8YzTou+vyOtdKLOlqJpjS4zjyHbeG7D9Zg7udb48EPn+xwHb47vwHDvS8Ud1jxv/C21G8mh0xAI42M+XwCaYl5gQF2gbiAxAG62j7L54sjjKV7bjjQJf79kOWL2U8zwG79Hjc5zOq+tFCfcDxLk835HlQasm0hPmlb2r6XI/P3C4XlYS2c8b0G4lqJjxUD7IhLjf5wEVk1MolWrK1uoccV8cpnztgqNOUnf5hHS0r40ceW3EwLQdtOXJfIjQ6USn/16rlXA9Z2JpK4vgxHVEM5zxJfbhW5fRDs/195kTGdQwVmXAq1VU9vr9EY+9FKza2E4rRfmdsyUu6//P3p/HypJd573gioiMnKdz8szDnW/VrbmKQxWpwRoeW/LwbKhF4BmG/GzJggADlGGZDUMQDEuNZxgUYMCWG6AtNFqSge4nW633LOlJrQE0ZYmiRKpYRbJY4x3qDmeec86MyMiI6D8uK9e31jl5WU+iyBrWVyhg58kdEXtce+24uX4bT0XHU8uJZMhf2T977BFJdNRqiS+q+nL89mOXhrH9W/c30nzu/hqn0SeImGrkefwdBbI/OsBS8+AeFV+OsatV7gukBO0P5f3mcjz/fbCzJzHPh5YKE16F6bFUOBuhQUS0DGMTUUCXStPDhN/ocD2utyUuSYeIv6W+I/9eBPsUERe+kl4R+VKHx3DTYXs5SJsiHz4X0RatcVbku1RinwZD/A9dDlFvjiTOBVX2OTTed6T/kIANGkJ924nOx4vj5RLby62BxL4MxtwhBQ/DtGWZlvP83DzkOw6l35KD4Ydjdncg+xoxeXmPL2qN5dq97TJeZww4gaVkdZLuxnJg7g35fmWYDzXZTQSuAJUyZ2P2iCRq73qXQ5w7kq5BbQj1bkE711I5frMQHn+txuvXYVgX+fZgyUc7ca4s5zj6IDXwfT6sfO1Zn8flnx7xPTSWsQTuItoMjMZeK00Pzcbwa2XehH3rIMpG4VLmgdX32Exrkj4eyvH74h77GfcGXPD+WD4Y+3cG1qwThTcqAu6gARge35E+NNYDw6zXC9JPy0C+20dsMHcG/PeuQhghxgN9+ooKr6dE+h1vKa/QHVnYt/iQPhzJso6TVCC6TGdrvuBS1vXIU9gB3OfimlxSqIeqz+NqcYXn/3cr9NGLh8wTwn3BSknO/6ynuJlnSPsZBRjniHPB8UFEVBszKq4H+KAolWPHg/kxAJ/cIVnWEuzdfZrO+5rxGOEUOLwXmEmlj4/I0A/M8HP//KQi8uG+uwjYEl/tZY973Kd71/lZBwE/aGuo3jlAOg/zfVn55CeAkUE746vu2wU3ZncIyCv1DuM44DHhIboyL/Nl4Ltd6LbdoVw3i7AOxwnuZeU4x3czH5rhwuo99LkZHi+vHjJ2F/dfywpLca4ISBMYl/sDtc9Iuewu2HqsA5HEGaKZfaVZF/m6r/NYbI047av9+fed252k/3SL/bTnT6TNRXxoHRA1HfWyrgtuEfZvI5EI3xjmDaKUqhnp1PhTUGBoj3bpRHzXdhid1huxj9pJt0S+DKCK3NEy328o23wR9hKI4EWbQyQxV9uAQemp91poutA/1OgebAu0YzpfAfCLGbi5xiDhPDyGd6fo06MvTETUAyxmCcoznzsbRUgk+70bybLmxHsB/rt+vzSXd8R8eZBsd24ymUwmk8lkMplMJpPJZDKZTCbTFNlLdJPJZDKZTCaTyWQymUwmk8lkMpmmyF6im0wmk8lkMplMJpPJZDKZTCaTyTRFxkQH5dz7DC7N/8sDR+cQ+Nep4o4t5BncdCdgntgJMJKIiAZt5pPNZxngdezLf9OYAfbTDKCarpSn84CuMzZLMII1KysHrLGtYfbMvxMRPTfP1zW6zAwstD4m8u0VuY7FhNlnjbzknZWhjiVgqTl9yXDMjZjRie1aUnxjRGLdBYZjR7GQtnrcZu2IgZM1n6/JKM5dLsPT44rETQuVMsywqkB1NdN+EZjml76fYWrDmxIG+tqXuZ1rwPvLKybnhSLyooC/psYllmMAfOWaLzPOA6d9DtLIaCcierPPY/YoPJsbdbUsma3IdxsAr74VyfExiPl+iJs+VtCqCowjnK+RYhAiu/vNoDVJd4BHSES0njKTLAtMrSJJbmaYct/s0e1Jukp1kW+5yHMAOYNbPclcG8Rc/2OHebWdaEfky/hcjpS4LSLgsucV39BP+d5YvgLJMnSBfY7nHiDrkIioDpfhmGqOdD7u324AfZNKft0o4Xogl05h0egYmNIbwFUrqdULeWc4XCIFlUamaQiM1DUAEF9VNjYCvuYtgPf2YznOd4kZ0Bst2R+oEPotiJkFeLUsy3qxNKS+4hqbTqsfE43T0/zwPnRPDyCOVQXyX4PBtN3nm8RqkUfbehnGSDOU90PbtVbmebhGnF5U7M4ycIaRuajcAsHdRzT2nb6cEPuAjtb2E1Vz2L4jW3tencvQivmGxw6n5x3JEl4tsq3aGzKHdi9tiXw+uJ9Fh41LrOzE5oCfhTzGENiuSSrnYSXL9nwlfWiSzqZyvUHuu5tyG1Ud2TnITyx4nE+ZSJoHgDWet3KrLes0BkbqSgHO3PEclY/T2wO2O1vJscjnp9x+c7AOlDxZ36vp+Uka3Z3A4ToVTzGgOSPOp9msnBvYFjiW59WZKmNgTSITXeNHaz6XqQL3qDUkP/iFG+xP3+gxy361IP2qLrD2cTpcVra+Dj7XzTazgI9Hsl0uFLlMj9S4P/cDWZFpZ4a0o7PXKyLpRyJHPVT3LrjoB/FFeU+ON2zbALj9XeV/IYceOcP6dCa0SegL6LriPbAMecVzDcDvKwJj/XZf+l/6ureEZ1ocRnJ8VL2czn7/71nJnl3Pnr1e630e7stwuhZ60k+u59y3zVV9P8t1HfJch1aLsj+erHOf1oGnvx/IMfvUwtEkXf4gz//FuCOfc8Rr0VeP2TYEiZx8JWD84npYFedByH59pQnnAoR4b7kuBXCWSARnEY2SnsgXO3yPJIW0N30d74Dv6Sp2+gzxeliGPUwjK/cC6N/87i5veu/15HPRz8K92BuyunQA51JU/bPthD6DDvdzdehqXCeJiBowD/FcsXpGnYkA7wzGQFzPOrKNcH9Shb1/Ti1Mx3A2SQsKP1LjCH04B56FayMR0XyOG+2ZK7zv6zbluSx9OBuuA3voWVgn/8cL++KaxUfYv9n4Gq9lQSzP8PrzI37/0gc/eU4dpjftiIc3unJt3Av5WXhGwGpeDpDzsKYiL31fnR2G57zgGPXUOxtktiPvOyQJ6F716pP0Ijh3mrF+EPB1vZTvd+Ayy31ALXHNGM5sGceYluvSTua1SXrJ4bNIXEfOyTfBr3/Q+WNw1J/Yt+g2QsXQo9qy4PkhdVgr8+rMCDwzAPtmpF5ETSsFnjEwVottGeZhHSr/ZF2er1QrcN/carJ92w1kW57g+XQwLnfVkX2+QxS+zfNM7JfoJpPJZDKZTCaTyWQymUwmk8lkMk2RvUQ3mUwmk8lkMplMJpPJZDKZTCaTaYoM5wI6DB3Key61RzKkAHEAzRGHVmUVOgLDJkqATOilMiwnJA5peS3isJBIhe1fGq5N0t+7xPd7tMahX+2RDFcgCOPqQyxwf6xClyCcGJEXBZlNhI9cKHE7JKkMXVwM1yfpXsSBIXnFZljIcxst5bl87ZHCtADmogaFWC3KoJD5HN/jRpfztULZh4h6KCaAaYEwXB0qc6/H98AQs6dqsp8uQpn2Q753nMqyJvDZgYfFCj3TGSNyhf/eCmVfP1JjZND3z3Lo4p3DWZHvq4CVqAES5nxJxrBc73JI1y3AjuiQ9c0BhMFA2NVMjstdzsiBdDTiEJsKhNthGC+RRHdgFN2CGpiIFjmBSK2DQGSjl7utSfqO87VJepzIPkwg/HmUcrsgMoCIqAJhV+eTRybppUxF5EOEy8MVftb2QIYWx13umy3iiRir8g1TDl/qJnuTdBaQCyvpqrimDmFXJQxjUmFWhbEMj+UyyM9H0LbngL70gRkZKjcc49jh8s3m5ATDaOrdAZdVIzm6YJ9moflKKrTbg5BOnEN7A5GNcmC3ES+FQ2w/lOPtTpfvvTPmuabROKHLY6fntCbpYip5UB5gPfYDbtg/PZRrxeP1Ig1jZZRNpzSKiZyU6CTUyKCzcQJl5flcKsXwHY+Jl2XUIN3scL4GjOcFhUvDkH9EM+AYraqw3gDCC0s+p1VELR1NCSeuKu7Lcch+RhBzuTEslYgoA7+lWM3zuh6p+X8RcAf1Ec9X39F+EKdnsoCYGDdEvjwsLLjGN1Uf7oZqAn9dZYfvF2VlntWUbfMln/MFyqh54AcNiedhNSttIrZFJ0Ikj+zDENazPqAAEBtFRHQPlpUQ5rei1Yl1Du3WGqm2hHDnKJkeh+qDP4Zh+CchjjdZCOwnjLYdKzuNqsIa/9jjMsT8i19hv7YH0d1VtQzVs7yu5HOASOjJjGsVbsydIS8QrUhO8jriBHKcns3KEHNsvT1ARwSqWfH+FwHtMuvL5w7jswOZ7w64z3qRbMwUfMUtYOhoNFnGwbHoQlrO8faI83kQNq+7EHu+DliVlYIM9i4CVmI/mI6DxDHig+07VPirJrgQiM0oKRuJuD+0i+tg0NdS6YvhmjAE36es7OU87E3Qhh0qnxJobjSECu4H0g8ap74IiTedrax7//95hTdbgRB9tAVFhTRZ+RB3kFOtT9JHB3LhRJQlwgs0PgwRfYh2acD80og1xD7ccTYnac+RtiB2+N4x+Ps5V/qHNYcRf62UER9BKhE1iEEcxIz4Wso8IvKtpozviABbptfW5iHbiT06maTLqcStPlrm8qLvc6Mt+6YNU2IRHOxlQLNoPx6tJc734ikEFKcj8J2OU9nvJwHYT+j3vCsHnA8OIm6R3u4MzisMGq7JS9BG/+PlTZGv0uBxXnyK/f/6khwTyT73/cKfbkzSnQ7XY+FqX1yTWebvsq9x3a+UpVE7DPm5d7pc1rm8XLsQJYxLUXssJwTu69EXULQ6evFgbpLeDri9Csru7w0A9QJzTaNKGjkev4UM1ymvUGw12HxWwNc+CeRzbzmMbN2PXp2kXVh3Z/2L4hqck5Uc78mzjnxnhv5rJuVrOur94+stxOzyczVu9fUu+0HHDs5duR568I7Qh/ScL+fDLCyw2G8ngd5jQRomTk519iJsbS/Cpgiqd8quot+9XuR7I76FiKgP/s058Acrqk6fO+S2wPcZzVDjg1yxd3yQ7JfoJpPJZDKZTCaTyWQymUwmk8lkMk2RvUQ3mUwmk8lkMplMJpPJZDKZTCaTaYoM5wLaH6aUddNTp8gOxxiOAmEqvgwbwh//ZyDEeT6dEfk6xCFUfQj5r6UyRHe1xGGS54sc+hUlfO/NocJDQCGORxwnMW7JEFgML96OOWZ93ZMokM2Yw0IeynHoTVWxTzBU/jyEapTUCMMQzC6EAI1V2GNBxzV/XZVTJ1tzu2wAKkOHIa1CJM3rbS4ERnGMVdxWGGPYPN/vYrUr8p2/xO33xZc5fOdrbYlmeOGQ+zf+//E17VCODwxhywFmpJiRIScYyYThv+WsDCtdznMbYShkqnAzXUDW7ACypaTCrOMpUS4YXqSxKvtDCKmFEFg82ZlIhisH8KAZhQLZ7vN3GAKkUQAxhu85jJTRp9ynELTXTfiU++H4WOSLfW7LNTjx/lxZNhKGHl6A8KKHhhL/Ead83bj75CS9kZ0X+TrO4STdG3N4fA3CxfJ0NpaFSIZZHUWyc3znbFzIsWrLc2UIU4MQvW4kr19D7BP8O+2lihw4GN6NaKwjFVKHUWFZCBfPe9PDrTCftlUYBVeF7hhAVL86pJ3u9OGUe4dPVV9Lr4l8I2KcS2e8PUkPvbLIV4Vw3WMIqZsJpT3fCzwKE70imbTGKZGXSgwKEVEMpjCvY0lBW0M8fR5D/uU1oylohgtFaU+OAEuFthRPondU+DrqXAntoHzm7oC/Q3SHxtAhwuUo5TXryN0R+SLAV5VGz0zSGlc3k+PFvAgIknYUiXwDwE9UfLZJeh5ifyDOYqgW4qqXg2v4uavuhUm6NVrDS6gCSA2kkwzU4lV0wADAVwU1VhDhsA2Ij8drEokWJuxoYJgqIgKIiDpQx9aI63StpkPMuVBF8IkKNH3dvDNuTdKuyueCd7tOvB7WIbxZ99MqhNFi82nriyH1I0iPhyokvMG+D+JX9hXiY3fIbdHdYnt5oSL9r5VFDnO/2Of2/+yBtLkCDwc4l4wjazID/tNKgft3byj9uRMIEUeMiavuN4ZQ7QLYFnRpBoqNg74n9rurzA/6PlvgWwwUAqwJYwyxFscqzH0MPuGFHIdMnysrrB2EU9/bZD9I2yrEsXxmn/v6YCjtJfb8hQqX9fGqtC27AZaX712HiHXdn7d6fL/maLrPgNg4tEeBshkxMI0OQx4rR9QS+bJRQ6AzTGfrLZwLYh+I5BgewNx7ZulQ5Au2+LqTF9i3/fP9JZFvc8CjbKXE49R3pd+M/Y1r1MUy//16R47zWcAsjAP2yRtZ6c+9EbHv3ibGMnrKd0fUC6Jdeqncj8zAHiRy2UedSeS7hMUC24aNIefbdO/JeiSMfSkS21KxTpL0yQ+G6Zl/JyJaKSLShPNdKfG8OVE2CDGos1ncr8p7t6Kz9y0VX2a8WgM73WNDEasb4h4T3/NoqkMtC9gnl8tQV+sm5vsby7zmVRfl/mvrdn2SPpdrTdKFh1dEPueYbXDtaX5W6RARLrIDvvC/87uF5094PdSIMcRm4Jgvqvc3uEzJ9V7mw/thf+L7DCLZthVY586XZVtiF/guz6lTeDNYIGcBqfj4jJyHW0AxutNh32xr3Bb5DuI3JulewPu5Uo7nXY4kpiVMuZ8yDj/3avq4yJfC+OsTz4fDQI7rg4T9mz14j5JRr2+bLrwjADsRuQsi37mUfeXVgmwXVAc6FceEbvNj4Cohylmjdk7AvytBX+M80QhJtBnni7DW9qUvNoK1IgJ/y3fkDfG9TAt8AY24LftvH+Nkv0Q3mUwmk8lkMplMJpPJZDKZTCaTaYrsJbrJZDKZTCaTyWQymUwmk8lkMplMU2Qv0U0mk8lkMplMJpPJZDKZTCaTyWSaImOig6KEyKGUWiPJsRsBdxQ5oRcr8t8gAAUqWMcOSZ5YccTfXc4wg/yphmQhrRckR/Mt3e4xg+nPjxTXGqBCFY+fcxKFIt9d984kPXSYA1WKnxD59h3OF0O7PEKLIh/yxGagukt5SRbaHnK+fWCp+Qr2iJ+Rx3Qscd90MmKmE3IukYlMJBnOrsPD/hbATouKw45MJyzebl8ysFbD1iRdy/L9lvOy37GGzYDLHSs2+azPZT0KuQ/bkZyuc8DurANXbrnREflKUKY4QZ6bbCNkSWVdZk4NT2EdkRcH3EIApuVUf64U+d73etN5k1noa0Tb7atCHI6YVbqSL5z5dyKibefmJN0ebU7S+Uxd5Jun9Ul6jj7I12fvinz9hFlj+w6fFzAzkGyxZ2eB8Qv9VlRMuAvAfhsnfI9RT3LMUGmW71FK65P0PWdbZeTk+TFzGl3FzYtTvh+e11AbS05jCzjPd4Hj3cjKPuwDW/9KBRlkmpHGnwHJS+fLMt8ccBErMDeKimv/Zp/brzlC+yGyCd70uQLPoReafP1GT/YT2svD3quT9Lgo+YYzMI4SoKq5inkp2JZQXc+R/d6JiMK3C2d7Hyvn3WcMO2psV4Bp3gyBZau4+8gZx7MdMsqOXQT2Ltqu3aEcZNt9OAMi5jSytnH8E0mGI3Kotf3NQpkawCn31U8ishGeUQHlSZsin+/wetZOeDw30xORb6/LDO0cjOcu2Awioix8VwHutmYk7sIhBEcx32NIck41nOokXc9yfRcKXOH10vTzafCMhbwr85VcWFOBiazZ09kpPzcJY8VBBfvUA1+sl0r/C4X2tzWSPsNjNS77k7P8rI2+LOCNNrcZMu/zJLnghZT7+njE14xifq5muy7mkI3J6Y2B9EeQ3361wpxWR7Xd3AX+7kKT+z3jSM5lAZ7VHXO/vdysi3wJzPk8nBtzoSQnzuaA79GKuFCLiSxgN+LxW4A1RvtpOEbuDfiaN6T7JX6p9EiNL8J5HcbSyKPdQX7oclGW4XyR59DmkPsD/WwiojJ0VcY5ux3u5+M+7IDfMlKM9aVHud/+Zpb9qu3jmsj32zv1SRrPwqkoY4VnGCEH9SCUY+wAuKoFKBKOva6qE54V8aBfjeF5OuiW6nMxQuAJF8F+NGJZ93o2Q6PEmOjfSDn3vo82UJzm2322SRfhTLDXDiTvu73NfYB24mQkexv9QPQL1HZf7D2R/4tHdSSKp41M3YU8+5F6zZsf1SfpY2A7j9X64Kf83QltTdLIWCYi8uA9Q9VlTvMokWcJbA15X9SDc3t8kvdDn3UG9oCaH36rO4Rr+LtZX96vAR9X8myranC2VKTsahH26jjXDhUD/W73bMf46Vl5v/UinmkHa6jy8XH+j6Hfswr0vgADaQVsVV+NX/wUgv3cv1cV+ba6vEavxa1JOnn0IZEv/R++d5J2X36Zy/fVG5P03m/J/e/nj9hh3eijHZT9iXv3k5D7qT+WbY7vSK5WuYb6pAl8F4anDM4q3+Jajc83yfRKk/RhOJ3Bj+ceVZVtxnIgurunXqUhx3tnzAu2R7K+5zP8LuAeuCcJzK9OeoCXkAv3SOBMtqYjHYOWw+cj4PwPx1dFvtjhwh86vNZmFYt9LuH578E5D7lU+lV1n9u2DOvwUJ3L0ovG8B2889FjB9jpJyn7dmOSjT6CvYUTchs9MeZy19T4QD/Xg3MZPHXuCdqQSoafW/KlHWxkkdnOZVgvSz+j4p8+32Wa7JfoJpPJZDKZTCaTyWQymUwmk8lkMk2RvUQ3mUwmk8lkMplMJpPJZDKZTCaTaYoM5wJqjWLynZgWCrJZMMwaIx40qgRDCkcJhwpgSBgR0cUKhyU8VWecQN7VIVicD0PMMEwNQ8+JiGZy/FyMEB20z0bDEBFdSB+bpB+rVMR3+d5Tk3QtwyEiWRXvvFzAEHj+u6J1iECQGWhXHYqOoZoYGvRaU9YDMTWrBQ5b0XgYghCbGQgxw3B4DL8mIhokGMrLJf+dHRlGszm4OEkjbmKtIMPSqznu64sXGOfwys0lkQ/HEYattEYy1GgxD98NuO5vHM2KfAGEK2Mz6zC6NoTLpTo+C4TDuQOhPNjm50okhLdrQDjcVk+O+azHJVwr4RySfVN2eW7gOI9J5hsDGiAYMcYgSeU4qvscUvehBhf+KHhS5Pta/2iS9iGUcjiWMaGLea4XhkLnTmGGIBQSkEidUV7kG45mJmnPkWgQLo8MpUwhlGyQ8thrZGR4VxBzvgLcI1RtXocIO0Qn6TkeTIlgns3K+41gXK4U+LuVvBwTPQgp3IKQ9ZIn58MQwq/QftR8WcBVuP9SgcMfvRbXva1ifBGbU8rxfB3FfZHPdzCU1z8zTURUIu7PEHAYr4y2RD4aEcWpYliZpqqoPBocmxgSqkP0cT2DJfQUCgjH9o0233xnKMNoW8Tj4uE8h59XIFyxr9YbF0K/O9H0UEJcN/sYhhvI++H8rRLbtF2FFjoOGHnl5Ll8QSLDT48hFDrvMLogS9KerCeMNMI276j67sccyjum6diDKOXvhhBTvw3hyRrFhtiLWhb8N4XuGMD9PAg71ki/OWgXtOGvdKSdbompChgf1eYlD20DoDuUIzQG1Mtijsu0VtALNJdjPmQkn/aDNge8HlbAn6vDoB8og34E/mYJOlSXoAprWaMI+IATGZp92Oa1dh+wdnuhnGwxoOLqYMM1Eq0HCJyHVnh9zirc10KO/bYXmtxeu4Gv8nEdt8EHf/5Y9g3inNBO7A3l+hWAH7kH61cO0JAa+4A4gT6MUdxXEBE1ANV3s8f3bqklA33FASCINPZwxueyIuoFETdERK07/Pn5TQ7Hvt6Tfd2GcmCk9kCFjmN5EZtT1vubLLYLf3e9ifeT90Y80Urx7OuJJGYQ5w3u/4iIQjGsuM3XM2qOZ4jCZLodN92X797/X9sT3GcgpuXPjuUYQ3wo4o0Cb3rb41o7GOvvOI3INUSJzkuzT0cB5At5ToZqn+EDQmAtuTK1fIse78PDBDEt0qYhPqU75n3fHqAmiYgO6IjejkYOrw+45636cmxvjhgBO+uyb7GkHDCc1x5gVLGNNZYRe+0eYDI1/u4oPNsvbkVyH7Tgop/Af9dro/AVPfQZ5MjMefz5mZneJH0YyOe+0uGB+aUT7s/1QA6eEtQ/AyhLd2dP5EtPoE+/+Nok+dX/wve73VvGS8S7hAbYsbpCcL4KhL9bxMiQwki+6zg3npuks4CyWizItsSmxWmYdeXa/eUT9iMPAdWFeEUiIlweO4DnOpavWIRwfLTTwdR8QxjzhVT2zTIxXnPoX5uktwZfmnq/1ZRxLIhOPXYl9iVIeQ4l4OeVSZahSVyPQczvjWL1vtAHlGg94f1lTvmeiF4ewdjW9hfbHHGLGfXb6z68Y0H0zNCRbd5O5Xh+S7cC3j9cobr4DvdlBwFg7dTGDOdQAO9bDrsSZ9gH32e1xPcO1fYjjIlGZ9OiTsl+iW4ymUwmk8lkMplMJpPJZDKZTCbTFNlLdJPJZDKZTCaTyWQymUwmk8lkMpmmyHAuoMO4TxlnTGteXfwdI4Vr0GLljApPhnBUjMpNFDZjFdAFs1kO1XipJRkYmwMMv+G/lyBs9tGaLAOGYL7ZwXBTGf5wLeGQk5k8h3voMPfVAocoNSCkeU5GLtHjVQ7pQAwNhugQEVUhLBfDfwueDCapZDi+YmvI5fvzI3mK+W331iSdDC9P0hg2RyRPJL5Ww3pwfV9rqbBNxOFASK0Ow9+DfpoHFNAFdeLv9y1w+A1EQtFGX4bDY/RYAOEnQxWK/pUmh6rUACNTUOFxNyHE9nobsBlFeb8KRP1gGBgiR4iIIghRnRYOpDEIiIfpR/whVOGJGM5ehfIcSFqCEGILcsqkrQOqyC9zaJqj/v0whntgJP+p+QC2AdEzGkf0evds5IoWtss8hFx3InnDnMcnuh9AOOAIUAdZNebxpOyQeOz1Y1m2EYScRnCi9jCWoagtCJtvgD2az8k+fKIm5+hb2hzIMLUGIGEaWX7WosIgHUH4f3MEyAsVAt/I4v24TDNZGfaG4eKbAx4TiOqYy8txlAQc1tf0Vybp7liGqPnEhjHrcrnjVJZh7PDnKOXB3XZkPiKihE7/zSRV8u7bnL4KzcZ5OZfnfn+8pjBSEGa6A+uNpoK1wW6MYfJ2SBqoCPoRw4ERM7SQl3Mc8XB7ELp4pyvnF64PiDeJUo1L4ut23d1JOlX5PEBjnYRv8vXjrsjnZ9g/yeZ4bCcKoeUh9gViedWyKWxwDnBrGn2SdSE8M+J23nN57hVHMmyzM5qfpDHEfGMoQ0w7gN2JHW4vfyzLUBxwqHwD0CeIgCAiOganoTcGW6rmcAfCYwtgMxYSaXcQUZVCe33nggzXP1/i8p6E3J/tSN7vcwdsgzEE+5k62+zDcPq2APEUiE4gIloFNNYQ8B+fubFA01QBH1pjwQ4BkTAA9MZsVtt9Lm+zzb6U48gbHoVcplsdQJ0VlbMCPsSNDj/rxd6+yIXoAwz79kiWD0PJR2NANMKYzzgqRBoQXhcLPLaryq1AJBr61gOFKkFMSxOwfUtZaVsQqYN+Xikj+/oIfM/f2uL2ujeUNsOBtii6YFdVGwVDvn8enONaVo7FqkAznc0czCuMx7kS53uqzvP9Vk+iCm528H7YlnJ8YJg3orXOl2R5il5Cw/htxoS/j5V1U8q6Ka3k5Vgswx7waMTjQBF+qAlUjytl7oPFvByz9wZ8jzzsbxZKet3kB6zC1gztU6QwPefKPEa+csJ2IVKYshR8hgTG2Iwj9/6IUkHb4LvSTsyDQ7EC86bZkhv0pstotnLKCI1E+QI1wBZWszhfpeY9nv+rRc6nZ+RhwH8Zg989BhTmjD8dy7jR53Ss5jvakGqW71FS+9VK5uw5uFqStYpEv/M1yQN+a9oecT00ugnHC+6RVgrSV8T3Q6+9wLiUwkt3RD5cz4YRr+NfOKpP0hp1hggt9IWDWJYVsXEzCeMH9X46D3tetIPa78a6Yxki9Q7j+SMuRzOcjhyOp/Bltc+bhfmRB78xUHve2+51/g6whUvOQyJfnXiNyKVsDBCVejy8Lq7JFvkaxIB200ORb54YA+zA3r2j0DM9h7EvKez3x6ncZzcB4VRLGeeSdWQfFmF9rME+fn8o2xgRc2irPDUflj22J1Wfn5Uqa5B1ub5fG3BbFMH/xfFFJHHVON4cde8i+Cc4xj67J+1gHu63Cst/V22x++NUvNt6kOyX6CaTyWQymUwmk8lkMplMJpPJZDJNkb1EN5lMJpPJZDKZTCaTyWQymUwmk2mK7CW6yWQymUwmk8lkMplMJpPJZDKZTFNkTHRQ6AQ0dhL6alvyotaB/3m5yuwihWYTnLUD4AsVFMTteMT/dvEV4KAjA5JIMqyQJ4qonpJiYJ3AvUvAc85nJBMKeUB470ZW3u/7Fpi7lAfW9u2e5HjvA29yF3iuxwqP/DCjnQW/K1WteTKCdoav6hnJOLowvnzmd7O5tze0kYNcymjOHdcJvwkk5o7GU7CHm335+cUTZkcdfY2ZZptDWdYO8JmwrzUzFMfHnR6315pivWEfHAQMEIwSOSaWBRuUa3yo2NM9YJojqgzHuUaYwSXkA2i46Mm6I0IS61tUfTNKuKxVaAjN2nSA/dZIn5ykw1Ty12Lg+n6pxQyygSP5ZH4KbP0sc8cqWdnmr7W48C2AeWUVZBk/LgOj/pxEJIp8+S6P840ed+5CPouX0OUsj7EbHebwxYojh7z0MbDeNKf1YMjfnYMxpjl8PvClV2vMSJ1XLMBgzH1/s8s29ovHVZGvBkzDc0VOtyLZlhdL3BZLhekQfQ/s9L0eNzSgk6mek/eOU/6yMVybpIdOU+RrO8x6KxGzBTuJZKeHTm+S9h1pS1EueRSTZvaatDzn/v+z6qyOJWBR1oG7v16S8/pwyHMF55rmaSI3sxvxvT3VR8gJf3PErMJ2szJJN3JyvuL6g2vAvaAn8rWdFjyXbXghlWcO5OC7VrpN05TP1Pl+Dl8TeHJcIkvdhfpWgb9IRFQENivaen1OxuKY515nzOuSp+wOci/7xLzZAbGdDhzZRhdTZqI3kVOu2PUFgn4HNqnmPh6H3CH3evxdN5ILHa7JeA7Nfiyf2wG7kUmWJum+uh/6c/Mwtl9v1US+63AGB44dzQ/+6BzPge89x5z8EXCfv7w/J65pAVcdbey1Ocllf+mAr7t7wuN8oPiryE/dTpFlK8uKLXG7y314pNZa3+WGGQOD9IPnpM1d7LPT+3CNx7b255CDvgOFddWYcGAOBHDmRd6RflXb5XZygTc7grXWVecmodC30LxftE5z4LtXVMevwTkjz8yx7TsZyjl+s8f9hgxSPA+FiMiHNb8HG4gdV9qZXno8SS+kzEQNHOkcj11uv4WEzxwJAvncjQD7A3y7DNejqs4iwvOVNJMXhfz2o5Dt0WCag09E83B2SioR6+Q7KUXO2+Oqvp/VihzKxc6psw42urkz818qT9/zduEsocW8PkuG+yoLPuBKfiRyudBnyKtGO3i7L+d4K+RrcC+hmdJ4blGL2DfeJrl+ofErpzwn8yTnwyZcloG1tulK/nJEIaRhD+jIDfqsy/Po8Rkuq8YD92C9wDWwE6ozrjLYH8CAhzO8DkfSMbjT5Xy7Ay7rTiJ9bVSVeO25I49loM0+98E8uEh6bezAupSHM9pyrhxvt3p8v13wG4cK6T2GNsvD2XXaBuHZDHgWlrY66MsWwP6iy/AgLjuuc/pdAu6hH3Z5/1VWh33hGq3PAUPhOyV87t1BRuXjjCdjXqPms3JdGoENxnGA/hsRUQznACHH+8DdEvk6451JuujxmpxRfvw+8bP0XHlL41ie4ZUn3st2HV7/ciTP7akk7MM1Xc7XB/+eiCgCn7fo8TjHOhARHXh8PmCVPsjPVWelbYFPswH2I1Dn00XC3+dxdQyMdiKii/llTlc4nz5LzwO72h9zPXAMHAdyEm1DmZ6Y4b5eyEmbXYZzzw5gTrZHsk5v9rktN+CswKI6M9JxiCJjoptMJpPJZDKZTCaTyWQymUwmk8n0l5O9RDeZTCaTyWQymUwmk8lkMplMJpNpigznAjp2dsh1fBokx+LvW9HCJD08eXyS3lWhKYsFDmXwIVxspy/DysaA0ZjN4TWyPJfLHIqQg/vtBhx6oEOhMIxmGSJiXBVW+IE6h3Remm1N0oNQhoshcsGDsKZmU/ImDgIMRedneSpk6ijkPyzkOF87kpXHsKTd4fRQ15rHIX+IENGhGGVA24whRK8IIScrRfmcxTzXt+ZzX4QqPBlRNNg3Wn247qU2d86uIk8cQfxTADFhVRXG/HCNPyM+6EBhgbAPfAiV749lqMs+lAOfq8N8YmhofFLZ577ojqeHcHcgxKYfy/CduQyPv6oPZVBtjn3gQbNcrMo5eTjkL/cBZZNR/36In2+7t/ma8FWRbyn3xCTdHnHYW17HBoKGEIq2F58dEkZENAPYqPmcHL+ILkEswpUqhy7V5dQloK/QQo77ZjuQKIs5l5+7UACcQywDChFp4Dt8P0fZlh0Ip0IE1OZAxjsjGuO1zvSlaLnO5cBpfRjKkNrPHfJzdwb83VOzsq+fm+UYttaIn4shjoOxrJMLSKNHsov8hYwqo/30zUl6lhj7kqRyDRjGEJLoMQ6jQBJlU05rFKfqIaZT6sf3w2f1HMhD2GsZwvr3BjJcFMdSC9aiVqSRRvhM7lONkZpNOVRzDIG5MYSY6nnonvD8WIa1aC0rw0DdEX/XAywCopiIiMoQ6nopfXqS1ravD4iTAWBRjt17It/J4AY/N+BQ0qD8qMi3nM5O0jhfZxQiKYi9M9OuwocNEh7/cw6Hti8R47n2Ehliephw2yYwl3W47pzHbV72+btY8chaI+7r3QGvWW1lz7FtF/Pc/udJIlLGgPWYLbANQqQXkQyf3gE/6Csnsh5YPlQ966nPfMPfvr06SW8O+LnaD61nAVUAvlOpJO1SJcPtknF5Iqolnlojvh+SMpZkZDZVwbzHEAI/oygPiGYYwDjaU1gwDCdeAt9uP5AVPgY0QwFwBOdIYosOxzzGcI4fUUvk60NIeODy/CoBeqaRNvASms/wuDxXmu5PY39cKnN5wljW6ZkneL6GbR4D/+uXFkQ+7BtEC90pqrB5eO5Cgdvc78nOiVKYhw63UaDxFaASICtwDhER3Qy4Lbedm5P0fsL4pnJ4UVyDaAZEKSwrjMeTwAL70iG3UXss5xba+gQwbxpb1B1nKNCD33RKm72Usm5CSSptVXt09t4sr/Aa6Cuj2T5Q/qHeQ7wlxLQQES3m2abf6bOf0B/z2NF7tg5svEsu7mHkvZtgpxH5hhhAIiIf1u4a2ImU5Lp0kDC7ZJhyoVbTNZFvzuf7ObC+Ho5kReZLZyNMR4otgrYZ/fg0lXZnrQR+TIHrLhBLCmWFeBhE1GQS2ZaIyfpqyHbGU/2+CP5ISmxXz5dkWyL+9iDksdhXU/gE1occlO8wkI2E90N7/Pmjish3q8PXNfJ8jdruUwzzI4y5LQ6AJrKQl20Jy5do15zCimJZEctalU1Js+ALIOrwaCQLuzWAdyID7utYvZfB/T+O7SCWvuwY0CIZeG0ZkeycAqwdxw6Pj4PwNZEvjtn2N4qXJumZRPoMHTp7/cpmuA+dUyhHLvtcyj5WOZXvzDrga7swb9qJRKLNeucnaYFbdKXPG8ZgCzI8KI7U3M0n3EazGbYL/USuh1vuJpcP3hvNAX6QSL77KMJai+sukbS/iwW+3wjWZI0Bdhzu6wulBP6uxhHYcNxLPFyXNmM24D7AZwWKbzSK01NlmSb7JbrJZDKZTCaTyWQymUwmk8lkMplMU2Qv0U0mk8lkMplMJpPJZDKZTCaTyWSaIsO5gMrpLHmUpXlnXfw9gJDE3aQzSbf6MnZ8GxACi3kI6VIhtX0I/cq6/O8YF2WUDxUhFH0/5K5qjjCcUF4jDseG4uHJ2EREO0MOXazB6dCNsgwx3+xyKHkB0AznizLEsRPxw0oQ3nGk6BUzEA50vsTP2lLh9Rga3Bxx+VQUEg0g7GepwGWYy8t/H9LXvaUq4Fz8rIx7yUOY8BhD1lToPoYyzWbPRk8QER2EnA8iwqkVyuceBNxoGCZcz8lwoEsQTl3K8Ng7kIdF0+0ut5EHoXxZT4X1Ryru563nZqWZwGx5D+/Hfz9UZdgHtkgnwpBG2Uh4v0U4jbw5kmXoQSEw1GimJPvGgX8n3IIyaPxC1vXOvMZxZBuNILzrIGaUwkCFMT9e5/HchdCxY3Xy9mWP0SDLBQg7VmPnTQj5W4KQ/2sVrtMbXWlnDoZ8TS+C8EmNfQCMwTq0352ubKNKhvsAcUsvNmXdMQzxdr8+SWPYIhFRFvq6CYZL4wQKHo/tnIchfzLfZo//cDjiARgdyZDw+Szbk5s9rjuG8Q8VzgXxDhi+7jkytG0w4vUhIp7HV50Pi3x3HcYEdSIOtR9npMH0KUexZsaYTsn5+v93e7LfMhAOmEDYcME729YREbVGmFbjAG2fC2GlCnnlQ4gnhne3Ex6XGquCIYVzEKLeGunfOvA6MIh5bG8pDN0wPRs3o22uC+XIpWy3Mo6c1ymE1EZjxiq0hhL7crvI/lN1fGGSXvWkfWqINZrbaKQWzgTm5UqRy4SItqOWvPct5+Uz65EliZRyIMw6SLgtS55cbxAxM0h48d51d0W+Ysr+UgbW+2s1aYNw6e2MpseN4jeINzgOp2MiECGQV/yPV5qAEBjxPTBcGhFhRETrRbCL0F71p+VzP1Dktui+ymPg9bb0k9vABlgrcb9lFaYBw4HPlfi5yjTTAMqE2L3XWjI0+x7gFzvgvl4syecitmEEc/Iglj7vjrs1Sc8ATsRL5dgpOFwORDhgeHhOXTOAMPcIMAaNrLIzENb8yPmDSbp8RWSjwT2u0396iXEnuwNZdx8cZQzTfrXtqHzcp4hpKnYlegptRo4A95fsiXwu2OkOnZukvVDavo7bmqSTlNuikPK87kSyjV48Yf/hSgXwPEW5rj5a4bW3AljBu325N8E1AfEJRR2+npzGQplOq5ghyroO3e3GU/PEiDtQ+JVzYJ9wfmwOp7/egKVDYJ6IiLoRj5cu2AJPXKPKB2vWScL7gmwsN/W4rrhgt6JUYlXG4DuWHZ7Ma0W5JuMc3QXuSCMv18NqlvOhD9NQe0r0yXEfuSSnAC3lcU/Jz8J3E0REK5AP/a+9QHFCQHXYQNztsc1FVAcR0S4xzgIRvEVPorEezrCPPgO3OA5lWXEUoAtS8+X4wHuUwI+czUlbtQx1R39Tv4vA+uJzO4rQhticXoT7IMAZKZ9yEZbyc7C3a6otBRapBkOspGwaYooRldgfa6wwIrQA46UwLTgfIsDLDhUGC33tQsodgMgRIrkOl1Oee+u5D8n7AVrl2eLyJJ1RrvYd2FMnUI9OhsdY1pNrHq7xVfAH+yRfihzRxiRddOqTdN2T7x8XEi7fwJHv51BZl+vUg/cMuXRR5EO0E76bbI1l5TPEc7SWMBIpp14bv95me3cd/ISLZWk0EHeN6yL6qJfLcrzNgj3fB8TSZ/alXb1Y4nyXS2xLn6jJydYu8j0QC6ZxsidhSqNETdQpsl+im0wmk8lkMplMJpPJZDKZTCaTyTRF9hLdZDKZTCaTyWQymUwmk8lkMplMpimyl+gmk8lkMplMJpPJZDKZTCaTyWQyTZEx0UGBMyDPiajjHIq/x8AW7bj83TBuinxRzCyea8PvnqQfL82KfIUc8smYARSnksGDTLftAX+HvN5YIfdC+MMghmti+e8lBYBy+i5zm7KeZFbh7XtjLg+ysYiI1ovMF9oYIONL5vvOudYk/dT/tTdJf2hB5gtf5HZ+5feYCZdTHO8s8JLLAFNuSHSc4LsdAH4ugHZ5tCohZMib3Bryc3TdkZ2+mON73OgphtuA8wFilcaqjRpZVfiva13xvq9Uuf2uAd/pS0czIt9Xj7lPjxK+ppZIPmwBmLVV4GTnFAAQuXnIFuwAp22cSNYmct8KHrdLRgHi6jA3IpgPmiOHXOoa8P40qxDZcVXgXGpUJXKCw7QH+aazkyPivo4U0xQ5dR5w1Qup5ITN5vg6ROq90RHZRP2xinVghl1U46MJzL92zIxF5DUTEV2p8ucGsJiPFY/0XNmF7/jvJ6FsTGSid4BdfxLK+VUGxvrdiG1p05X2t3vE8/9cmfuwM5J9szXifouB+56NZH17Mc+veajvEdiIY1VWZErXoIKXq7KNxq2rk/RXx7cn6RHJ++GaMhp3J+mUZJ28jE+JutZ0WqPk/rzQvwpARmcOzj24VJLMxcU8slR5jgaxnFOt8dnnmXiJfHIP+IdeAlxFxSMXdQBY6UHgnfl3ImmPazlel0Y9afc3HWZUI39ZC8dcFuxT2ZFs0Uzl2TOv8Wg63xTtbEFOQ5LYVuDDqqJWYV3HtQfnf8dpi2uKjmyLtxSkyrDSypn5BorJeZR2z8y3nq6Kz3M54K+7aDPkdTgKJJdajiNklyKf+2JF3lAhcCfSTH9xNgasbXWf+1D7TsgZf3SG2zm4J+3S/k7tzDLgOSynv5v6FR0CA3MffLat/nQe/Hlol4y6N7bEcQBn4biyzZcKuCbwTY7V2Bk7vAgKDipJznAjvTBJhymwzmGNGiobP0zZfrwMTO/KgizrM7PMI61/4rFJOn72WZHvM899dpL+2snZPFIi2R9d9OfU+NoZ8h/wDJOGI1mlm3AegQtc64Ir5+eVlMveyOB5MmodhnucTx+ZpC/leewNxtKAHAZg32D/8GZPsv/rMKWeqnHfPlWXLNsdWB+Qq+oqtvatXv6Un2k6rfmCQznXIT2tkQ+N+9ySemtxDGv8AZwdFirEOvbvfO5sXjUR0daQM3Yjvvf5IjDHc7IQ19twhpfLfO78WNlpYP/nUp4bRTUf8FymxRzf41xZztelPJf9sRpfo47ZouU87FWA9Xu7L+txp8sXbg95DvQiuSgUPdzj899135QyyMPma653pyxYJNeyCqxL7ZGch+OUy4cc6UYiGdAOLAR4Ppj2M/ActQUwDXVf+V+w38d3Nk/W5HlGH1zjcx9e32NWd5TIMYG+wVaf792ZckbZgxSf2q+efb5HQS2OyHlfzfNYns9J+4vn4u0GPCY0Yx3n68UK92F/LAcIlsmFteIwkmcEFGDeZCE9n8gzqYpw7sasz+mlVK5L58pcpocq3GZv9vRZetxOAZyHFifc13otW0nWJmnca3fhPA8iogrxmGgkc5N00zkR+fAeeN5KxZN1j8GHQJ65PktAcPdH3J/InSciWk3YNy7Be7Zt9d4zC8+aBS77QB1e407xA3EM3OnJPN08vgPlfLe6cny0Rrwmz2a5PM2RtDN3+9y/6/D6S5/DVso4lDEmuslkMplMJpPJZDKZTCaTyWQymUx/OdlLdJPJZDKZTCaTyWQymUwmk8lkMpmmyHAuoMPwdXIcj3KZqvh7mnLIA4ZtBiMZ1pBAiP6wyOEGLRVXtgYxTxguttGT4Q9IK0ggTAejfI5CGUeDobO9CEN5ZFjDcpG/Q1TE0pwMW8brbnQ4JEZHKSLGZBBz+MicCg0+DDCEErAZfRkKFR5zoR6tcBjX7lDiMMowgs8XAeHg6lASrscdqGIZwrEcVSvfPTucqjuWYR6XSlh3CFmTEeZ0q8/19eDfr65UJFblIoTsYSjZ4zUZTlzOcd/HEHpSzshyj2HsjASCRI7LVQhFx5DYk1DGWc5AKCOG/2KIGIbG6XyISBiqkB+cKnf63GcDFeqJCBe8xUZf3m+jD20EsbUlT5o+DN+PIGwrn6nTNPkQxhSSLCBiZM4VecwHsZwQAaAaPn/Afy+pcDtEA+C9dwAzVFH9/iGgMRQyjGzSiKVLJW70UobTubpsozbYKgwJ1UFPe0NuiwzYIwzNJCIajvlZIYTGD1JpV2/QvUna7V2YpE9iGdK14d7kZ6Vc+XnAVRERvQ7zEkPM3ujx/Go6sgxlCAecieqT9GUZJUjrZR6zX25x/Q7dTZGvPbg7SQ/DLZqmXmbngUgh030NximN3dOx8/d63Adbfe7rfixD+ZcgvLvq830OAjm6O4DDyENI85wj79cG/kEK8w2xDUOSa3ch4fmxCWGlimAmUGxoJ1aL0rZE/YVJGkNCESlBRNR0eNxHxPNwSBp9wlpJLk/SY2X7Rg6v1xjBqXsHbToioJYLMufuEBEuEO4MDtJCOi+uyQD2QdtmVMmB0P0U6y5Dx11Yr0vEfa1DYH2w0xWoVE5Fr3uAfnhqFtB6ypjuw/jrg99xsSzbqOjx55dbfL+TQK7xYcKfA/BXYwwT9uW9r9XYYM7Psg+z+aYMY36jxUiNl1rcrm2F3cI1AUOLI4UzRDwGoh52o77Id+weTdIOrA9VFaN7r8/rBY6Jo0D6lN+7zG0xl4f5Fcj6XgDcUQH6OlT4JYE7gC3XCJB3ofLFjhx2Bl6O2B9Z6V8U+b4b2ihd5fDr8T//f4l8uwFjxpaL0/F3+0MuE6J/RgpXhbgpHNt1tcY3ovN8P7AzQSKd447L46pB3B/aRw1c7vuZlDGZGBI+TuT4Rb/PEX8X2egQpvznAV+xXpTh8BdLbCdwjJ6MZN0PQofCtxkS/n5W3r2P8DhflkYS96XoDg8VpuWVJvdHO+FOXMxIrNKjM7DvHvN4RiwFkbSlKMSrnkLPwOrmp2D7HLmf7sC6NIAxX4GxTES0TPy5DpxHZUppZ8jfPVUHZKNCCx2GPDanoSGJ5PuDEfgJrZFc597sAUYVvrpSlp3TBpTijS6XtQ3reFb5N+gzIMquq5BtUcp28Xz69CS9WpD2HG0DmuaKWh+WC/xg9E060XT0DEqvN/Hm8iR9b8Dt9WZHdmID9sq4h24rFNs+rAM5WEcyDvqD0la1gOtzEPL4uFyW/ipiD2+ALxAou//cHKNGjkaKUYf3g+d2wM3tRHLi7CUtfpbD9SuTfAeXaGP9dc2pvV0pw32F4+iC4gxdBYTL8Yjr2FFtjsi70GH/oeSxv3kuuSquWfT5fc5JxG3uKeQrIhbR34wd2UYbzt1JGjFP/eRY5EMkVCnh9tMIyeMx16MPz9UotgslqAe8A4ochReF24/Avyz5cnx04TJ8z4MI6qbiUJ0gkhbQM3sk6+53uT++mAGMT17WfQX2FmgHi+pNeJwSTXn9d0r2S3STyWQymUwmk8lkMplMJpPJZDKZpsheoptMJpPJZDKZTCaTyWQymUwmk8k0RYZzAc3kLpLr+FRTocE+nG6LJ2p7eU/l489P1zjMpCSj/GgmyyEFGAr1hsJ/IAamDjGTeD/XkSETGLooTz6WgVsY8XgCJ9j+6e1VmQ/SiDfZHMpKYXjyRxtckR2FX9kYcPu9+btc9tFYtmWUcKhsMcOhJAsqPKOR5TJ99+r+JJ1Xp0rfPeRQl1mfw/wQ++IrHMA2nNKOKAsdWLQJoVo7EHp+py/DjvddDtGdgfrpEBYMSZrLYkiznK5fO+KQvzXAB0UqhLQBp7uveBzir0/o7kKs4G7IoVVYbiKimSGX/WKeQ4DOQ1iaDgnfg3ZBHJGOmNnscb/twJi9UpN1x+isu4BBOg5kKBTOh0qGL1IRvzRIAfsCyIX+SNY99Dgkedbn8Om6Ov0b718F9Exe9Q2G2+Ec0uVDBEshw+18DKFynprj8zluiycBzYJjmYjoBLBPRxDapoYHHUFo1XYfQtHGsheXITYKcRM6dPwo4Gd1+4wCqNLjIt+az7Y0AzFYXYWomkvZdjVSvp+rglZxvmHZh4CySNTIzMMaUIP400gNYMSHHI1vTdI5rybylXMc6hngGFPolnrhAiVpTCc9Gb5mkprPO5RzHbrTle2Hp8/j2Lmh1to3ISR2CDG/iswgUBSoalaNRfANAliHOxACi+gPIqLFIl/TCqfHEnZg0G0N2FahfSMi6gCWaoZ4Ds1mFHom5pBzDCXNkQyVLad1vp8jQ+VRTQhbxfDd7YFsI7QHDYion8vKulfAhmwO8HcffD83lHVCtcAVcFT/BRCyjiG1OmS1CvZ9RDymtM2dgdD7JXB9lImkANAll0tchqJCckUp+xY4XrrKhmOrjJLp64isLy8eWY9DdxFtRER06doJnaW7d+Wa93Kbx/bLJ9x+u7HEAp3z0RZO/x1PAm1UBTd3xpU+ZS/lzxhRj+G/RET7DtfjhBih5caPqSfzuF8tAoop1r42pwNgym0FEjMWgj9xrVjnp/g8T9DvISJajnl9KLo8BrA8RESLszx3k//8h5P05/5sXeRDhN6HZ7nf90Ppx2P4dAka01Fr6CEwNfqAZcu6sj/PEdcjBs+56EjbgvapD2g9jcNIYO5FgORpRVynqkIaoC1Gk6voC9QFTN7egJ9zMJQZSxmYk+CzjZRvV84QZRR6xHRarYgoF0ufnkiuD0BsoZZimuCYm3PYjjXU/hzv3wM0Vk/hOSsZ7lNwtQWOQOMlFwCFOU9rk7T2H3CuEPGedKEo5+E8YKRwxmufHPdZOEr1vN4D/NI0XA0RUQQLRt3jNXVZ8Q4Q01qHPWojJxtmP+ByoL1EPMQwnj5JRN8CopGIiJyHJsm8w+WrKj5MAYYB1lyPtyL0O6L/0C4QSbwe2v35gnxuN2I7dAzjVyN4wxjwYVDYrNowJVCOPqzdcy6Peb0fwf3vPLwHaOTlvfcGnG8uNx1BdQjvc9D2aSwrojoRsdRypC8wcPmzTzyoyqn059CXRZweomWJpI+Ew0DjOhCNtwkI2CP1/qAIuL4rySU6S4XMdNxPDPvIyJHI4kHamqQDQDt5JOduN+H9YZwiSkxWKgsYNETjFFRbIqoUETULGqEDY2cVXjqmvUWRr+xz/XG8aduHfbAMLtwurK+bPdn+iH1BLM0SSVswlzvbzmz05VjGqbwAc2ApLwvbH7vkTp8GQvZLdJPJZDKZTCaTyWQymUwmk8lkMpmm6B39Ev1Tn/oUffjDH6ZKpUILCwv0Qz/0Q3T9+nWRJwgC+sQnPkGNRoPK5TJ9/OMfp/39/Sl3NJlMJpPJ9FcpW7tNJpPJZHp3ydZuk8lkMpm+sd7RL9H/+I//mD7xiU/QF7/4RfrMZz5DURTRD/zAD1AfMBn/7J/9M/rt3/5t+vVf/3X64z/+Y9rZ2aEf/uEf/jaW2mQymUym969s7TaZTCaT6d0lW7tNJpPJZPrGekcz0X//939ffP5P/+k/0cLCAr344ov01/7aX6N2u02/9Eu/RL/6q79K3//9309ERL/yK79CjzzyCH3xi1+kj3zkI/+nnnc1uUYZJ0ctkixr5GPmgFe0nCuKfMvANEV2Yd2XvJ2aD9zckLtgpahYXsA4XQCs0Uqerw8Vh+8w5DLc7HA+zYo+ATxTb8x1OilInmAATCJkGuXd6Vy1tTkGzuZb8rk+sBDjhOvbjyQHKgDe8Z0+crjks2azfP/jHvfHerElyzQD7C1gu48Sfs69vmRtfq3FD9vp8xhYKMhpMwMMMeQoaTZjb1yfpBvAMUOeIxHRvSbzrO50mVP1N1Yl22oO6t6NkAklG2kBWG04jm53ZR9uDPm5XWCQ1ZMZkS8HZqMbwRiDsTcrqy74c9hGbizL2o+5nQseXxQobB6y31aKfI+yL8fREDjyyJjb7ctxiTzcUczzP+PJNh/HwM31kXEm+bDNEdejA/zaTiy5aAWHy4vs+p6C2yHrbRqrqzuW9sMDZlof2lmzO7eBh6fZvajDgOvUGnM9ApLjN+rzHF8DltpSQT4XR9+euwf364l8zag+Sc+nc5N0H/qMiCiCciA/bZDIvr49Yr544PCzGsTnBSylC+KaPLDxLpSBca3YxDsjvp/j8HzIODmRb+xw2dOU29V1JWu6mlmlJI3obCrxO1ff6rW7kU2p4KWnmLz9Mc+p2102IsgpJyJyge94CMxKPCuBSNo+tO9jBZ9GW4NsQDzPYFmt98joPRjinJT3DuCGAYyd8/mcysfs6QDmwFpZrl9un8+AqWaXJumiArC2gUW7PWSWYk/NQ1Q/4fZrKU7rWokrXAIeqfZp8rBez8Nai5znvaE6w8A7m1NZ8WQZEujPFLieY3U2Qc3ldQDrpBH5G8C2DoB1Oq8YpGhn0R5LUrQUroH7QzkmtuD2eG/k7BMRZUP2cYZjHi/ou7qKB9/a4bpvHvOYut6T4w3Zs3N5WHuUX6VZr29J/1nhMSd6uiGdi0cTPpsEh/YXD+Va6wNfPu+yX1VRthl5qUcwtHcH0gk5GLGPJNYvV65fVThj6TjkHj5f5noUM3JcHgVgZ8Dpnc8qP77F68WN/4Pb+e5A1mkuy2VfzHO7pIp1jv7TBvi8usfQV0Hm7bmsHMGXqjx2esgcH07feuK5E3GkzhJJub7reU6ju3SuLMf8Ipyj1BN9K+fQwZDbFv2HqivvV4EzmuRZRNKeV4I8DeIpg/gdrG/12u18/f8TOV2FD4zLtT6DB1WDs0lKav3Ce+A4qKnzmx6rsf//YpPHM5411ZUmks6Vue8bubM55UREXTj7C8/zeqQq3zkchjx/7w14rrRGsk5oMzfgDC+drwPl3YYhGSo/6CjkjEVYQ5eV776YP7sPjkM5r/H2yEFuwl5xbyzPPQiBHV2N4Ww5V9r9ZeWjT8qgNot5aKQWnpGjFu812KQiC3+hML3N9+H4i0PlgxxDPnzuIFGDB/ojAj633ucNgHONe50svCup+rL9kQ+P7aDX2iXYQ/+1ee6PnaHc/24O+HPe4/rm1A2xaYtwbl+PpC1NgRnuplzWAUljgGf1IKtbr0vos1ZgSdVnGHRgfuSnI82Fv76W5xtiGUL1bgL99UrM87iYSuZ4lw75mpTX0AW6IvIVHbgOKlxK5fvHfsrjo+PwvTVjPXLknuYtbafy3C0v5L02+ipV9Y4Fz0frg20uZaT1Qzu7DO8wy5AvTeX4PYGzofZGPCZm1LlOdTiLCOfhOJX2Df2EUczXrBdkvgvFWJyN9SC9o3+JrtVu3385Ozt7/0DFF198kaIooo997GOTPNeuXaNz587RF77whan3CcOQOp2O+N9kMplMJtM3X7Z2m0wmk8n07pKt3SaTyWQynda75iV6kiT0Uz/1U/Sd3/md9PjjjxMR0d7eHmWzWarX6yLv4uIi7e3tnXGX+/rUpz5FtVpt8v/6+vrUvCaTyWQymf5isrXbZDKZTKZ3l2ztNplMJpPpbL2jcS6oT3ziE/TKK6/Q5z//+b/0vX7mZ36GPvnJT04+dzodWl9fp4qXJd/N0u3ktsgfphyeueA8NUljqAyRDL85DPjDYDw9XgQxCzqspApRExgWNZ/j2JQgmf7vIEc5vmFWxQbtDTmk4xZgXzxHhklgFTEkTIfKXatwmWLASszPytDWax/jUKG4DSE6Kjz59quzk/SNLoetFDyZLweh3q+0K2emiYgernC4HKJPMIw/UGiRk4Dv3QYMR7cnw2G8Ptf3cpnb77l5GfbyWouxKIhBaY7l/YYOx4jFKddDh7mvQCjpYcBhQ0t5GQqVb3A5QqjjSyfyfiHEla173P6NvDQTZQgvHsHAhOFGLYW5OIYiHUG4XX8sY6GwDHUI39P4BYzoxDBLRG3cvz8gXABbklfjNweh7fP+Q5M0hpsRESXE5Y2hrB2FIBklHLv4aIlDko/6A5FvCLiIFNqokZNjp+DxGJuBqEac1ohEICIqZyBUzuHrtwey7lt9FV74dUWJrHsbQs56gEFpOXLj9GbMvzDK9zgUbba7IvKtOg3+LuGQ94EjkSYuxLDtuweTdDmR4XGNlMdsyUO8kWyXUsL2JAPhhXUHEEup7Kc8LJUYuVjXocAVLpPb/c5JOlJImV33Ht/P4znuucr+UoHid88yfaa+FWv3s402lTMB/dFBXeTfgm5MYRw0I9kfBQ/DrHmCuSosGvEuI5gfeRVG2xrx/RH/gZisrUCGY877bDMwDDGIpY3EEZcFZFMwlmPxKcBeHAY8HxCrQES0BJw2tJ8PlyWm5YUmj83WCJ0TkY08sDU4d2dy0ldBxBeuvSW1xqOPkwWMXAXWITXFBSJtATA3iwXpZOGyMu5AiG4qK1UBtF4F5uPxSK61iLZJBzyvx4kcH2slLvtinu3vxkCGrJ+EnE+sryNpm5tQjuUC95MOs0Y/MIY1ZavH9c0r5M38PtvpbQjvPg7lvRcAm3G+zN/1IokWKQO3CNevGV+2OSL9DiBsPlJr99N1vg7v91WN4QC7X0i4Hpfqss0R1XCrw+MoUDa85bAv24NQ6PX0IZFvFhBdCxASjkMW246IaCaLuAkMm5f5vnzEPiVi/OqqLW/3IRSduDwuyfshUgqRJs1U+jerHod6L/rSjqHaI74HDsWlwvQ1DUOucwqndxyeHYqOtqUqh6/YVxUBQXAwnI4SrPs8JuYUigmRHJuDIlwj/ajHZlrUG59d3neLvhVr9yi5Tym425VthUitMthfRTejIayPHWjvcSrHTkJn78P1XrYE/Yhj9hCWw3l5azFv0C4+XJHzcKUACJc6Y08f/h8k0mT/i1zf2zf4HxuaocKvQJkuVfi5WwOZ7yTAPT7n08AC9JUjaGhFZqASzKObPS6rxowNwCepgJlFtMicKxFQB/DcY4dBhncV1HAMyI/ZdG2SXo/mRb4YEBGHY/a/mo68X9hdnaS/e5GNiMb94D68Cf7hbE6OI8SRHYf8oebJ9RCF+JCKYtfmosVJejvgeuTAB6wrH+tcicveAKRXlGjbx59XqmzrT0ZybdwBzNg8vDvB5xARtUeAcwODvpo0RL487JMz8JvenNrzzMMYQcyQ9m+wFIgZUa8PaKl49n46KsvFQ2NV3xIi1lZr8uaITtsBv/uPdudEvi6M527C+9qBKyN0atBmI8CA5kn2TdNhH6SfHE3SWVfi9PyUx58DvnpXzYcWoJTilJ/l/wV/eo2+GdrwkxH0uzLRFXgBuQYoN93vHfCHEUmFey8iojr4E4UMP0yPj8V8RIP47PciWu+K3flP/uRP0u/8zu/Q5z73OVpbY2O5tLREo9GIWq2W+Ffx/f19WlpaOuNO95XL5SiXm27ITCaTyWQy/eVka7fJZDKZTO8u2dptMplMJtN0vaNxLmma0k/+5E/Sb/zGb9Af/uEf0sWLF8X3H/zgB8n3ffrsZz87+dv169dpY2ODPvrRj36ri2symUwm0/tetnabTCaTyfTukq3dJpPJZDJ9Y72jf4n+iU98gn71V3+Vfuu3fosqlcqEt1ar1ahQKFCtVqMf//Efp09+8pM0OztL1WqV/sk/+Sf00Y9+9P/0CeFEREEaU5zE5Kh/W6g4HB40j6He6no8pbcE/+CuQ6EO4GTqNkQbaJxLH8KBMNzgRo/LME5lWIMHITYNKIMO3/FdLhSGpeuwBgwzxWfhc7RiCL8uFWW4c3CPn/X66xyetFKXISwYnr0A+BqNNMEQVh/K5LuyfP0xD3X8BnEws1kZencNwny9DofE3ApaIl8BQl08iFnRLYQhbIfqBHHUZY/b5ZE6hwNhWYmImiE/dyHPYT6lrAxD+QCctn1zi0ODLlZqIl8v4tDUAYQul2I5gC+U4dR2OKW9CxifvUD20/4QQi4jLl+QyrJiCHH2AaeJH8FkO4DTmHuRNGmLcLJ6DUKkB2N5w4YLIckJY0cO3B2Rbww4hvZ4c5KezVwS+bbS1yfpmSFjPWoKl9QFhkszZeTQIJBhZY9VGA0wmwUsDYR3106Fw3N/YPiUtjPDmPt6x9mfpF1l4RyHL0yB4ZAnGY7pA5KkOb47SfdIYl960GYN4jH/kL8o8mH4GOKD2mrsLOfP/pXTSSjD8Nd8Rq5gyGo3BlQHyXsjzmEGwuubkQ6f5HEVjLld9gOFgEq47q0ij6M0PTt88N2mb/XafRTkaZjJnUJyFWGsj2AgRYmcBHnEJUFIbNaVIZMDWEdw7IxVjHkCCxjatH3nziS9o8bYyZhfVjzqL0/SpYwsa3PE1834sD4oI4m+BYaf6nBMMb9gfX2tK+fTvR6PTUQ9VTOyjeYB/xXCfB0o3Mw9IERUwTZ31Jw6Dvi6EtQDUXO6jeZgvuK6qygo1IU/zOb4mpIv7a/vnv3ccihDZXeHXI4gYRu5PZC2ueyzjQwT9MXU+IXl7EE4vSzY3AaMX419QYwZ4ggRrYNrOhHRIfgZOwHX7yiQ/YnIFfS/NPawCP2B/uZA+RkPlXkAD2Meixs9+dwWjJf1IjfSxaocl3Gb2wiIBlRR+I8BLBeYr5GV86E/Yp9hzjk3SV/KS7/q0Tq3Gfr0mEbEAhHR1TLXYx2wfSOFb8R2ruXYlzgZSj/j+WOu5O7w7FB2IqIuoJ4QxZRToePbA34WYt5KkcyXgXDx+Tx/t1aS9RDjQC7XQkUP5mgGx++DfEX+A+L9tB90pcJzGe3MhdL0NXlnCHZG+eeOk5LzgH3SO1Xf6rX7cJCQ78a0kzTF3wMHkBVjHs/lVOL+8g6P7RZgFeNA9sdCHnzoB/wgvgX27gLsX49DLsO1ivTnCoA32R7CflD1/3ed5/3E3LOAOnpK8uGX5xnN8D//LqP/XtySv/THdwE7gK5rh3osTseqoXyX2xLnVzUj74fzFf2M7YH0aXBvkQ95rmTh/UNPIRdCh9t8AMiscSrfJQxiRlbEHj93NpH2twF4srUs++RzscRQXQIOFO6rNPLqIID1GhqiR9LwxOATzmTBTqi1exfabAicUsTsEUmkqe+cjSZStDpaybPfgb6x3isSoI4GIbdDOSPzSawwl1u/b7nV5bI3YSxivxMRncvwnESUUHssxxHuzVJ435JVxr4P7YdIRY1HvQg2fTnP+S4UZfkOR9wuiKTFteNjKweEigBDR00ei4hXJCLKNB+epF+A9wp9QMMRERUAb4rzoaNwVQHYPtxHDtIWTVOR6pO0T9IoXizwXMFm7qv1GdfKgwDf8yg/Ht57ujDeWjD99WsxbGf0GU4U1grnxiDhsZOouYtjDLHCr6n5dLvvUfg2t+Lv6Jfo//E//kciIvre7/1e8fdf+ZVfoR/90R8lIqJ/9+/+HbmuSx//+McpDEP6wR/8QfoP/+E/fItLajKZTCaTicjWbpPJZDKZ3m2ytdtkMplMpm+sd/RL9FSfFnWG8vk8ffrTn6ZPf/rT34ISmUwmk8lkepBs7TaZTCaT6d0lW7tNJpPJZPrGekcz0U0mk8lkMplMJpPJZDKZTCaTyWT6duod/Uv0b5ceda6Iz/N55iktF5n/05XYJsGjioCnc6jYkcjXRB5pTvGsnqwzswd5kfvAm25LnBidA1zcaoEL0RnJfy9BVtmYOJ+veILnipxvHpiLWcX/S4EHtvAw852Ob0pu03//GvOm7/S5ws8ohmMpczYH/XxpIPP53Alenytf8SW4CVnZATCrloA3eRBKZhX2TQGYua5in1VcZkkhO0qzHWdyeB1fk6SSRYXMtJns9F+FtKBOT64ylysM5bT2izyOYOjRwxUJoJoHbtjzR9xGsWJDIZu1DNw8xHqqYST4cCXgWnqKAXuUMNfrJOQHjWJZp/0x8MOBndgcSM5dEDPXa7XEdepEsu4ZbBiobyveFPmQy4fcsU4ied91l8d5Cfi8sWIn98c8gY/c3Ula88m8LrdzI891QrbbbiDbCH9QhHOoqXhiIfFAjR1ORyTbqJgyS9yHMesrXuoBMcMxSnh+ZT3JTh8BBw7rvhhXRD7f5X67WoO2TGV9keuL/Lr2SI6x3hhsH/YhlCen2x9swe3+2Zw8IslIHACfTzNvy3D/BrDh9Xkc5bRGcaqMvOmUXu3kKe/lqebLsY3zei7PbZtXLEX8iMx7zTRFruRrHR5wW305V+pZ4EomnC+JHpmkCyQ5jX3ifg6BOZ7zpGPgA2e4AovUbE7WCY6hECxwvaLguSc5YFve6CifAcZzBZjhdcW8xrMnDoZ8v43+UORDXvJRcDZLnIgoSvlzFtiFl6vsM6yUzmaEEhF1RlyGvaFmtvK9r8L96ooVjaxG9NNmVJu7Dl+40Wfj0AFuNBHR7oDbbx/sdled1YGfcIziWCaS59+4MGYzikGKmFX0V1eAC6p5qbtwPgf6vHN5fe4JP3cXnJ9qVvbNc3M8jo7BL9V1z4PdXy1wmbqRvN8JmMdZeNZaUY50PM+gAf32gRlpxJED+4UsnBOj/LmHvTpcw+xTZLkTEe2Cvx4B0xynzaEcHnQMXNr5HN9vtdIT+doBj9m5GWSiyvsVMrz2dqCfyorPC1NcnNESObJOeOZD22GW9UCtm+vO3CSNbHzNLUfW7gJsBZBnTnSaw/+WsC23BzLPIRhCXPvRRhMRPVQ7+6yf7146EvnqFbZjA+Bfb7el3/Li4SwNYuUgmE4p+fr/syTbzyH2N4UtKMp+a4AdOgFueaj4umi3yxnYj2T0/pw/PzJ3Mkk/tcQ3XPiQXEde+cOZSbrgIYda2qpb+7OT9PBzzDdePrwj8vnLPI/ysAe/Um+LfPshz6+Xwa+PlQFYAecYV4RQ7UeQg45nrHSUbT4IOd9AsLrVmQ1wXgqevYYs61SdlYLMe9xr59M1kQ/XtkPncJL2UnWWE6SRB3+lKvPNZ4HfDoz7k5HcZ6Av1YEP3bFcINDnr2env2rDfHimmt53LxaAUZ/l8xuGYxzL8poTYHoPwHdF5jsR0Xlg/+N5QVlXTqJZ8I2HMLa7YznOAygT1m+tLAuIPkgPmq+seNojGKcO9KgavuK5eI6VzrcxgPc5Dud7uCLfL9V8/u4KvFsrwvup3Z48o+Ew5LnbBl9lMS8LcQffnYAfVHIaIl82YZvWhnHedeUZEgWwlw5sWIepzOc77NMsJHz2kn6vVc9hG/Hfj9RZEzj+4lO7Cxay1Ked0dRXtgD3W/juRPdnJPY308uAZyfhOGqP5DXNNBVj7kGyX6KbTCaTyWQymUwmk8lkMplMJpPJNEX2Et1kMplMJpPJZDKZTCaTyWQymUymKTKcC+h+aH9CV6oF8XcMlY2mhPIQEbXg1/8YRj5WYQGP1Pk7xAHcaMv7zc9z+kqJ4z2bIy7fhbK89+NVDjVs5PnmlYxEKdzoANoiwTBrmW+lwM8tQAhLrDAcHoTA3f0ah7aOVDhbD8LCmhAh+vkjGQZa8fnzpRI/d6kqw1mTlO/3UJ5vGKnwolXAu+x1uI6HgJF5tSNjuGUYGH+6MK6JfBiulJ8eVS5CVpHuoBEfJQixjaF+GwPZ5oMx/xvYSoHDBJ9+VKJFXr++MElfh7pXFfImA6HyCwU2DfcUt+jlJn9uj7ifECcwVKGUGL6H7eCqcTQGtEgr5dAqP5HjElEIMaBFIpJ1whBiDPnV4Y6IN0I8zDiR4bgxfK5nz0/SJZoR+SpJndMQr6RD4OsBh7AWB4yRClLZ5k/Xuf4Ysoeh975CT+wAJuBej7/bHsh7FwFBsJQsTdKb7j2RL3A4vHuJ1ifpXirbKHG4TCWfjVjBkW0UpJ1Juj3emqTvZuoiXybkcVnKcPmempH1jWCuvN5iW3oz2RL58qlE/kyeA8th2ZH2qDXi8fG1E763DpuPAA+DaU+FyhUA0fEkPTxJFzPy37UPglAghExna6UQU9Eb08PVrvh7N2I78WKTwy77KjwZl/INwPVUZeQ4rQAibR6wA8WMNPxoMnsR4BN8nu8lX/Z1d8Q3RNzUMJaDLA+hmmhXdfg6ooUKsKYUHrBGYfhkpMY22swCjNOW8oMG47NDNRfyCpEEi+DGkG3Lm+7rIl8KfK1H0icnaUSuVBXG5wQwEIhHqyay8hWfP6Of1lQIKFyjMSy1npX5cCmfAVzEsid9C8SxeA7Xr6XITTh2xL0VRqYI4dmII1iqyUHxMPhPl9ePJ+m72+w/vNmVay2GPuO4rqi5cQ3copsQIr1QkGVdybO/mXf5Jq935XbkGLB0uOZp7KEsK3+pTCk9VuN2Lmc4vVyQLJXHHt+fpK9u8lqhfUr0gV84YKzCnYHsayxGBcZpC8bYUIUx47x5qc1lGCp/WmMV31I/kp2Dfin6QR01yQ8DwCg6vO74qbwf+l9lWqJpOk7Yl3LAdY8VkqML5VgtQWi72qF2Y/DxB1h3bnO9LwuAi4AIl6rCUD03y+Py2Ss7k3QmL+/3hy+x37cNqCNXhZEHiUNBbL9T+0bKew5lXYdSX86bJWBP4T73yZrEgpUybJRudXmN31Z4wyzYRRz2d/uyj+72eY3OH3D6+xYAv3JB+tBPfpzH+cnnWnx9Ve5HWge8dy9XYX6tS5wpgT0Ih1yPw770XRHTiHuLnb7GIvBnrO1A2R38hMilw0Da8B3A16FfoPdVuMajPxynbJDOFyQOA3G3vYjXokEs17Ic4Fz8iG2Q70gbiSVC1Nz5olxsZ7P8GZEcrWj6HK4B8sJx5HjDpsB2ORqebbOJJBpPr1+zOcTm8N9xLNd9eW90HWPwy2azcvwuAtb2Nqz/hwoN24a2GMN+qx/L8VGBdsb2n1GYPHxPgD7WWklWHk013q+jkMpxgr4K57zTl5iWMSBSPBgvviPf/VVhf32tzvvVI3hvdKQwwK91eE1AlHNJodP68DKxkjLCJUfSFrRcRrigL3wcSwRU1uV+myVGH8WObKRzydVJ+kKe7VtzJPMdAcOwCf70SO1HEPt0MQfomVCOxZdabD+bIc95tDOBujfipfYBxVjIaD8IcNLZs/f3RESz4Hfj+yqNlxuMHWFbHyRb4U0mk8lkMplMJpPJZDKZTCaTyWSaInuJbjKZTCaTyWQymUwmk8lkMplMJtMUGc4FFKUxpRSLEFoieRptBsKT2iMZXnQSc0gMnnR7sSjDlVzALlR8DCmQ/6axCVFrYcLhRd/R4NDn+YIKbQOkyW6HQzU8hXp4tM73u92ZHgLbGnFoig+ho/rg2sMBh3Ecj/iU4KEK80HkwgJErRwrYoGK6pjoZQibJSLKw0noD8/zSerlkrxh7SHOl3mZ6/Enm3w6sUaxFCBEeh5CqepZdRK1CqN/SxppgqdtFyF+RP9L1hKc4jyb5Yu2hvK5233+7nd3ua8fXjk6u0AkcTo6vOjVDpuD/QGcMJ/Iitx1Nifpgz6HEJXhlOpHKhVxzeUql32jx+XWpzEvufVJGk96r6k2j+EkeohwOnXCNIbU4QnTGsOxB6deRw7PoYzCekQpz70UwhNLaVXkW/Y4th1PotZ4CMQnXKnyhJjJypAuiGwVIf81QPJUlDVHnAtGOOu6I1okCyfblwBJQ0SUIc6HqAJHIQiwzTzAQ/kk27KTcHx3P2QE0bb7sshX9VYm6Zd6HDoXJRIPM0q4kn80+rNJujm8LfLNFx/he6eMm1lJFyfpc0XZ/jj8XGd6aFsEhhHRXzkVLzYYQyifjyGhIhsdhy4l9m/d31D7gUt5z6PdoC7+jtgW7Cq050QSgYHhu3pd+mpzGqpE5oMoX8pD3+Pc1ViVUpEvWn6bfY5rjLaliPzAta2u0CeILkFcR5Lq+/FNwhjHuQzbRHvSgPDJWEUxY3G9IYYJy0bHENuCyzYNphDdkRQfOga2Dc4vLA+RtKu4nh4r5BWGjvfGYHN9adDx7otFxFDJ8uH4w3FZV+HOUYLjlzNqpMlsluv7KIQdl3PSOK8+xg01POB7/8lBfZJ+o0NCuwO+R9HDB8tC4Bx4GNAuF1XY/HqF7f7JMdtwRdcQ9nML8ATLRWkk8+Cn4XxXSy2tFLhPc+DLNlU49ssv8zpweZl9ypXvkPPhK/+VfZw/P+YO1oiE82VYL6ZE8j81Kz8j6gWxhz2FlME67R2zD7I5kCHNOK+xDNsDOdcQzeZB/waODIfHmHpEvdQcaQgRi1aCEOxItQOGkucBfVTy1bpJXL4wgXk44vYvKL+2luPyXQM37bmGHOhXVhlvhAiNV+40RL7/fsDla0LIukYQ1LJEo7cZEv5+lu/eR3hcUA7stSrbNMQW7gylH5kAznEJ9r8zWYlS6UR8/zcAHXWrI/c3uK5frMC+pc/rUP6X5L2vPcH7ghzM9zt35MS+A3vy6hGP+eKGvN9ciff11QqnMwrfNAPtQmUuK+JSiE77qZPnqrmC4xlxEytqAUM/tR1xPaJUtmUF5vIesS29N+Y6fSRzUVzzzCyX9WTE9dhU2B30QYKY53gvketNM4T1P8flWSvKdydFn+txo8v2E1GYRGK7STOALdHvb9A/QUzW1rgl8i24bJTw/YveK6LrcgLrQz3L5dPvec4X2V6uFfgGa+W+yNcGfM1N2GPd6cr74T7Ih3Gu1+7HAZ3WgTX5RPnTxwH6QYAcUhshxLbgUG5IU0BHQGY7CPiiFkkMMMEQCZpcX9+Vjf436uq6r+tWj8cH4u6IiHYHXMAuzKGmWgoQW7QUM3L30GmqovLaGwPmNVDjaOxy4y5nGA3bIWkzYuI5ivvSelbO8ZOQnzWCeb1akMibZxp8j3s98FtIygd/AlFqZXjsvZ60H0cwb47gfcE6SaRyHTZcLbBhGi+FbY5jVr/DG8ZE4RRfTct25yaTyWQymUwmk8lkMplMJpPJZDJNkb1EN5lMJpPJZDKZTCaTyWQymUwmk2mK7CW6yWQymUwmk8lkMplMJpPJZDKZTFNkTHRQRPeZ6Ceh5JPlPeB1Agyw7EvumOcwJwk50prvtDfkz8gGbCoID/I7lwE1WAYOcs6XZY2AmYg8c1fxmB6ucPmKGR4GCpFGg5jvNwsMsZFiM+4MmZPUFExZecOlPNfxA3UGWHXV/Q5DYEzGwFVN5f2eALbXzAqzo4ofkbxk+vBjk2Tu5p9zPYBZeK0igV0t4CwejbgMS3kJUMIybQIfq6/Y+ueAIbZe4H7T1MQrwAy9DZzx5ki20ZUqPwuRy//vVy6IfN85x+zH5Tw/txnJ+yHnPgMDpq64r37EDLG2wxzJDrC8LoyviWsuFLnf73aRESrbcjHPAz2FuearAYzXHbsHk/Sl9ILIVwNmHTKIl4uyTr0ecy8HwAUtumWR77b3pUm6HTEbPpeV+Qox16M4AJZwIv/d8jKg45dyXKfOWOZ7rc2fkQV4EHBfzCn+bdFDbjGnkS1GROQnPM5rLpd7heT5A9u0P0lvhnyPsuLGLyfMOAwdbstyKs+GKAIL0CsCR92RfdOJd7geLtugmeHTIh9y2lfShybpWnFJ5IuAq5o4PC7nslyGimKxtgAki6zpVcXnRTZzABDogjpsIudhf/Lfs8oYVH3/FHPadFrzuYSKXkw3e9KluQfnLzSAm6ftCbIfsb2HY9n2eJkD6zryfomInq7zXETe+uaA52vRk7ZvCLYhA2xLPMeCSOCI6SDg744VixndjuUCf1DZBKcS0aeNnHwucgNxzatGcmyPwNZ0RpzOqnMB0B434AyIldFDIl89Zdu6VNak6/vSjGXsw0Mo+InyR5B1fhJxn3VTyUutJ+zb+XBNXtUJzxlB5mJXItaF33EezGIjKysSwZioQNWzruzETWA47wbs+3ykoWz9dW6L53eZw9kdn11uIqJ+HJ2ZLvuSu43n36BNcxSnNZ8FFnCGMzbUgw+AZYut8mZHthHyMB+q8HePVCXHG3mxqzU21HsduXbPFNgvrV5mf2m8K33tnSG3H7L19ZjAeYP2o/yA3Rdybpfz3F4XFMv2cMjz5gvH7Ez0ZFHFeQlYuuNUHiYwcHi8LMB5IXEqOahN8PUyqTz/BlX32d5lofLHao9V9c9mmuvzqRLcfwH1/lqdjcmlkrz37T42NF//SluWG33tEOZnR9k3tJGzMPYa2VTlS2ioD4IwnVLRdyjnOmJPSkRU9Ljt7g24rzXee9bnfAHYywsVafu+elyfpPGcgeZI7vtKHj8LbRIytFsj6fNuXOd74568lJWGvwBr/gtNHm+uspENsEkpu7/i/AciuTf24R4LeWmDtsBseDAPlakSPiva+kYs/Zu1EvKN2QYNlXMRxLjv2IC/t/k54wvimhH0YTEDdSrIwvYjOPMlA0x0ta8NoEwh+Cb6XLdshvvtGPb7nUiOS7xfFs5oySmfsgN7hmNgwDfdQ5Evl7CNzLk8rpbVe4ZNOBMNzxXrwXuiBXXmy0Mwp67W2GYf9+TajVqFdwRhLBcpHC/ow2im9HoBGdp8vxtqr4j36IPZvt6SN9wZyXXvLS1l5Z5ya8Tr2ZZ7i79QPk0G9qgJjNGtvmL/w3rxRov3q/sB33Cg1loUnsfTVU5qFQDzxQyvr36gDryBcRrA+uyqc+JQuE7i2W1ERPec1yfpcv+ZqffIunz/iovceO3f8OdX29JvRl2pcB3R98HzfUpqn3w8YkOdJy7DSK2rwRj3dvx3VXVhn7ZhiOl3r3nPEfuYB8l+iW4ymUwmk8lkMplMJpPJZDKZTCbTFNlLdJPJZDKZTCaTyWQymUwmk8lkMpmmyHAuoBP3mDwnS8cqBOAp98IkjaFQVRV+GkKIge+cjV8gIjoYQqgFhF2MVL4IQl0wvOhLJxxWMpeVYTkYsoqIlHJGVmo+Nz4z3VNYlRUIba1VOVRj86Au8oUxhj8hrkNkoxyEIWNom+/KjDsBh2dt9vh+z8zI8LjZMofsjjo8nMM/kGGqh7/6Ct+7t0hnqRPJ6YAh0xi+3lWoDYzSwfrqEHO8XwPC/LaHWZHvv+3XJ2lEcujwX7z/LQjv6qhG3+jXJumHqny/19simxinuwH3dUzyfhXi+9VSLqtPPHb6qvL3Blz4Kzx8KSUZ3pWDeLEWhNjsBTI0e8fd4zIkjGIpKKzCUcD3uNnm8YahikREHeJwsb7L+JsZuDcRUT2zPkl3Yy5DXqFKMvDvkyGgBY4CkY1Wi1zeQ8D11NR8XQCc09dOuB6fP+RrHqnJkOtXoX+PQ65vy22KfIsJh6WvAy5hVuEcqu01vt+IKxKmMp7Nh2WlmHL/ugpcVCTGDsTO5Um67chwxzDlioxSHgeJCoG9WMG+ZxRNa1QT+Vpwj0MYR0chh/XuqfCuisdzdKEA6ZwMO3y1iaHoaN9k3VeKgISAvtXYrWMZaWyaonImoaKXUMWfHk68N+T+2JXmhN4cnUzSBbBJs56cU1drOD/43n0V0nkCKLDHamxblmE9Pbcg5+GNHR6zX4JQb1xPiYhaEIqO+LWUZN23+jzn05TH7HpJIU28s9O+QoZg6C26PkNdd7D9M3DDulzmqA31QATGB8sLIh+GcSNdpwOoB41f6sMa3RxxOySqjXoQ3nlC7DPk1LqESL6ZLPftQsFV+TiNbVRVFJo2LD/Yv3kVoduDfBg2q++HtcqDn4A+JBHRdosRFudLPAlORvz3G8ovGAD+ag7wZjNqfcAlH6YafbUl2/K1zsokjfOmpPwbHNtLMAZaIx3Wz7VfhFD0i7OyIqXy2cZ0W2E95ho8X/0nGGny2i9pnAt3AtYdEUZEEk+yBOakA337Wkuj/wBLASi8WPXnMSAbXwK/QCNl8DMiRhC3RkTk0dm4pJ4jw+lxja6kvL7Ws/J6DGdHvFRbkucoherH8EGjDh8tss+wCHPvgzNsVzPKL3ijywMLsUWrBekD4ngLEGOpEEvni2fvW05Gsu4nI1fMWdPZ+q65iIqeewr7tBdMR7igKoBz+diH703N99U/rU/Su2B4jh1pJ1qAsIhSXpPPgz056kp//2jAE3sL0hrLiPZ4G3wQRY2jCHApD6IK3O4mZ+arqLUW514TWFtZ5XA2YPFYcGDeqP0comzPCX9C3u+FI55jMQEKzOM1Pk1lBV9tcToPk1LPJCz6TBbXZ7mIItoK11dE0BIRHQbsiKNt1u9lENV7BDi91ZJ87mqJ+76e401vqXNJ5KvnYB2BR321Je+H9e9Bf2DXrCm85FyRB1k2yw0RKEzLTJ7t53yJbf0jgVy7sY1u9Di90ZfP9WDsIHII0UtERDHY8BeOACU0VkhlB20BoPpGEh8ycLi+LryPCFOJdoocwC0Sv0PT4/wrTe63mo8I5Omo2Sr4ooiT3BnJMvQibucF2AtXPLmOXEr5fVWYsD3qpnLtPgKsbQDvM5JU7lEzgNTpp9wOM65CtiXctuey7CPpdzu3OtyHe4R4X4Ut6jNS+TjkvimAb4JoPiKi5bwcf2+pPZJ12hpwW6wCGnZO4Y3w3RjawaqyGRcrX1+7D+gbyn6JbjKZTCaTyWQymUwmk8lkMplMJtMU2Ut0k8lkMplMJpPJZDKZTCaTyWQymabIcC6giEaUUEq74Uvi724KJ7CPGO/gD6efjusBzsVXIX0rELe6VuSwkHEq893qAHqjxWEIy4ACOA5l6Me0cOKcCs3eGHC81zN1CPlRWJVajsMkIjilfrsvMTI5OEl9BU6VvjeQbYRRpgchl2Ff1aMLIc4QdSFOICciqkAY56DD9/u922si38aAy47topErqA/NcN0fq3J40a2eDDHBfkPkSrUi+xNPhP9f7/I9cir0FogaIpw7p05m70HsVzPk8t1Ot0W+qLc6SS8VMM5P3u844Hv4Dvebp/6tbQzhVDkwIRgqr8PmMYxuDU7ubo/k+BCnp8M99hwZVzMiDjUqEIdcV3x5PzyN+SDmEKddV4Z99hK+fx5wNTMkcS6LyblJ2vG4XaqJRIaU4DRrjFbUoYstgTTgNIaOERG92ecvT6IQ0pynGebxEjoE5EoW+rMBIaq6rCgdPvn4DN/jTpdDv742kGFbx+7OJF1NuW9WST4XW2LgSPwSqpxZgmsAmaXGZT17dtxr0ZPLXN3jvnJCwFDBmJp1ZLjuapEnJWIpRiq8fqmIIZxnt6sW9vtQnXLvOUSJ7gjTKb3U9innZk+FRR+MeG0bEk+2BVciHK7kZifp7AOQbQ9XuIO2YP3fHcixdwLjqghImH0IY5xtl8U1m7BWYhgo4ky0+mDfdNh3kLA97wBeJixMDxPGtVYjbwowjbYgRnqUTA/1xvUwVGP7Xo/7A/Faj9Tk3MPiNuEeaNtV1DylU0LgMRSYiCgivmEG1rKqI21pBNfdHDLuqxPJsi5BI61AaPWlkgxPDmBSv97hCrYUcWQGXI1V8BUPAjkmEIviQI++2pHlqwImrJLhup8ALnBGRdCuxRzSPA+dodFTh4Aqw+82eqrjQYi8qmWlncZweBwDD1VkHyLC5ZGZ1iSt0ScjmHsn0C63e9KXTe8xauC51xn3dTJcEvmAVEKXyxCKfgrhwf2G0cWIAggUMqAL6/obXe4QjRzEOV/KQHtlZBkQ0zQApEkhlSHcM7DurZb4uXtD6Z+PIUR8OcN2TCLVpH99ExhGGqfnObxWIoaq6sv6zkMDIqbh+ROer0t5OT6W8txID1d4jV8tS0RNlPC950dcnp5qc9y3NCNul+NQoXbC9BTO0nRaRW9MpYxHObX3RJwLYlC2+gr7AmOz9AL753pPeQDzPwKMwZikfXLAr9wb8nO/vMt24UMr++Kar8B3Xzjm8XKgHLohYCrysH+I1OJd9rnsuMYrioHAerQirlM1kr4nrtEhzN2iK8f2XP7sV0L9sXzw7Q5gKqAeGbUoHI4H8B3bk3p6NlKVSNZpY8B18tSO5KEaz3l8bEvhtNAXyIBd3BzI8YF4J0SLOY58bh3WKc85+z0FEdEqLCvoBxU86Vugz9qNcI2XNjLv8fgdQ6UWAYGh51BzyM86BsxQ1p2+JucB+5Lx5P2y3tnXHYVyHcF97RBwLhpXh1ivIYyxRk72Da5n2J97Q9lGccxrUQJ7xcCR5UvgJscp+3PxSPrkt3q8Hi7moQyQR9v4I/DhsA8P3F1ZVuJ2jhDNopp4AfbQiGxsOy2RzwdkYznl/c3YlW2Ee+gsoU2UFUG8C7a/RmuhH5QDfGtIEl9z3bk9SbuAjT4X8bu6J2fl3PiOOW6MWz0eE/d6ck7egRdliJuZUwPOB/8BEY1qm0d1P6XAnbKJULJfoptMJpPJZDKZTCaTyWQymUwmk8k0RfYS3WQymUwmk8lkMplMJpPJZDKZTKYpMpwLqJzWyKOsCOfSQrTFgGTsrSO+4zCmN+MTkW+jtTxJr/Q4VHY2J7sD8RoY/uhCqLEO28RQdF9UQ957WqDCtXlZ1kqNY3RTCKnzFFYlAqTJBxp8ivl6UYZn4Gn2uxBqpDEtRShuBSJ7fB3yt8vtVy1xWR+uyFDNAMJ8MPz3coVDTvoqbHOuyKGf4t6BDJUL4LoDCJnaUydb3+pxvmnYHSKip+scwpKHcKr/vi/LhyGAGEZXG8+IfBj+hNGFzzVkW97OY2gqp4Ox7JvjEMLeIXT/lSFjPdZ9GSJ9pczjtwex9zVFvEC0TR5OqT5uSqyKByFOK3lu83Nlfboz1/FWtz5Jv9qWfbPp8vioJ9x+K5mqyNca85z3kvOT9Jwn64uhd7HmLID647O/e7Mvw9k2AO8w6yMKiOs7Us9BpBQiTfxUtlEMA2R3gOGmsgw4TvN4ojbJELgucX9geJerQj0RN1EkDj8bOrJ8vZRROwWH+0aH/C3muI0erk1HbWG43TjhsmNo5qw61XseQvmO4frjUObDcExsIx0Ch/gFiMKn7b6ck70opiidHnppuq/2iCjnEu30JTaj47B9x3BHDKUmIlot8XjBcf6gcPw5GG9jNacC6LJWBBgkiHB8aSTnwzjhi0IIv3ZVOPFTYN7xq1tdPebzdJaOA2UnXFzX+e96zGJoej/mdj5M2yLf+pjROGgHEWdGRHQYc2Pcc16fpJP2UyJfZ8S2FcsnwjEVEm2tDPa3yxX5Uvy6yBemXIaMw7aqnDwm8sUQAhuC3xfEso1nwCY9VOE1b6B4M4hPwfF2pSrrcbHE97g3YBu+pewE4kAO2G0RNoiIKAvjPoUxO4CYckRSEREVwED5MFaaoRwgWIYS+KWdSM5JtNuIFthV4es5WFPR/u4FsnyL0AXop3kZ2UbtDq9LvRGXYa0o/fhqjgfq5pcwPFwKEYld6F8Nc0Hc3w70Ddr9R+rTMSgB4GGGChUzn+O2fXYOcItDmQ/NYsmHdS6QuJ/5PNvIRcBIFTKK8dNl/+tSlcfl9y9Iv/v1LvfhUcB1dJQtQHRfJ+b+0KioDiAlA7CXS3kun7btjSzfA33yo1A6n4t5fC7f495A1v1rTQfy8d/1nijrOadQEKbTGiYeObF3Cg+l94RvCTFZREQ1n/N9EbA+11vS7tQA9xfBuFpSyEb0lbFMdwD/sdqSOLhSBtFpnG8Uy/GL9u4gYoRhlSRu4jXAt7aTgKapAmvWYp7rrvcCBZfnXgEQDhoviVNHrLXKqr2ZMLIRDeMyLYh8iPu8mFzjcgPaRSNgjkdsHBC1seTWRT5f7Zu53PJ+SMnF2kZqeMWQsQDo1KwqH37ENiqqt2noA6Lv0wwVNgNwLEMYL+gfERGNY76hC++a8kGdr1H4iq/BOMX1aikvDfBijZ8VAWpjsy3HOTZ5Pcv3OFeUNhLXLLSLrUh2Ghb3YeBk1uTWjrDJEJ2o1wd8P4fzuuhIfFAEg7YPaKdqRq4JiOi5BcTRKqyhGiHZgsIi+udyeEHkO0l5rRw67Bgc0h2RL3A5H/qrZUfarQvp+iSNPn2TZGMGgFnB/VFeod0WPfYNpuF0iIhqWf7uoTz7/knwkMiXTbkcPcC31n3wOXJyr/uBhaNJ2of6Zl1ZpwTKXoXyrBYU/hl8QvTZ2pGc41U/Jt9wLiaTyWQymUwmk8lkMplMJpPJZDL95WQv0U0mk8lkMplMJpPJZDKZTCaTyWSaInuJbjKZTCaTyWQymUwmk8lkMplMJtMUGRMdVEqLlKEcrec+JP6+njLDvJplFs/mSHKlmu7xJH2c3Jukk0SyHheclUm6M+Z77I27It+ixzyqLDDNNofMW5/NSCZnCUBhiEXb6U/n+1yr8DBY+58kF8m5cnmSTl9iVtP3zGyIfO1dZmKlwBa7dOmYpulwh3mT91qSPb2veOJvKUwUW3TI5V2YY77Tkw/tiXzLu9yW2C7nnuY292pyOngf4boTsCLJV8CuN7f5ml9iftXmQPbNo1Xu6/Ui9+erbVmnPz3Ez5zeHcrxhuy3ITBqsyS5aCslLi/yu15py3yAJKMGNP+NQEK/9sdcx4zDfXg1OzdJP1qXdbpQYrbdYcDtopD+lAF+Wh7GfC+SbLYQ+Kt1YMo1snKcXyox53I+x/frRIrVP1ybpEvAMVNIWfJiftZqhuteVoA+HGMx8MmQR0pEtAzc0SwwuLYGMh+y45YBvoc83ZdO5PjopsxZi6Hj64pvmqb83RC4ez11RgBy0ZCX7DtyHA2J+7rrMCe/RysiX5l4zi+kPHY6qeTLFx3OtwL5VhS7ty145Px3V7E1L8P0fbx29r8ja6bsjTa3C3KB9ZkUbeDhISlaM5u7Ed/veofnbi+VfN4SZWmcSran6bRq2ftM9L2B/Dty0BPgIB5FkjNaGPKYq8Gk6kXS9n0eGPgXytN5t2WYOjhX0P4ehLKvfTgL4NE6z9EVxfXD+z1c5no8WZPj5DP7bDPf7PB4G6hzGMAsinH6cE1ko+uwTo1iruBgLH2GCOzJGy1+bncsy9d3uLMSYh8pIGnHDgN+1lHK63UG1rknK/IcELSr9+Dvg6Qp8rnQ5p2IOa8HGcnQzKa8XmSBMTlS5xWg3V8pcN8cKn/mEJjoy9C/FxWfewfOX9kDm6RZu8j4jelsO00kueUjSA/AfxgliqEJa8IQOJI91Z8ZbEvg/Wumbx7OCPCgrDnFuURm+GyOy4rsWiKiENZk3+eyNttyHXnxkJmayGl9uCb9bge+2+0yF7SYkX1d9bn+yM3WR6BgH+wMuJ+wFppD3xnxRX2YkxdL2hbwd/jcGeUHFfB8jiw+WY7LHKxtyPRtKHd8FRi4yA/uj5XPAGcYIasc/Qwiue5lwOfVfGMc58hzxrGsjsUQc7wFW7G8J9vocMrZC5oBi0vCIXD8NTd6sZih8QPOwzHdV5o6lKYORYlsv3WwhQmc77GvzkR4s8fXIY94P5LOwCFwb08c9s4uuJLjPZfnMbw74HEaxGy3Sp7cr14pw/lZ4D9sDOUZAWPwQRJxzoYctHfoLpc1uT1JX3Q/LPJdzvMaf6HCz32zI9cHZKQvFrgeoVogOiPkOZ99FgYRUS2uT9IrWd4HPTEj9wK4LUIuuLBvyh9pZHnNW/XmaZrw/KGL4Ivpc2yut/jB2Lcz6jAytKUzPt+kmNG2mb9Df0nbHbSfWz08305mHMW87uE5Z3mSfO42cKSTFH1UftBxKNu/MmXP1lJ7uw68q/A9OPtG2XM8o+1oyOXLKX50C+aaA3ZRM697sHankE81uTjLbSYH562MZPlCOH8F/emsOidjzod3XiVey9aLsnx3YfoegBuD70oqGTWH0M4AVD3nyr4pxHyTCHxe35H+dDHlOT4g9l/zqTzP5EqV+xDtoBfKNhrBeY1HDry7cqT/1YvZXwq6vO/Gc1OI5Lk7BWjny/A+iEieLdAewRmFYI8Stdf1wH94fI7Pa0xoVuQLoC1n4QyUx2vS/jbAJ++N+JpCRr6jnSkPqRuNiF6hbyj7JbrJZDKZTCaTyWQymUwmk8lkMplMU2Qv0U0mk8lkMplMJpPJZDKZTCaTyWSaIsO5gIbOkDwnplwqwymy7tn/1tBzZBgoIlzGCYdMzGQuiHx+ws0eEYd7jEmGknRjDvEoqlDXtzRQoUH5DOdDZMBQ5fMg1OVWj8NZ/r//ToacPD6zNUlfepTrm/tfPi7yLd18c5Ju/j84BiKjwk/9BS5fqc3hevnu9FDZAoQXVXwVdlHk8Ixsla/Jn5ftNRNwfxwccSgJRNSRuypD9NKbHN7d+VNGxRRXZZ2GuxymcqO1Pkm3xyqEKBdBmvs2VviKjQFft9njAgaJ7MMYCr+U5XvosPkAwuWGEHpzpyvD5ose9/0Ts5w+X5ZjIog5jAijVcuAKnmoLO+9UOWwGuyznU5Z5NuFsPcqhNQ9MyvDCb9ywp83etyuM1nZ74chf65BqPeHG7IP34Q5gHiDkrKQQcy24W6Xy6dDEufyEKYGX3kqLLIMoWAtCIHXGJmLlbPDrAXaoSbDrKqD+iS9H3CbqyJQN+F5iKF32wopc5JwH6LtcxQ+yAPcQdbhcantaglCdEseX9PwZKgWjr9rVa78jC/H2MkIkFcDRPeIbLSQ53rNQpOVoC/2h7I/D0Juv+U81wOxEUREKYRZHgY8DwNp3ighvv82MXqqC0gwIqL5ZJ1ihbcwnVYQnw67JyJquowTwvHnpjpE92wcQ3MkBw+GiPYjTp9Eo6n5zpUwbBDQTp4KswRUAYYkd8dyjGH46A6E1GaVj4C2C0PMOyreOQMGIYbKF+S0pqtVzncMIaKXSK6b93pqsH9dGusxcBn7VCYOqXdU37RSXru7bmuSnkv4mvwpXBKnMUS66ErsC+JdchCir21aDrAXIbG9rDly7YaobfpKk9c2RaWgq7A+FsG/ea0jbSSqCt3bU6HopYzqrK9roML1+xACn4VL4pQ/HI9kPyGu6yRmbEHPkaGyccqNvpoyDmfGLah8XIZyhseRDptHlFJ/zGW4oEKuZ7L83Nt7vHbc7ErfYgPWhK0Bl+HVjlxvnqxh3wAmR4UaI7YQER967cjB2MSQ5j6wnUZqyiAmAFvlaCT7eaPP7YddrYaH8CcqPqanI6lwvF1V/hz6EGgvT1R4/e0+lxeREnWFRHRgvsawNtaz8n4ZmEgDaD+M1tfr/XHI92tCmHtVIfgQE4A+YFn5gB6gC/S+CtUZucL2mM6W7ybkuwnpkXgPEBNfawEKSPnaXZh8wzFPpJbTEflcmEldh/2sDYUI7feZY4b75LzH5TkeyWt8xGFARTz1O8Uh+HIRDHrtj2BZXcAsxI40FGhb5gB5VW3I1rzbO/sdhsYMpjDNtwdqIoHmPF73HqnzZHmoIq/ZGSI+BXwf8IMOA1mGJVgu1gs8v2735UTEUYA4kTCW92smvGYFA0CxJZJRdQQ+zSOwz9BoEWxzTOtx2Ydxie9p9DjH9RD3BRoPNXD4HcQYxk4p4QbbH8j1wQU/xgHb2crJSp2MeA1cKfC90TchIjoAvNyNHreXRtk0weYi0mpBOZWICcPa6v0Sag6uySsfOgP+Ez63pLCbayX+/FSd/bmH6m2Rjw4YaTJKEDsCZVX4IPTxA8RBPQDLidjJCjXEd8UU0CfOZfi79D0RYYqI24VAIpGaDr/XChK2kSVXPrcE9y+CnzZSPiW4lHQZ9ggbyuZgM3kO3w/LPZ+T9sOFeb1ylcsav6aRflz2J2Zbk/S1vyGRXuFd9m2/9GXGy+plunuSpf5YohWnyX6JbjKZTCaTyWQymUwmk8lkMplMJtMU2Ut0k8lkMplMJpPJZDKZTCaTyWQymabIcC6giEaU0OlQXhlmDac2pzJctAKnfJec+iR9wZGnfwcQknVErUm658hQkhBOrF9N5Um3b6mbypCDWsLhSjMQCllNZVdjmPotoNLc6cm6bw75ufVDDvn5kf/7/y7y5f9vPzBJVx7hv7tVGcrrfuAS59t5dZJeGUo0zhyE78YQbrfTk20e9zmsI7/LdfLrMiS59h0cmlLtQ4hHMn0KbP0299Odk6VJemZLtjmWbwZwM7FCPQzgdPddCIs6DOW/ZWEo7kM1CCfMyPAdxLSsljhfU92vBPGFeCp6lMg4pDZ8bkLY8HmF5JmHE87fhG7zIFxsX50I/fwuz4EATuTOurIMGJY3l+N2vtXL0zRh2NtQhYHtwWnl+5D+4IwMRV/OwwnpEKa9mJd9vTPkcmxDeOFyUbb5d87x+OvASehfbcl2wTD/p2HM3unLUMO7MM4xJBnDkHAMnC4Tl7sTybCyPZdxImOCEFP1b6xN2pykB6MjroPKV85yKH/BkfgEVFaEdHHf6NC7CyWu70qe4033AomvGUO/YaS2Dj/b6vPnHiA5FsR8ldfgGEN0hEbjhPCsIObBeJBK2546XKf95MYknaQKH5LJC0yC6WydBCll3YT2Ixm+1yIOXcRwxbEj58BMwoiqIOR+8xSHI0j4O0Rb3HVviHydMT/3Zv+hSfpyemGSriikAeJobnd4nO8rVMc8hMQ2Qy7fUajHCdcR8TIZNWgRMYHpo1Dn4/S5EobryqfinEJbpbEeR4B9igCRcuDuiHxoh1aTtUn6kSr3WUXR7l465vbbAWTLrLMk8hVcRrhkwEcqqVBZDF/vxLI/UBjC3R1jSL7Mh5gxH+xgkMiMiAk5CLjNd1WoPY7TEhg/jawogm2FYU5NGNdNhUFopIw38ME3jkmWoZbWJ+k6YNVC5WeI8HWw2QeBRIaMUkB35bivs65sfwzpPoJ2PVLIhQNwCQ+GPDe0v9SNeF0R66scEnSpxOXF/g3UeoOYu1oRw52hL1Q8cdU/u59uyK4R420M7VpTPBcXxschjCNEYRARnQN02mKOH4xYGyK5OuK6O6fCsXcCvh/iUq7VZPn2A/Sr+Ll5b/r4rWTRj+Q8ilZF93psW9qAhqpECtsJuI4xzMOBisJH/2S5wNf0VVvWsy6N1Hw2ndaNbpbyXo6KnpwDuwF36r7uBNBRxBN7w707SR+O5ZpczPD+NYK12/Pk4lFIuU89xFwFbGcKGXnNENaEFgzAakb6qLAkU5JyvsiRtm8u4XUq7/KeF8tDRPRGjzdgMzley66W5UbocoXb8hjW9a2+HLNBzJ+3Y0CdkdyPXMrzs5AIc28g9zcn8KylAvfvcp6fc7Usy5AXCC2+fqUg8+F0w7FTy8o2WvZ57TiMuN8PQ9nmZZ/7HX2xRk7O4YvyFcREG32Zbw/WGERU4TsfIqI7MWNa8imXtUryQaWE23zg8jVD8J1KfkVcswgmDhEpe0ORjW534R1BlftaY7ze6HL/IrZM+yODhD8jsrgXyUU0rfL9FvJ8v0ZWj0uwx5DW/jki4faGXPZ5hZH53nluv2eeYH+zeyDH+WKb64EIp+uwnRurtXsM8zoHvko7lu+kmoDuRNTkQirxK0/WeRwgivE4lDaxBQgddPvKrrRBc+l5fi7M62WnJvKVgGlWgRtq9MkivJrB8XISyuci4gfRbni7yzNyn9wbcvm2XuDxnyrUEd4jhHeHN/9AvjcaAYr4ZMTl2xpKe553UxrGZyO0teyX6CaTyWQymUwmk8lkMplMJpPJZDJNkb1EN5lMJpPJZDKZTCaTyWQymUwmk2mK7CW6yWQymUwmk8lkMplMJpPJZDKZTFNkTHRQLi1ShrJUIsnROSTm9GSgyXIkmTnLyfokHRKzgTaSI5EPWZKHdGeSThQ7MuswP6oAzKRFj7lX/ViyqLpj5iStl5knNJ9XvLkBf3cAsKyrVcmOQkY1MqH+7IU1ke/Z/+Uzk3QM+er/k+Q7xd/x0Uk6X+N6rH/5ush38nvMrPqjN/lZd/rTOUXuETOTntiX/Knv+d6tSbp3j+v4hTdXJ+n5vLxmtsj5kIW9OZAsRWSzXSkzbKySkcyq17t83faAmU4KoSmYjojyulCWHKhrFR5jEbAX7w5kG3VgiCC2DZmjREQLOa4jMiZdxYf2HeTccfqpWgh/l9C1rzS5b5A5nqby3/EQcfZYlutXUkzOGvD+SxlOL+U1nwzK7WJa5lsvMyN9psRln7sg2emt589N0qslbudVxetDNity2VcK8rkzPl/XAAa8ZsUfhGwLkA+3C6zIjSPJg17N8zX9Mfe1tkeRw8+tpMyNPADbRETUDZgdF46Zl5hxJecOmeguMHTLwPsjkhx0UR7FNN0YwLkAHnPMKhk5fld8bosg4bl2uyPb/M6oNUnPBcybG8R871BzbYGBicjFl47lHL8ZM18e7fyxc0/k60eHk3SKbMxYjresV6Yknc4DNd1XexST78Z0w/my+PswYu5gPss2aEgSLrxHPF6QO7qimP5ZYPTGY7Y7o7Qn8rUHtyfpIMNzxYU1ZXW0Iq5B+4uMZc3avQ689EHK6ZjkfJj3eGznXC73aUY1p7PedIYvshBxjepE8ppOBLZmxMzWEklGYom4Pw5cnh9RKu1Y1uF6SL4mF6itQMh7CftskcPzcD6V/blA/Bnbsuk0Rb5awjZuCL7dMJXM0MWU64gM0gW1Lknzcjb/mohoE85vuN1n2zAgeVaHB3Z2FRm6iufagC5A/n0Pzu1oJgfimjG0ObLih460VeWU/bkIBktzLP2qBPwJTPfU/VbhDAO83+5Q14nrjvNG+1VDcAZmc1zfzkg2+kaPxwuea9GL5Pj1HF4TsESNnDqTBnZZfTDlyE7XPPjlAj9rDc46SSI9P/kegDOmQ3U4DLbfAawx2mbkhzwflgv83NJYrtWtiD/PZvkeKyXZh/NwbsnekO9XycjOeXqJfa4ujMUvt9QZTT2uJDLRF2DLtq94v1U4e6JKzH1NU1kG5MbPABp3dyDzoanB81YG8Vjl807x2U2ntTt0KOc6gq1PJOfN+Qp/OByqRoX9TSfZnaR7wy2RLcy0znx+qSj3qLiv7xKvRddT9pv323IdWcnwWoY27XAs1zIUnvVRTaviu3qGB2A/5sG955zIm8DQfK3JNuQokHtAPI8A1yV9fkM35nsU4D1I2ZGs6HNlLvujVb7Gc+T99sBvxn1Lb8zpui/7Hc/MQjuD1xDJsxiW4aykw0DaX+Rme9DmRVe+/kLe9BeOON+i3O4L2zVtHSeS+9Ih7L8SZXc8OIsFz4ZxSfLNny7xuWLjhMdsD/xDdRQGlaCs4HrSjhqWuAY68I7rQlna8xD6xoWzXMJYtmUIg6wFvnFecaYjOJOuD/1bUGdhzPp8vy743dpdxbbFfp/Py4zzJW6AoMll2jiqi3xzcCba98E7jZLH68iuGm99sEctcEj6A+k/dKHfcc8xm5HvH/O4BEIHt0ayTnjWTN7FPbMs33rMY2cOzoKbL0z/TfVA+E6qLXM8/vD9nD6n6AAY+niu2CNVOHulIP3ak+P6JL0H56ZE6qyRJtiJ54+5b5CfT0T0vQutSfqhGu8B7/QbIt9B4FL4Ns8zsV+im0wmk8lkMplMJpPJZDKZTCaTyTRF9hLdZDKZTCaTyWQymUwmk8lkMplMpikynAtoxilRxsmdCvPpxxyOiqFejYyM8ylm+N8kbgccTtx2jkW+GQjFKbgcFraQSkTKosvhPN+3dDbG5HP7MuQAw8dGCYdgLeVk2GYd2BaFDIdC/LV5FZYe8XNv9jgcpT+WbXTjTa5TrcDtVWvLuCHvM384SQ9/m0Pem1uyLa8fMhLiRpfLsDtUMbqgMUT5+Y4Mien+t/OT9PGIy44hYjsqBO4jgIdA/McrTRW6BE1b9bmNygo3EcRvLzwkhtDbWxD/O05k3esQpupDGN0jFRXqDd+91OZ2znsSrwHR8dQM+UM3kv/WhqGz/pR/hpuvyDCwZyB9EvK43AlkiNMBoE/u9bmsOioWo5UxpHashkfR4z9gqPfuUIYnXgV0TLHI6S99WSIXbvT4ugzcD5EjRNND+UsqjHk3gPoOOAxJhwHj2G5AONUuTK8eyTjmGBBQt2JGsRwqTEuJ5iZpP+UxdZ4eEfm8In/XH3PIfzWzKvKtJZcm6Q4xzkGH61fAPomQS0e2JQ57HOcfWJRYmt0O20tETByP5Hw4cjnkdzNlO/1yn8tzLrkqrrmQ53ufBNwZX05eFflCYvtZBBSITxJ5U8+yPcK2TNVIz1CeEpLILtNpncRDyiQJZZTdz2fqk3RW9QHq2OHQ7wDGBCVPinzrDuOOEC3i0tnrM5FE9BwndyfpsSv71R/xvPnwHNejF0mkwS7YGn/M6SiVYwejhhHhEiSKGQLFqGZxjsv1ahlQVAP4bqMnn3sS8g0juPmQpq9/GeK5hxg7IqI6+EtFh9cLRGAkCjnWcHi++g7bVV2EEvg+gxFgWpStOgQ8TBb6Oq8QNUVYFLBEqXZboBz7EA681ZcZW+BcIMJl4Eg/bS7lcNQyLMo6vBvxcLhGebA4xokcl32nNUlf9ZYm6Xqs5hPUCfuj5Ko28jBEn/2beU/eD5ErGLadVxQwbDFch+925bgcwiJaz/FN2mNZX8T6zEGZfFcOnqOQPyN6oqxMAdKY2iPEr3B6JisvauS5HueL3FEnyhdrwdK2P+L1X+N+Iofr1HElqgi1F3B9qz32H/pjWb66z2UPwa99rS0RBCejsx1EPR0WCuzIoO+eceTYGcNEaoXcsLt9QKIpO4hzvAR7tFhNShxXjwOi4npbDrgbASM1fLAFPTUnh/0ajVPZD6bTaoYpZd1EYCmIiJZhUs2Bu95RDIeiC8gQQPXFsewPzwVcpc/jNEgU2s1lJF8nZd9snHBfBq68dyV+aJLOOTxeOk5X5Bs4/KwiIFwWHDlvZsA+xQGP04uORM/gEMZ1/TiQaKEB+Alox/Se8k33xiSNa/KH/YdFvkslftb3XGXfyVGYzP/taxcn6W3AcCFK7NW27M8+bESXilzuQLstUPa9IfgFQ2nPO2PEzZzdDkREByG/txiGiPSSfVOHRRXvUJOmig5hO4Y2aZlmRb4AcI0pWEZHOSuIKlmGdhklnEbEHRHRQQDYnCx/t1KU914p8lxbK3B5Hvtu+e7q4j3+PIa+dUjayOUirx1HgOFoK3Qa7t1jwPNo33MGlp9KhttSry+rgD5bBCTauaJ8bhfeQexu8nuQ/UC+F0Bk0Pkid+izDfYH7/Xle5TXOvieDBE18r3RasLvuObyfI1+p7LV57KjidwZS7vlAqoMfUDcWxMR5V3uqwsVQEr5cuwcg38TwD00brUJPskejPm+Gos4l8vQFA9XuB5BKP2MCMY2ooQKCu97oci2eWuoJuKU+8VQbj3etvoJjRLtpZwt+yW6yWQymUwmk8lkMplMJpPJZDKZTFNkL9FNJpPJZDKZTCaTyWQymUwmk8lkmiLDuYAWcjnKujkRJkREVIg5xAMRLvWsDGHBEMysW5+kd4c6lJ9DFq55HPJfUqdUYxhGHU4nLgEm5ENzMnThFiAN8F9IbvZkVzdyXEfETWwNZDj8AEIo5rMcQqFPrg1ibovFLIQnnSgUQZdDpg7vcRjMH28tiWw7gYrZ/bpyKgQLQ2cxym+sQjXvDTjjKMGwMghRV7iVjW6ZvxtjGKgsEyJEbve5P1ojmQ8RHXe63C5LBdk31SyXw4ETsHWo1sstzrcAYw/7jIhopcDPqvtciG2FILnR4VgcH8LedKhxBUKF6oBLuQ64n3sDiUF5ZobDHxchLGpbhd5UIVRrH9AuOpQPwzv3IYRoayD7sACDAvvtRiKfuznkELviIc/JNzqyjTAcfhbmUFlZ0jbgRC6VuP1LnqzIYcj25IUjHjAzWXlDHPZbA55Dm+42ZJJluIWRxA8gCfVTxqIcxhzOmXVlmNog4HzjmBu9nJFzt+lwuPMq8XfNVIbAln0M1wcMgppfq0X+w2Ke2+hOsy7y9cbcObc6PI6GCofiwbLXjRhzMxxxqGK2LG32XMR4l92Yw89iheSoOgv8HDh9PeOoOZRw2TPuFS63+6LIl3eqFBvO5Rtqy71DruNTnqri7+fSa5N0Jpn+m4E+oJCOXR4TA5JjdgOwKIhF6EeHIl/G5fGT8znMMkr4OW3aFtc0icM7jyCs1FNrHvonJwljR0JHogOuZHgsYohoPJITDMPA92FNQOQFEVEG5ug8+Ax9xdDC8ORVrz5Jt2NZvo7D47qXcNh83q2JfLPEPs3FsvRP3pKK8CeMGkZMgEY49MeccQ/6PUhlqGwV0EyI2dMh4Q/B8IvgWSqiVoQhv9yEMPxRIPKFJMPy3xLiW4iIcuAndMDRKGTkg3chpH4I/dbIcXlWAStEJMcV1tdP5XzKA6YlgnlSyci1LAedVc3yOA9iDW1jIc4sVL7AvT6X42ab26sbS7vZA7TGIWKQVBtXHF6TVwApcVlG9dNpKMl9af9wG3wSrGMJJuWlimzLJ2u8zlV9rkdBhZt3wPEYJVzu5ki2+SHMyRxxvtg5e3wREQUwPhCPRiTRJyeAm2qOZD7ECSE6QvtVdWEjYYwpkz0HDz4Aph9Ow6OxREgewXg5l3InLhalfbtS5r7BkdiN5IBDlJILThcic4iIDimimNQmwHRKm0GfMs6Y5n2J9KyDi45Io62+HBTJKVbWfTmuXCsqed6TFD1GGI5J2tx2yjiX/pjX9QQwV7lcWVyDCCgPMCiFVNYJVYWx2EvlOHGgSKuAUcyo9eY67Nk2AEm3HC/L8iU81n1nuh9UAP8pAiRUXXHBejAl/p9f5vVC73lPAAuaANIE/f3jUNqgEawdGUCBnS/Lur/e4ny9CN+PyHm9VgCEKayHB0OFvIH64pq3pzCv12LeFz0C2KdIvRM5DNgG4/uRvCfbsuGejfLQ7bI95DLhOo77c+2PICIF9/7Pzkob2Yu4rMUMvL9RvuLhPo/ZCryHeqouy4ptgZjcV5qyb9Avwr31h2YkTq8BiOCbAhkm1zl8h4N1133zepvn71db/OCh8mXn8oBODfgaxMRm3bPtz/3nAk5WbcIL4C8hnkfj6vYBYdwC385TCB30bXGtzSnnuB2xHbvT5ec+Myvz4XtBxC/1lcuwA0NJ+JR5eT98rXoA9u1PjtjmPBxIm12CsdiH91o1XxYCx+wKPFe/C8N5+GqH5/Fthf7bHA7eNorNfoluMplMJpPJZDKZTCaTyWQymUwm0xTZS3STyWQymUwmk8lkMplMJpPJZDKZpsheoptMJpPJZDKZTCaTyWQymUwmk8k0RcZEBzlf/99TAMuVLLNzkDGp0OnUCqcxEiVvB7mIy8A7m5G4Q8HH8hy+9wrwoRTykiJgPe8CS+luTzKEZnPc9Q8DgjRV3KZrFQYeYS12hpJdhPypAjC4h69P5/kuf4AruHIs+UP7ITNlH6kwd+xIsR73grNZ5b2x5stzupThjHhNV3Efv9RUHfJ1ISuKSDLWkc+N3HMioiyw0Pox9wdyz4kkE+tCiZ+lua83u8iV478PFdvdd/n+2F7bfVk+ZOWVgGOqMGuCXTaAdn61CTzSsRxvd/vMMfuuuemcyHIG68sj7kDxjHFu4HTtK248MtGRq6pwv2IuB9B+ywppmAP+WQD9nldctJUCP2w3QB6/5Otp1vtbao/kFzh2EmCxBsBsLqV1cU3D4TbPwxkNjXhW5Hsl/bNJOk35uVlXch+HwDp3gfGdJckP94EJibZP81fvRsyURs7oOZI8R4L7RTAOjkeyLfGsiBlg/C6PFdt9zAav5q9P0hWfmZnryTlxDU5RyUSVA+lcujZJY923nR2RLwPXzUE7t0jyiAtpkeLUuKrfSJV0jjzKUqS44CXi9bAPfNqM+v3AEsyPhZTHxyiV8/DYafN3BGtjKm2pA7Z0NncZ7sdjZ5zIso5hfuwBrzOj/JHdseS0v6WL3rz4XAUj2QEu+CDRdp/HIq5LWbWGjoCrivYyqxYI5HMXMlyGzVhyLjfiL3P5Bncn6UpBzr2+99Ak7QLPGfnjs1npY+3CmSqtEaeVG0StEOzTiL8cpZIZ2nb5vIQYeOTLebl241kYkm8qn3uny+OqFQFXleR4Q143MjWLrrR9WRfO44E1T/uyuO5tA9wSOa1lkr4dMoeRIR+qudGNuR6XS2xzNRsT/YeKz/c+DqUtxTNkOsBm3R+qtRHG38mYfeOyK889aUN5Q+Df5kj6eTmYD5pvjsIjjPYCZM3LfOiTYH/koZ/wjCIiyQK9B+cUabYrINtprcTlriqGcTVkBnQffLPT593w/fGMpoZyhbEUeA6NPnplocSd2IQJ0VR7pVs9LsfTde7DKJV37MO5J9iWrTHbtBNXnk+B54/MwtlBVyqyDD74cH90wGNnO94X+Vyw7Q6s41nFv86lORrb2v0NNesVyHdzNKtgwDNZ7o8KnAmWV8xrF8ZBI8vny+Q8eT7KvMffFRP2uQJH2Xri/kZ/OE64L9vRlrjmLpzltJry+Tm+esXipnBWCnzXUmevdGCtrEa8rq+W5Hy4UGLbMOjxGSjav9kjXr8O0zuTdN6RbZQnbpcM8ZzcU0YN31sgB13veYcx+ORga/Dcr0ZOthGeYYJs51eaco1HO4b71efm9XkLfN2rcBZDO5T5ag73je9wO8yoNR7XpTacB9Efy/utFbmOHTgDZTeU4+1yiZ+FDG7tM+C61IOxuBtymy9k5V6sAuZ9Ds6qK2RkfyITvZ7jB5/ckGvo7Rb7xnguHtplImm3cT+9JIsn3l/he55Lsy2Rz8X9NTDR9bsJ/FyDfp/PyXF5vcsLGp6t8cpIzut8jwt8Ocv75gsVru/5klxH0G61YB+qfTG8Cs+7i9WaN5Pjz2Wfn7tWqot8eDbRFrzb6ap9Rhf2LV14j3cxkmcRPVrFNuPxoc89wfmK50Jqdjq+b8JzC3z0XZVf++wsl/3Dc/y+4LWmLCueW3ClzC/hVkuy39+AsXML3p8dBrKN7tvPt/cbc/sluslkMplMJpPJZDKZTCaTyWQymUxTZC/RTSaTyWQymUwmk8lkMplMJpPJZJoiw7mAhuOExm4iQpCJiFZKGHrAf7/VlvEKG2MONyhAiOiYZBjSjMffXanwd3M5Gab6cptDFG73uauilENli5689yyEkrza5PsFsQoTFqHZGJ44PX41jDHMVYZ0YIvtHHGImH8iy7d2vsXPHfCzzlW7Il8AzxrEGCLiqHxQBh1LCkJMyCyE+exA2LeKJqQKhOidQDhWJSPbKIZw5y8eQBhueiDyuYCiqKYcwpWmMnx6AMMqC2FMCyokqQ6DMYA2OolkRRADU4fwaY0JcF2+bgFirmPFLcIwHQwPL8K8OYrk3GhDOPZuwGNZjyMf2rIOIZx9hai52+PPcMmpcEesO5a7pP75EMcOTv8fWJV9OIx4Tl6H0CCN0MHnnkCYX1NF964WuPBrJb53M5TzBhFQQ+Jx4EHI5Xw6J67Je9w5aNN8V4YdPxQ9N0kfZfcm6WIqQz2XsowaQaTMTCLzIc7h4TqP7Vsdudzcdjh0LoY6pdNNEI2S6f/um3e5jS6W+Sa3OtL2NQDX4RGjNmYghLPmy7CyMsy11eEilzuRCI3lAth9qEgULIp8HlhMDLerUl3km3FKFJGK6zSdkvv1/wIVFr0PCKK+05qk66rfcsnZrhDONSKiBNZyxBjlM3VZHsAdFYnH26P0+CTtqKGcgT9gGHQvlf2PKIrYwbEt5yFqBPYjfcAaX4dxrzEtiI44Cvm7USJtVRXuUQesRC6UTIgkPRtplnFlvgGE20dJDfJxHo1vQ2TbeoHbqKtCrl+H8O5ZmKOJK23GXAKIKbiFttPny2xzIcqV9gYy386Ix2nNZRt5oVQR+XAdQUxLSz33JOS2bMU8ZhFrRUTkAHADQ+WbcE05I21fBnyLhQJ/l1GhtyhEp2m3rA4+ahn8V1f9pgfJbB1YN301bzCEeJzw2qaRchiuPwDMRhvsAhFRCuN5Zlzn+0XT1x70LbYUM+hwzOP3XJ7XmJUiziE5J1/vcj0OABWjKC1iJj8IPfP4DPgCHvoCMh9i6boqHBuFZdroAWpDMQc7gDc6AiRErBb5vSHg74Y8/2dz0/197HfPYexAMJQ4OMQg4WMzjizDzS630ettxmn03Y7Il4f9V+hw6PiIhiJfShWKyXAu30jrFY9yboY+PCvX2grg+fZDtjUFtU8rgp97LmJ/7pYrsRn749cn6ZzH60gO/D4iolHC13mwFqWIJlSIhCBhzFsPxktZ+dAe4H8QyeUm0u6jf41zqqzcFPyuP+Z63AnkfnoneZXLOm5N0o/7HxT5Glmu75egve7FRyLf8nBpkka70xrLdukQzyM/mpmk0bcYKsOl/Y7JvSKFRwUMBLbDRl9e386evcfH9zpERGuAS1krcplqGbnW3gQX89U2X6ORV1gr3Bs7oSxfV+BrAAETy7YsgE+5kuf1AdtL29+VPNpmTus91mqZ++nyk+wzj1ryfuEW7PEBrbsfyHy4nCHeZE4ji8EGI7L4pQO5l8X3BLgH9JUNP1/kNeahGs/DUSz7ugVln4NB0YtaNE134B1QNWS0y6LC1SFacLnA17jO9LWsGYKPqrDCS0V8z8P1rfoy3wE4TIhsxHFDRDTjcdlX4Z0Dvhsikoi1AyasnfIzFqCOl8tc98/LVyfUA+QaYiNzHvu/evz6sKfH8REorF0TfLNXO+znfcdCU+R7pc39/rUWT2SNUSw7OUoM52IymUwmk8lkMplMJpPJZDKZTCbTX072Et1kMplMJpPJZDKZTCaTyWQymUymKTKcC6jku5R13VNhF3jqbQDp/XFf5MNQbwzRHacyHAhPlX7+CBAOIxkncTdh3EEjrU/SzzU4bAuRDUQyjBlPO++qENMI4m32h/h3OSQ2h/wsxJg8UpWhiyWfQzX2+hxOMV8IRL7dTQ5v+7MXOGQHQymJZFvU4STvjmojxFzgKcFFheRxyhDym0K7QITYYl7eu+jB6fAQJjQ6FUrC9wsSaGcVlRY4PF56EEJc618R+TBMrT3iez9ck7FQj0MfBDGHsx4EOiSGP2Mo5Jzi12B48ZM1DiXbGspwIAyzzkH4LyJliOQx3Bh6hOFAd1Xo3UWIrLxa5XCbzWFd5MPw5wMIE94eyD6sZ3k8Y9gbnoZNRJSDNvchrKmoEDozNQ71HEF9y1mZ7wROvQ4TbotZedg5rRe4nY9HHKKn53UPwpVyIw79zIzOcblJXtOAU+VxnkQqdPxqgefkWsz3TlXM3wywe3IQdtUeSfuGEVk4xB6qyvEbtVfoLK0VZb65LIdaVXzu66wrn3sRwvdiQLY8Ny/v92aHr3NgrviA09BthFiEhQL37Y227PeDgMO3ix63f9GRHV/J8HeIHag4sqwzucypUH/TaW0lL5PrZCjnyvDpPIRqF1OYN8r16RGvU3vu5iQ9TGU4YNXhMOZa0pikB47M1xszFmkzeWGSLvlcvhzJvq4RjytEoixn5NjJenwPxGb0I+1n8OduzOPSU7+d8GChwjuMH8BVQnwFokm0hmO2aa5aEGv+2iSd87hvZtx1kQ9xUR2o4/UO4MdUSC1+wnndVEgOnOfnMhxuvgTYGCKiGtjSDqDKZIvLEN0HzdoCYLgi8A/3htKeLOQ53yJQuPoq5BfD/yv+dMxKEYb9s/M8rjrwWGXOKUn5fohv0/i7k5DL0IZ+qmV1KDWE8sM611bkC3C1KYT1K+/JPkQkyWKR7xf2ZEX2AGNwCHN8EB+LfOfcJydpNL1bfdmj+Al9p7JipMQpdxyum8/Uee3HUHsiols9tgUe7Ef0mEL/ddo+5X4Z+B4aRyjzcfoYsAMd5S/d7LAfhMirGYWKOwwAiZiy/xs6ElWAOLejDtvs89m6yPdQjft3LgeIQBgTYSzLUAUGDvaT9h8qMG2WAZcQB0siH861PYd9jjCVdtBzMqJeprPlOff/3wvkmnw34Q75b7s8drKOtCfzebZj9YTnzVxyXuQ7AVsTpow7CeO2yDcYsT3IAGogTqajeVJYCXCfp+XDmr/s4j5ervFtQB/kM2fvnYiI9ob8B1z/s2ovUPB4bat5q5N0IyN9ELFPAJ9o6Eg0zgjMFaKZEN9CRDRweU4kgCosQp18hW/B+6HvW1B7Ih/2bHsB25PdofLdAVEzBwioUkY+F1EUT9Z5fNzplUS+Jqxz2B8NhZ6az+Gei587k5P4ICwGvqYpBAqrBn4MroFrgDAN5DJCA8CMljJ88/mq7KelD4N/eIHfy2z8f6bbryslHh9P1uT7pdc63GYRvC/xFH6lA3i9L7e4n7rqsbinxH56vCZtbhH2h8sLbJtvb8+KfDd7PJb2h9xoUSrfa3UdeKcECF5E8LzZU/4NOElqiAmh74l+ht534zubPrTLTYWTfjNmfsqCw/O9nJF2Fd+NLUqSsBC+SxnF0/OVoZIC46NQgojQwb2AbAfZYF88Zt9/Ht4DaIwP4nTR/coqvwr3LdpWodb8MkWJ9u7Plv0S3WQymUwmk8lkMplMJpPJZDKZTKYpspfoJpPJZDKZTCaTyWQymUwmk8lkMk3Re+Yl+qc//Wm6cOEC5fN5eu655+j555//dhfJZDKZTCbTA2Rrt8lkMplM7y7Z2m0ymUym96veE0z0X/u1X6NPfvKT9Iu/+Iv03HPP0S/8wi/QD/7gD9L169dpYWHhbd/nPpvNESxRIsnA7CbM3lr2JStrJsf8nhaAJTWnqhPzPUbAY9p3TkS+rsNstp3k1Uk6Pf6OSfpySfK1EuApDYGLGKaSn3QUcvnGCfPY7vYkawi5Y4/WGc5Uykho1fIs86cIqtEbSa7Xy01uozc6yKWSbY5s4SEw7xYL8t99BlDHk4jbtT2W+cYp36MGjERkrOcUa7MCs6MIHEmFcKMscMFXgJe8mErmcwosxS+P7k7SkWLmL+S4zY5h8LySKBaow9xGRFb6D/insTvQvxp5WwMWOPIiH69JdlQR2GrtEbdrBOzvY8Vl70IBsbYXy5pDz98WgLPfVQzYdsjtsjdmft2xeyDyPRZdnqTrwGZthnJSnsCYWCvA3B1JE3nU5Toi321lSXIVV4g/z+4xX3es+vAe8PaQn6bHGLI8FwCSluvx/NcM43ngBOeAJ3gYyDLk4GGIUh0qriryXJGNq5mGJyEwg+EW82ruXigBNx7ggg9LHDGtwbkK1RzbhfPr0l7evseM6t6YC7iUl/OrNeJyDOEsAeTuZVWdvClsO83nbY65rFvx/tkXEdEVsA1zOS5DyVd8yLeHZXvX6pu3dufIdTKUdeRZDMikzVFBXzYR2o1ewmlf3a+bHk7SQ5fXvH50KPL1g7vwicf23Wp9kq54krUbJ3y+wazL+R6pyzEBWE/aGDyAhzk+e/CUPLkmIy+97vB3+syMK2V+ALI2b3Qkz/U1eom/S5il2nAlo3YpuThJr3r1SVodSUOIScTZhvYpVixFMOe0PeQPx6G0absDtlU94NAewpklRETHAY+DrsO2/RFvTeRDk7kLTMlOJDtnDKtgk3gctVJpM+qDxUnadZYn6RnFX3XhzIWBstsoxOajB4fmbqTAu3kwfmtFYPCOZRmQCbsEa1TVl/e71+N8ISwQ4QPG7xDSiSKDD8DWN2DM9mPp87bcQ/iO077ieM+kVfiO/95WYFBcb2fBZ1svyzXhCvjQB3BEUISccsXnRlYs9k1BceixP4eyukLijADwV49HsqzoZyEH/SiQ5TuGOfCQz3Y6qxbKW0Me2wOH+bXZVMJYy8R+EHLV61lZvgXgoDeAkXoUsp+2VpaN9Fj1bKBroM42wnHaACN7EEof8ID4/Isc8O4r6YzIl1JCcTqdo/1u1zdr7d7sJeS7Mb3UlG0VpGyhDly2iwmp/gwunHnfKpyBQkTUcrkf8w47mf1E7hk84JOPY56wWZ/vl3Xl3r8X8hkoocdjfjZ3WeRbS/j8qxIYl65aH5CLjKNUHW1GfZijuG5WPck6/5DzzCSNzPH+WD4Xz/S6TNcm6dmMnK/o9u4CqLntyvMlULjPXSxwYc8XZRkOR7heczpJ5bzeGfB1I3i/oc+GQguCLpE+++Nckd8f9MewBx/J+2kb95Y0xxvPAauAbamqI0twjcFz2OpZ2ebgqtA8fHW1zPPmdeWL4buJRpF53wvXJPvb+5//L5O087U3JumZ2q7I9wjseTc7vE7uDGVZ6z73Da5fkbK5uM4h31+f6XG5yvkeAZ771RU53rYPeV7vH/J8PQhk+XCtxDMKa7Qo8hUT3l8vuHy/HszXsTqzKgB/pDTlLLj7+fgPyA/XbH1cl1LBEpfjMhxyn+6C7cwnyyJfa8jjZWuI66lsI9wDo/+l/aqKz9dhFZfVdgvPfcAphONjT51p90aLn7VeBh+rJPMJFrvPZb3XkWtACZbyy1lm/+spXfHdU2cfTtN74pfo//bf/lv6iZ/4CfqxH/sxevTRR+kXf/EXqVgs0i//8i9/u4tmMplMJpPpDNnabTKZTCbTu0u2dptMJpPp/ax3/S/RR6MRvfjii/QzP/Mzk7+5rksf+9jH6Atf+MKZ14RhSGEIv1pu3/9VxejrvzKP1L8sRfCvnOOUr9P/EjSCX5niPaJU/4KF7+HAv5rGJP8lPoHfCiVYBoJfsieyC7HoWAZ85n2dXVb9a1bMF8Sc7o/l/boRl70H3+lfww3hHiG01+hUm4/O/C5U/zo0mlLHVP37EObDf2HCf40OYnnvIfwM1INfCul/TcV2wfvpf3nEX6LjL1Qi1TdYVjwheKR+xQxBDOIXcLoe0+6tuxrvh/2bkvyXxwQ+9+HBwyntoIXP0f8C6EAb9cbcRmESiHzYLtjvcRqpfDhXeK7pORkm/Bn7Hctw//PZJ0l3oum/OMJ7jNXYGcQ8f4MYx7m8RwaKiyUfPWDu4vzCb/Rcw9+64DjS+VwsuzM9XwRtiW3+oLkr56EsHY5F1+O27Ix034DdgV+Y47jU5cBxJIsny4rXyF9WyI4SY5Gmj4lp49JX/R4nvC7pU9vf7fpmrt1vrY+Jmv8xuDhoc/WPddFu4FqbKNuXEv4iFuZXqn/tmJ6ZFvdWZcWxM4Lxoe05DoNQrGVy8EgbieuNHrPwXXL2+kwk7eIwRpsh169pfov+VeZ4yhronFo3WZ7wGfBX+LKNEtFGnH6QrRpDf+i5O07P9tNwHt9/1tk+DT7n/v3OthMJqfEr/CBeA9+uLdXrsL6On3P2vYiIHPiJ4xAy6jYPp/g++MtnXYZp5SbSPvT0X6JHYo3h9tfjMp4yLhMVpSnnIfrJet7geg39qyMppozFQQxRMsrPCIRfMPXWU8e2bkvsK5zHp/vw7L7ReyLhv+IcUGEk08Z5nKpIUfyFXYpzSM4b9B0HUI8QfkGrfafhlHAu/Uv08RS7enocQT1gTOk64S/Rbe2evna/NX7Gp9YHbFu0kWqdO7W3fet6tS7B/cQ6nupfgceQTs78u7YZ0/O9vTX+9PrA14l1WBmAaWtMrGwk7lXwOz0sx+I9w/Q9r7CLYt1UP8cW9z7bp9HzE7/Deai3LaMpe0BH/RI9EjYS/HhVJ9xn4Jo1fICNROnoYblW8t/11Wh38Dv9nHDKHgnXkUC15bS9bCeUc83v8q+YnQG3QzfSeyyI9BL7rekb/gf9Ej2AXzW/3fULn3u6fPxdFubkINZ+GpQpxXXp7fmoKC/R75r4cwbqq98HYX1x36f7Xb4XOHt91mV3oO663GOxv0H7oUfm2fvkSNk+bMu36x9O+yX66feAPNDRxwpUY06zJ33lV729vf99O/G2993pu1zb29spEaV/9md/Jv7+z//5P0+fffbZM6/5uZ/7uZTu783sf/vf/rf/7X/7/x3//+bm5rdiSf2WydZu+9/+t//tf/v/vf6/rd22dtv/9r/9b//b/++u/7/R2v2u/yX6X0Q/8zM/Q5/85Ccnn1utFp0/f542NjaoVqt9G0v2rVen06H19XXa3NykarX6jS94j8nqb/W3+lv938n1T9OUut0uraysfOPM73HZ2s16N4zdv0pZ/a3+Vn+r/zu5/rZ2s2ztZr0bxu5fpaz+Vn+rv9X/nVz/t7t2v+tfos/NzZHnebS/Lw9k2t/fp6WlpTOvyeVylMvlTv29Vqu9Yzv0r1rVavV9W3ciq7/V3+pv9X/n1v+9uMm0tfubo3f62P2rltXf6m/1t/q/U2Vr933Z2n1a7/Sx+1ctq7/V3+pv9X+n6u2s3e/6g0Wz2Sx98IMfpM9+9rOTvyVJQp/97Gfpox/96LexZCaTyWQymc6Srd0mk8lkMr27ZGu3yWQymd7vetf/Ep2I6JOf/CT9w3/4D+lDH/oQPfvss/QLv/AL1O/36cd+7Me+3UUzmUwmk8l0hmztNplMJpPp3SVbu00mk8n0ftZ74iX63/27f5cODw/pZ3/2Z2lvb4+efvpp+v3f/31aXFx8W9fncjn6uZ/7uTNDzd7rej/Xncjqb/W3+lv937/1/3bL1u6/uN7PdSey+lv9rf5W//dv/b/dsrX7L673c92JrP5Wf6u/1f+9UX8nTdP0210Ik8lkMplMJpPJZDKZTCaTyWQymd6Jetcz0U0mk8lkMplMJpPJZDKZTCaTyWT6q5K9RDeZTCaTyWQymUwmk8lkMplMJpNpiuwluslkMplMJpPJZDKZTCaTyWQymUxTZC/RTSaTyWQymUwmk8lkMplMJpPJZJqi9/1L9E9/+tN04cIFyufz9Nxzz9Hzzz//7S7SX4k+9alP0Yc//GGqVCq0sLBAP/RDP0TXr18XeYIgoE984hPUaDSoXC7Txz/+cdrf3/82lfivTj//8z9PjuPQT/3UT03+9l6v+/b2Nv39v//3qdFoUKFQoCeeeIJeeOGFyfdpmtLP/uzP0vLyMhUKBfrYxz5GN2/e/DaW+JunOI7pX/7Lf0kXL16kQqFAly9fpn/1r/4V4ZnK76X6f+5zn6O//bf/Nq2srJDjOPSbv/mb4vu3U9eTkxP6kR/5EapWq1Sv1+nHf/zHqdfrfQtr8RfXg+ofRRH99E//ND3xxBNUKpVoZWWF/sE/+Ae0s7Mj7vFurv/7RbZ2s97r69dbsrXb1m5bu23ttrX73S1bu1nv9fXrLdnabWu3rd22dr/n1u70faz/8l/+S5rNZtNf/uVfTl999dX0J37iJ9J6vZ7u7+9/u4v2TdcP/uAPpr/yK7+SvvLKK+lXv/rV9G/+zb+Znjt3Lu31epM8//gf/+N0fX09/exnP5u+8MIL6Uc+8pH0O77jO76Npf7m6/nnn08vXLiQPvnkk+k//af/dPL393LdT05O0vPnz6c/+qM/mv75n/95evv27fQP/uAP0lu3bk3y/PzP/3xaq9XS3/zN30xfeuml9O/8nb+TXrx4MR0Oh9/Gkn9z9K//9b9OG41G+ju/8zvpnTt30l//9V9Py+Vy+u///b+f5Hkv1f93f/d303/xL/5F+l//639NiSj9jd/4DfH926nrX//rfz196qmn0i9+8Yvpn/zJn6RXrlxJ/97f+3vf4pr8xfSg+rdarfRjH/tY+mu/9mvpG2+8kX7hC19In3322fSDH/yguMe7uf7vB9nabWv3W3ov193Wblu7UbZ229r9bpet3bZ2v6X3ct1t7ba1G2Vr93tz7X5fv0R/9tln00984hOTz3EcpysrK+mnPvWpb2OpvjU6ODhIiSj94z/+4zRN7w9y3/fTX//1X5/kef3111MiSr/whS98u4r5TVW3202vXr2afuYzn0m/53u+Z7KYv9fr/tM//dPpd33Xd039PkmSdGlpKf03/+bfTP7WarXSXC6X/uf//J+/FUX8K9Xf+lt/K/1H/+gfib/98A//cPojP/IjaZq+t+uvF7O3U9fXXnstJaL0S1/60iTP7/3e76WO46Tb29vfsrJ/M3SWM6P1/PPPp0SU3rt3L03T91b936uytdvW7jR979fd1m5bu9/S/7+9+4+pqv7jOP663Cuoc/xQ8oLJLVwuSq1QhpF/ZMmWjc3qj1rMEmTTmbAwt9Jo/mn1R9OSGv0U3FIZf1CQa23EDzeaoiGUTgNWDPpDomyITRTkfr5/NM44X7yGfYH7Pec+Hxub95wD9/O+E57bBzjQ7olot/PQbtptjPtnp920ewztnsgt7Y7Y27kMDw+rtbVV2dnZ1rGoqChlZ2frxIkTYVzZzLh8+bIkaf78+ZKk1tZWjYyM2F6PtLQ0BQIB17wehYWFysnJsc0ouX/22tpaZWRk6Nlnn9XChQuVnp6uTz75xDrf3d2tvr4+2/xxcXFavXq1K+Z/5JFHVF9fr87OTknSDz/8oObmZj355JOS3D//eJOZ9cSJE4qPj1dGRoZ1TXZ2tqKiotTS0jLja55uly9flsfjUXx8vKTIm99paDftHuP22Wk37R5Duyei3c5Cu2n3GLfPTrtp9xjaPZFb2u0L9wLC5Y8//tDo6Kj8fr/tuN/v108//RSmVc2MYDCoHTt2aM2aNVq+fLkkqa+vT9HR0dZ/6DF+v199fX1hWOXUqqys1JkzZ3T69OkJ59w++y+//KKysjLt3LlTJSUlOn36tF5++WVFR0crLy/PmvFmnwtumH/37t0aHBxUWlqavF6vRkdHtXfvXm3cuFGSXD//eJOZta+vTwsXLrSd9/l8mj9/vutej2vXrmnXrl3Kzc1VbGyspMia34loN+0e4/bZaTftHkO77Wi389Bu2j3G7bPTbto9hnbbuandEbuJHskKCwt17tw5NTc3h3spM+LXX39VcXGx6urqNHv27HAvZ8YFg0FlZGTozTfflCSlp6fr3Llz+vDDD5WXlxfm1U2/qqoqHT58WEeOHNGyZcvU3t6uHTt2aNGiRRExP25uZGREzz33nIwxKisrC/dygH9EuyML7abdmIh2w2lod2Sh3bQbE7mt3RF7O5fExER5vd4Jfwn6t99+U1JSUphWNf2Kiop07NgxNTY2avHixdbxpKQkDQ8Pa2BgwHa9G16P1tZW9ff3a+XKlfL5fPL5fDp+/LgOHDggn88nv9/v2tklKTk5Wffff7/t2H333afe3l5JsmZ06+fCq6++qt27d+v555/XihUr9OKLL+qVV17RW2+9Jcn98483mVmTkpLU399vO3/jxg39+eefrnk9xkLe09Ojuro667vhUmTM72S0m3bTbtotuX/+8Wj332i3c9Fu2k27abfk/vnHo91/c2O7I3YTPTo6WqtWrVJ9fb11LBgMqr6+XllZWWFc2fQwxqioqEhffPGFGhoalJqaaju/atUqzZo1y/Z6dHR0qLe31/Gvx7p163T27Fm1t7dbbxkZGdq4caP1b7fOLklr1qxRR0eH7VhnZ6fuuusuSVJqaqqSkpJs8w8ODqqlpcUV81+9elVRUfYvdV6vV8FgUJL75x9vMrNmZWVpYGBAra2t1jUNDQ0KBoNavXr1jK95qo2FvKurS99++60WLFhgO+/2+Z2OdtNu2k27JffPPx7tpt1OR7tpN+2m3ZL75x+Pdru43eH8q6bhVllZaWJiYkxFRYU5f/682bp1q4mPjzd9fX3hXtqUe+mll0xcXJxpamoyFy9etN6uXr1qXbNt2zYTCARMQ0OD+f77701WVpbJysoK46qnz/i/Em6Mu2c/deqU8fl8Zu/evaarq8scPnzYzJ0713z++efWNW+//baJj483NTU15scffzRPPfWUSU1NNUNDQ2Fc+dTIy8szd955pzl27Jjp7u421dXVJjEx0bz22mvWNW6a/8qVK6atrc20tbUZSWbfvn2mra3N+ivYk5l1/fr1Jj093bS0tJjm5mazdOlSk5ubG66Rbsut5h8eHjYbNmwwixcvNu3t7bavhdevX7c+hpPnjwS0m3aPcfPstJt2027a7Sa0m3aPcfPstJt20273tzuiN9GNMaa0tNQEAgETHR1tMjMzzcmTJ8O9pGkh6aZv5eXl1jVDQ0Nm+/btJiEhwcydO9c888wz5uLFi+Fb9DT675i7ffavvvrKLF++3MTExJi0tDTz8ccf284Hg0GzZ88e4/f7TUxMjFm3bp3p6OgI02qn1uDgoCkuLjaBQMDMnj3bLFmyxLzxxhu2L95umr+xsfGmn+t5eXnGmMnNeunSJZObm2vmzZtnYmNjzebNm82VK1fCMM3tu9X83d3dIb8WNjY2Wh/DyfNHCtpdbl3j9n6NR7tpN+2m3bTbuWh3uXWN2/s1Hu2m3bSbdrup3R5jjPn3P8cOAAAAAAAAAIB7Rew90QEAAAAAAAAA+CdsogMAAAAAAAAAEAKb6AAAAAAAAAAAhMAmOgAAAAAAAAAAIbCJDgAAAAAAAABACGyiAwAAAAAAAAAQApvoAAAAAAAAAACEwCY6AAAAAAAAAAAhsIkOIOzy8/P19NNPh3sZAABgkmg3AADOQruB/w2b6ACUn58vj8ejbdu2TThXWFgoj8ej/Pz8KX3Onp4ezZkzR3/99deUflwAACIB7QYAwFloN+BsbKIDkCSlpKSosrJSQ0ND1rFr167pyJEjCgQCU/58NTU1euyxxzRv3rwp/9gAAEQC2g0AgLPQbsC52EQHIElauXKlUlJSVF1dbR2rrq5WIBBQenq6dWzt2rUqKipSUVGR4uLilJiYqD179sgYY11z/fp17dq1SykpKYqJidE999yjzz77zPZ8NTU12rBhg+3YO++8o+TkZC1YsECFhYUaGRmZpmkBAHA+2g0AgLPQbsC52EQHYCkoKFB5ebn1+ODBg9q8efOE6w4dOiSfz6dTp07pvffe0759+/Tpp59a5zdt2qSjR4/qwIEDunDhgj766CPbd74HBgbU3Nxsi3ljY6N+/vlnNTY26tChQ6qoqFBFRcX0DAoAgEvQbgAAnIV2A87kC/cCAPz/eOGFF/T666+rp6dHkvTdd9+psrJSTU1NtutSUlK0f/9+eTwe3XvvvTp79qz279+vLVu2qLOzU1VVVaqrq1N2drYkacmSJbb3//rrr/XAAw9o0aJF1rGEhAS9//778nq9SktLU05Ojurr67Vly5bpHRoAAAej3QAAOAvtBpyJn0QHYLnjjjuUk5OjiooKlZeXKycnR4mJiROue/jhh+XxeKzHWVlZ6urq0ujoqNrb2+X1evXoo4+GfJ6b/UrZsmXL5PV6rcfJycnq7++fgqkAAHAv2g0AgLPQbsCZ+El0ADYFBQUqKiqSJH3wwQe3/f5z5sy55fnh4WF98803KikpsR2fNWuW7bHH41EwGLzt5wcAINLQbgAAnIV2A87DT6IDsFm/fr2Gh4c1MjKiJ5544qbXtLS02B6fPHlSS5culdfr1YoVKxQMBnX8+PGbvm9TU5MSEhL04IMPTvnaAQCIRLQbAABnod2A87CJDsDG6/XqwoULOn/+vO3XvMbr7e3Vzp071dHRoaNHj6q0tFTFxcWSpLvvvlt5eXkqKCjQl19+qe7ubjU1NamqqkqSVFtbO+FXygAAwL9HuwEAcBbaDTgPt3MBMEFsbOwtz2/atElDQ0PKzMyU1+tVcXGxtm7dap0vKytTSUmJtm/frkuXLikQCFi/RlZbW6uDBw9O6/oBAIg0tBsAAGeh3YCzeIwxJtyLAOAca9eu1UMPPaR33333tt/3zJkzevzxx/X7779PuBcbAACYHrQbAABnod3A/x9u5wJgxty4cUOlpaWEHAAAh6DdAAA4C+0Gpge3cwEwYzIzM5WZmRnuZQAAgEmi3QAAOAvtBqYHt3MBAAAAAAAAACAEbucCAAAAAAAAAEAIbKIDAAAAAAAAABACm+gAAAAAAAAAAITAJjoAAAAAAAAAACGwiQ4AAAAAAAAAQAhsogMAAAAAAAAAEAKb6AAAAAAAAAAAhMAmOgAAAAAAAAAAIfwHkbHAyD6toR4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sum_over = box_size[0] // 8\n", + "\n", + "# Function to create subplots\n", + "def plot_subplots(proj_axis, input , row, axes, title):\n", + " slicing = [slice(None)] * input.ndim\n", + " slicing[proj_axis] = slice(None, sum_over)\n", + " slicing = tuple(slicing)\n", + "\n", + " # Plot initial conditions\n", + " axes[row, proj_axis].imshow(input[slicing].sum(axis=proj_axis), cmap='magma', extent=[0, box_size + 5, 0, box_size + 5])\n", + " axes[row, proj_axis].set_xlabel('Mpc/h')\n", + " axes[row, proj_axis].set_ylabel('Mpc/h')\n", + " axes[row, proj_axis].set_title(title)\n", + "\n", + "# Initialize figure and axes\n", + "if only_final_fields:\n", + " nb_rows = len(final_fields)\n", + " field_start = 0\n", + "else:\n", + " nb_rows = 2 + len(final_fields)\n", + " field_start = 2\n", + " \n", + "nb_cols = 3\n", + "fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))\n", + "\n", + "# Plot initial conditions and LPT field for each projection\n", + "if not only_final_fields:\n", + " for proj_axis in range(3):\n", + " plot_subplots(proj_axis,initial_conditions , 0, axes, f'Initial conditions projection {proj_axis}')\n", + " plot_subplots(proj_axis, field , 1, axes, f'LPT density field projection {proj_axis}')\n", + "\n", + "if len(final_fields) == 1: # Check if axes is 1-dimensional\n", + " axes = np.expand_dims(axes,axis=0)\n", + "# Plot final fields for each projection\n", + "for indx, final_field in enumerate(final_fields):\n", + " for proj_axis in range(3):\n", + " slicing = [slice(None)] * final_field.ndim\n", + " slicing[proj_axis] = slice(None, sum_over)\n", + " slicing = tuple(slicing)\n", + " axes[indx + field_start, proj_axis].imshow(final_fields[indx][slicing].sum(axis=proj_axis) + 1, cmap='magma', extent=[0, box_size[0] + 5, 0, box_size[0] + 5])\n", + " axes[indx + field_start, proj_axis].set_xlabel('Mpc/h')\n", + " axes[indx + field_start, proj_axis].set_ylabel('Mpc/h')\n", + " axes[indx + field_start, proj_axis].set_title(f'ODE Step {indx} projection {proj_axis}')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ea64b6d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/scripts/particle_mesh.slurm b/scripts/particle_mesh.slurm new file mode 100644 index 0000000..2585d5d --- /dev/null +++ b/scripts/particle_mesh.slurm @@ -0,0 +1,167 @@ +#!/bin/bash +########################################## +## SELECT EITHER tkc@a100 OR tkc@v100 ## +########################################## +#SBATCH --account tkc@a100 +########################################## +#SBATCH --job-name=Particle-Mesh # nom du job +# Il est possible d'utiliser une autre partition que celle par default +# en activant l'une des 5 directives suivantes : +########################################## +## SELECT EITHER a100 or v100-32g ## +########################################## +#SBATCH -C a100 +########################################## +#****************************************** +########################################## +## SELECT Number of nodes and GPUs per node +## For A100 ntasks-per-node and gres=gpu should be 8 +## For V100 ntasks-per-node and gres=gpu should be 4 +########################################## +#SBATCH --nodes=1 # nombre de noeud +#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud) +#SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5) +########################################## +## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant +## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant +## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: +########################################## +#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) +########################################## +# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm +#SBATCH --hint=nomultithread # hyperthreading desactive +#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) +#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie +#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) +#SBATCH --qos=qos_gpu-dev +#SBATCH --exclusive # ressources dediees + +# Nettoyage des modules charges en interactif et herites par defaut +num_nodes=$SLURM_JOB_NUM_NODES +num_gpu_per_node=$SLURM_NTASKS_PER_NODE +OUTPUT_FOLDER_ARGS=1 +# Calculate the number of GPUs +nb_gpus=$(( num_nodes * num_gpu_per_node)) + +module purge + +# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" +# pour avoir acces aux modules compatibles avec cette partition + +if [ $num_gpu_per_node -eq 8 ]; then + module load cpuarch/amd + source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate +else + source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate +fi + +# Chargement des modules +module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake +module load nvidia-nsight-systems/2024.1.1.59 + +echo "The number of nodes allocated for this job is: $num_nodes" +echo "The number of GPUs allocated for this job is: $nb_gpus" + +export EQX_ON_ERROR=nan +export CUDA_ALLOC=1 + +function profile_python() { + if [ $# -lt 1 ]; then + echo "Usage: profile_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="prof_traces/$script_name" + local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="prof_traces/$script_name/$args" + report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" + fi + + mkdir -p "$output_dir" + mkdir -p "$report_dir" + + srun nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + +function run_python() { + if [ $# -lt 1 ]; then + echo "Usage: run_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="traces/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="traces/$script_name/$args" + fi + + mkdir -p "$output_dir" + + srun python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + + +# Echo des commandes lancees +set -x + +# Pour la partition "gpu_p5", le code doit etre compile avec les modules compatibles +# Execution du code avec binding via bind_gpu.sh : 1 GPU par tache + + + + +declare -A pdims_table +# Define the table +pdims_table[4]="2x2 1x4" +pdims_table[8]="2x4 1x8" +pdims_table[16]="2x8 1x16" +pdims_table[32]="4x8 1x32" +pdims_table[64]="4x16 1x64" +pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2" +pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4" + + +#mpch=(128 256 512 1024 2048 4096) +grid=(1024 2048 4096) + +pdim="${pdims_table[$nb_gpus]}" +echo "pdims: $pdim" + +# Check if pdims is not empty +if [ -z "$pdim" ]; then + echo "pdims is empty" + echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160" + echo "Number of nodes selected: $num_nodes" + echo "Number of gpus per node: $num_gpu_per_node" + exit 1 +fi + +# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 + +if [ $num_gpu_per_node -eq 8 ]; then + gpu_name="a100" +else + gpu_name="v100" +fi + +out_dir="out/$gpu_name/$nb_gpus" + +echo "Output dir is : $out_dir" + +for g in ${grid[@]}; do + for p in ${pdim[@]}; do + # halo is 1/4 of the grid size + halo_size=$((g / 4)) + slaunch scripts/fastpm_jaxdecomp.py -m $g -b $g -p $p -hs $halo_size -ode diffrax -o $out_dir + done +done