Skip to content

Commit

Permalink
Linting for ARPL;
Browse files Browse the repository at this point in the history
Removed unused constructor args for ARPL
  • Loading branch information
famura committed Dec 1, 2021
1 parent 87c0fc2 commit d5d3d29
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 53 deletions.
16 changes: 4 additions & 12 deletions Pyrado/pyrado/algorithms/meta/arpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@
)
from pyrado.environment_wrappers.state_augmentation import StateAugmentationWrapper
from pyrado.environments.sim_base import SimEnv
from pyrado.exploration.stochastic_action import StochasticActionExplStrat
from pyrado.logger.step import StepLogger
from pyrado.policies.base import Policy
from pyrado.sampling.parallel_rollout_sampler import ParallelRolloutSampler
from pyrado.sampling.sequences import *


Expand All @@ -61,11 +59,7 @@ def __init__(
env: Union[SimEnv, StateAugmentationWrapper],
subrtn: Algorithm,
policy: Policy,
expl_strat: StochasticActionExplStrat,
max_iter: int,
num_rollouts: int = None,
steps_num: int = None,
apply_dynamics_noise: bool = False,
logger: StepLogger = None,
):
"""
Expand All @@ -75,11 +69,7 @@ def __init__(
:param env: the environment in which the agent should be trained
:param subrtn: algorithm which performs the policy / value-function optimization
:param policy: policy to be updated
:param expl_strat: the exploration strategy
:param max_iter: the maximum number of iterations
:param num_rollouts: the number of rollouts to be performed for each update step
:param steps_num: the number of steps to be performed for each update step
:param apply_dynamics_noise: whether adversarially generated dynamics noise should be applied
:param logger: logger for every step of the algorithm, if `None` the default logger will be created
"""
assert isinstance(subrtn, Algorithm)
Expand Down Expand Up @@ -107,7 +97,7 @@ def wrap_env(
halfspan: float = 0.25,
proc_eps: float = 0.01,
proc_phi: float = 0.05,
torch_observation = None,
torch_observation=None,
obs_eps: float = 0.01,
obs_phi: float = 0.05,
):
Expand All @@ -127,7 +117,9 @@ def wrap_env(
"""
# Initialize adversarial wrappers in the correct order
if dynamics:
assert isinstance(env, StateAugmentationWrapper), pyrado.TypeErr(env, given_name='env', expected_type=StateAugmentationWrapper)
assert isinstance(env, StateAugmentationWrapper), pyrado.TypeErr(
env, given_name="env", expected_type=StateAugmentationWrapper
)
env = AdversarialDynamicsWrapper(env, policy, dyn_eps, dyn_phi, halfspan)
if process:
env = AdversarialStateWrapper(env, policy, proc_eps, proc_phi, torch_observation=torch_observation)
Expand Down
17 changes: 9 additions & 8 deletions Pyrado/pyrado/environment_wrappers/adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@
from typing import Callable, Optional

import numpy as np
from numpy.core import numeric
import pyrado
from pyrado.algorithms.base import Algorithm
from pyrado.environments.base import Env
from pyrado.policies.base import Policy
import torch as to
from init_args_serializer import Serializable
from numpy.core import numeric

import pyrado
from pyrado.algorithms.base import Algorithm
from pyrado.environment_wrappers.base import EnvWrapper
from pyrado.environment_wrappers.state_augmentation import StateAugmentationWrapper
from pyrado.environment_wrappers.utils import inner_env, typed_env
from pyrado.environments.base import Env
from pyrado.policies.base import Policy


class AdversarialWrapper(EnvWrapper, ABC):
Expand Down Expand Up @@ -107,7 +107,9 @@ def get_arpl_grad(self, state):
class AdversarialStateWrapper(AdversarialWrapper, Serializable):
""" " Wrapper to apply adversarial perturbations to the state (used in ARPL)"""

def __init__(self, wrapped_env: Env, policy: Policy, eps: numeric, phi, torch_observation:Optional[Callable]=None):
def __init__(
self, wrapped_env: Env, policy: Policy, eps: numeric, phi, torch_observation: Optional[Callable] = None
):
"""
Constructor
Expand All @@ -119,10 +121,9 @@ def __init__(self, wrapped_env: Env, policy: Policy, eps: numeric, phi, torch_ob
Serializable._init(self, locals())
AdversarialWrapper.__init__(self, wrapped_env, policy, eps, phi)
if not torch_observation:
raise pyrado.TypeErr(msg='The observation must be passed as torch')
raise pyrado.TypeErr(msg="The observation must be passed as torch")
self.torch_obs = torch_observation


def step(self, act: np.ndarray) -> tuple:
obs, reward, done, info = self.wrapped_env.step(act)
saw = typed_env(self.wrapped_env, StateAugmentationWrapper)
Expand Down
17 changes: 5 additions & 12 deletions Pyrado/scripts/training/qq-su_arpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
from pyrado.utils.argparser import get_argparser
from pyrado.utils.data_types import EnvSpec


def torch_observation(state: to.tensor) -> to.tensor:
return to.stack([
to.sin(state[0]),
to.cos(state[0]),
to.sin(state[1]),
to.cos(state[1]),
state[2],
state[3]])
return to.stack([to.sin(state[0]), to.cos(state[0]), to.sin(state[1]), to.cos(state[1]), state[2], state[3]])


if __name__ == "__main__":
# Parse command line arguments
Expand All @@ -47,8 +43,6 @@ def torch_observation(state: to.tensor) -> to.tensor:
policy_hparam = dict(hidden_sizes=[32, 32], hidden_nonlin=to.tanh) # FNN
policy = FNNPolicy(spec=env.spec, **policy_hparam)



env = ARPL.wrap_env(
env,
policy,
Expand All @@ -62,7 +56,7 @@ def torch_observation(state: to.tensor) -> to.tensor:
obs_eps=0.05,
proc_phi=0.1,
proc_eps=0.03,
torch_observation=torch_observation
torch_observation=torch_observation,
)

# Critic
Expand Down Expand Up @@ -94,10 +88,9 @@ def torch_observation(state: to.tensor) -> to.tensor:
)
algo_hparam = dict(
max_iter=500,
steps_num=23 * env.max_steps,
)
subrtn = PPO(ex_dir, env, policy, critic, **subrtn_hparam)
algo = ARPL(ex_dir, env, subrtn, policy, subrtn.expl_strat, **algo_hparam)
algo = ARPL(ex_dir, env, subrtn, policy, **algo_hparam)

# Save the hyper-parameters
save_dicts_to_yaml(
Expand Down
17 changes: 6 additions & 11 deletions Pyrado/tests/algorithms/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,17 +502,12 @@ def test_arpl_wrappers(env):
env = StateAugmentationWrapper(env, domain_param=None)
assert len(inner_env(env).domain_param) == env.obs_space.flat_dim - env.offset
env.reset()
env.step(0.0)[0][env.offset:]
env.step(0.0)[0][env.offset :]


def _qqsu_torch_observation(state: to.tensor) -> to.tensor:
return to.stack([
to.sin(state[0]),
to.cos(state[0]),
to.sin(state[1]),
to.cos(state[1]),
state[2],
state[3]])
return to.stack([to.sin(state[0]), to.cos(state[0]), to.sin(state[1]), to.cos(state[1]), state[2], state[3]])


@pytest.mark.parametrize("env", ["default_qqsu"], ids=["qqsu"], indirect=True)
def test_arpl(ex_dir, env):
Expand All @@ -535,7 +530,7 @@ def test_arpl(ex_dir, env):
obs_eps=0.05,
proc_phi=0.1,
proc_eps=0.03,
torch_observation=_qqsu_torch_observation
torch_observation=_qqsu_torch_observation,
)

vfcn_hparam = dict(hidden_sizes=[32, 32], hidden_nonlin=to.tanh) # FNN
Expand Down Expand Up @@ -564,13 +559,13 @@ def test_arpl(ex_dir, env):
)
algo_hparam = dict(
max_iter=2,
steps_num=3 * env.max_steps,
)
subrtn = PPO(ex_dir, env, policy, critic, **subrtn_hparam)
algo = ARPL(ex_dir, env, subrtn, policy, subrtn.expl_strat, **algo_hparam)
algo = ARPL(ex_dir, env, subrtn, policy, **algo_hparam)

algo.train(snapshot_mode="best")


@pytest.mark.parametrize("env", ["default_qqsu", "default_bob"], ids=["qqsu", "bob"], indirect=True)
def test_arpl_observation(env):
policy_hparam = dict(hidden_sizes=[32, 32], hidden_nonlin=to.tanh) # FNN
Expand Down
11 changes: 1 addition & 10 deletions Pyrado/tests/algorithms/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,18 +409,9 @@ def test_arpl(ex_dir, env: SimEnv):
)
arpl_hparam = dict(
max_iter=2,
steps_num=23 * env.max_steps,
halfspan=0.05,
dyn_eps=0.07,
dyn_phi=0.25,
obs_phi=0.1,
obs_eps=0.05,
proc_phi=0.1,
proc_eps=0.03,
torch_observation=True,
)
ppo = PPO(ex_dir, env, policy, critic, **algo_hparam)
algo = ARPL(ex_dir, env, ppo, policy, ppo.expl_strat, **arpl_hparam)
algo = ARPL(ex_dir, env, ppo, policy, **arpl_hparam)

algo.train(snapshot_mode="best")
assert algo.curr_iter == algo.max_iter
Expand Down

0 comments on commit d5d3d29

Please sign in to comment.