Skip to content

Commit

Permalink
Replaced conditional_actnorm_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
famura committed Oct 15, 2020
1 parent 6637c3b commit 3794a92
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 44 deletions.
21 changes: 0 additions & 21 deletions Pyrado/pyrado/sampling/parallel_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,3 @@ def eval_randomized_domain(pool: SamplerPool,
# Run with progress bar
with tqdm(leave=False, file=sys.stdout, unit='rollouts', desc='Sampling') as pb:
return pool.run_map(_run_rollout_nom, init_states, pb)


def conditional_actnorm_wrapper(env: Env, ex_dirs: list, idx: int):
"""
Wrap the environment with an action normalization wrapper if the simulated environment had one.
:param env: environment to sample from
:param ex_dirs: list of experiment directories that will be loaded
:param idx: index of the current directory
:return: modified environment
"""
# Get the simulation environment
env_sim, _, _ = load_experiment(ex_dirs[idx])

if typed_env(env_sim, ActNormWrapper) is not None:
env = ActNormWrapper(env)
print_cbt(f'Added an action normalization wrapper to {idx + 1}-th evaluation policy.', 'y')
else:
env = remove_env(env, ActNormWrapper)
print_cbt(f'Removed an action normalization wrapper to {idx + 1}-th evaluation policy.', 'y')
return env
4 changes: 3 additions & 1 deletion Pyrado/pyrado/utils/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from pyrado.environment_wrappers.observation_partial import ObsPartialWrapper
from pyrado.environment_wrappers.utils import typed_env
from pyrado.environments.base import Env
from pyrado.environments.real_base import RealEnv
from pyrado.environments.sim_base import SimEnv
from pyrado.logger.experiment import load_dict_from_yaml
from pyrado.policies.adn import pd_linear, pd_cubic, pd_capacity_21_abs, pd_capacity_21, pd_capacity_32, \
Expand Down Expand Up @@ -213,7 +214,8 @@ def load_experiment(ex_dir: str, args: Any = None) -> (Union[SimEnv, EnvWrapper]
return env, policy, kwout


def wrap_like_other_env(env_targ: Env, env_src: [SimEnv, EnvWrapper], use_downsampling: bool = False) -> Env:
def wrap_like_other_env(env_targ: Union[SimEnv, RealEnv], env_src: [SimEnv, EnvWrapper], use_downsampling: bool = False
) -> Union[SimEnv, RealEnv]:
"""
Wrap a given real environment like it's simulated counterpart (except the domain randomization of course).
Expand Down
15 changes: 8 additions & 7 deletions Pyrado/scripts/evaluation/eval_policies_domain_grid_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@
from pyrado.environment_wrappers.utils import typed_env
from pyrado.environments.pysim.quanser_qube import QQubeSwingUpSim
from pyrado.logger.experiment import setup_experiment, save_list_of_dicts_to_yaml
from pyrado.sampling.parallel_evaluation import eval_domain_params, conditional_actnorm_wrapper
from pyrado.sampling.parallel_evaluation import eval_domain_params
from pyrado.sampling.sampler_pool import SamplerPool
from pyrado.utils.argparser import get_argparser
from pyrado.utils.checks import check_all_lengths_equal
from pyrado.utils.data_types import dict_arraylike_to_float
from pyrado.utils.experiments import load_experiment
from pyrado.utils.experiments import load_experiment, wrap_like_other_env
from pyrado.utils.input_output import print_cbt


Expand Down Expand Up @@ -174,10 +174,11 @@

# Loading the policies
ex_dirs = [osp.join(p, e) for p, e in zip(prefixes, ex_names)]
policies = []
env_sim_list = []
policy_list = []
for ex_dir in ex_dirs:
_, policy, _ = load_experiment(ex_dir, args)
policies.append(policy)
policy_list.append(policy)

# Create one-dim results grid and ensure right number of rollouts
param_list = param_grid(param_spec)
Expand All @@ -190,7 +191,7 @@
df = pd.DataFrame(columns=['policy', 'ret', 'len', varied_param_key])

# Evaluate all policies
for i, policy in enumerate(policies):
for i, (env_sim, policy) in enumerate(zip(env_sim_list, policy_list)):
# Create a new sampler pool for every policy to synchronize the random seeds i.e. init states
pool = SamplerPool(args.num_workers)

Expand All @@ -201,8 +202,8 @@
else:
print_cbt('No seed was set', 'y')

# Add an action normalization wrapper if the policy was trained with one
env = conditional_actnorm_wrapper(env, ex_dirs, i)
# Add the same wrappers as during training
env = wrap_like_other_env(env, env_sim)

# Sample rollouts
ros = eval_domain_params(pool, env, policy, param_list, init_state)
Expand Down
17 changes: 9 additions & 8 deletions Pyrado/scripts/evaluation/eval_policies_dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
from pyrado.environments.pysim.quanser_ball_balancer import QBallBalancerSim
from pyrado.environments.pysim.quanser_cartpole import QCartPoleSwingUpSim, QCartPoleStabSim
from pyrado.logger.experiment import setup_experiment, save_list_of_dicts_to_yaml
from pyrado.sampling.parallel_evaluation import eval_randomized_domain, conditional_actnorm_wrapper
from pyrado.sampling.parallel_evaluation import eval_randomized_domain
from pyrado.sampling.sampler_pool import SamplerPool
from pyrado.utils.argparser import get_argparser
from pyrado.utils.checks import check_all_lengths_equal
from pyrado.utils.data_types import dict_arraylike_to_float
from pyrado.utils.experiments import load_experiment
from pyrado.utils.experiments import load_experiment, wrap_like_other_env
from pyrado.utils.input_output import print_cbt


Expand Down Expand Up @@ -103,10 +103,11 @@

# Loading the policies
ex_dirs = [osp.join(p, e) for p, e in zip(prefixes, ex_names)]
policies = []
env_sim_list = []
policy_list = []
for ex_dir in ex_dirs:
_, policy, _ = load_experiment(ex_dir, args)
policies.append(policy)
env_sim, policy, _ = load_experiment(ex_dir, args)
policy_list.append(policy)

# Fix initial state (set to None if it should not be fixed)
init_state_list = [None]*args.num_ro_per_config
Expand All @@ -115,7 +116,7 @@
df = pd.DataFrame(columns=['policy', 'ret', 'len'])

# Evaluate all policies
for i, policy in enumerate(policies):
for i, (env_sim, policy) in enumerate(zip(env_sim_list, policy_list)):
# Create a new sampler pool for every policy to synchronize the random seeds i.e. init states
pool = SamplerPool(args.num_workers)

Expand All @@ -126,8 +127,8 @@
else:
print_cbt('No seed was set', 'y')

# Add an action normalization wrapper if the policy was trained with one
env = conditional_actnorm_wrapper(env, ex_dirs, i)
# Add the same wrappers as during training
env = wrap_like_other_env(env, env_sim)

# Sample rollouts
ros = eval_randomized_domain(pool, env, pert, policy, init_state_list) # internally calls DomainRandWrapperLive
Expand Down
15 changes: 8 additions & 7 deletions Pyrado/scripts/evaluation/eval_policies_nominal_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
from pyrado.environments.pysim.quanser_ball_balancer import QBallBalancerSim
from pyrado.environments.pysim.quanser_cartpole import QCartPoleSwingUpSim, QCartPoleStabSim
from pyrado.logger.experiment import setup_experiment, save_list_of_dicts_to_yaml
from pyrado.sampling.parallel_evaluation import eval_nominal_domain, conditional_actnorm_wrapper
from pyrado.sampling.parallel_evaluation import eval_nominal_domain
from pyrado.sampling.sampler_pool import SamplerPool
from pyrado.utils.argparser import get_argparser
from pyrado.utils.checks import check_all_lengths_equal
from pyrado.utils.data_types import dict_arraylike_to_float
from pyrado.utils.experiments import load_experiment
from pyrado.utils.experiments import load_experiment, wrap_like_other_env
from pyrado.utils.input_output import print_cbt


Expand Down Expand Up @@ -96,10 +96,11 @@

# Loading the policies
ex_dirs = [osp.join(p, e) for p, e in zip(prefixes, ex_names)]
policies = []
env_sim_list = []
policy_list = []
for ex_dir in ex_dirs:
_, policy, _ = load_experiment(ex_dir, args)
policies.append(policy)
policy_list.append(policy)

# Fix initial state (set to None if it should not be fixed)
init_state_list = [None]*args.num_ro_per_config
Expand All @@ -108,7 +109,7 @@
df = pd.DataFrame(columns=['policy', 'ret', 'len'])

# Evaluate all policies
for i, policy in enumerate(policies):
for i, (env_sim, policy) in enumerate(zip(env_sim_list, policy_list)):
# Create a new sampler pool for every policy to synchronize the random seeds i.e. init states
pool = SamplerPool(args.num_workers)

Expand All @@ -119,8 +120,8 @@
else:
print_cbt('No seed was set', 'y')

# Add an action normalization wrapper if the policy was trained with one
env = conditional_actnorm_wrapper(env, ex_dirs, i)
# Add the same wrappers as during training
env = wrap_like_other_env(env, env_sim)

# Sample rollouts
ros = eval_nominal_domain(pool, env, policy, init_state_list)
Expand Down

0 comments on commit 3794a92

Please sign in to comment.