-
Notifications
You must be signed in to change notification settings - Fork 18
/
setting.py
119 lines (95 loc) · 5.15 KB
/
setting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, Optional, Type, Union
from gym.envs.registration import EnvSpec, registry
from simple_parsing import field
from simple_parsing.helpers import choice
from sequoia.common.gym_wrappers.utils import is_monsterkong_env
from sequoia.settings.assumptions.context_discreteness import DiscreteContextAssumption
from sequoia.settings.rl.continual.tasks import TaskSchedule, registry
from sequoia.utils.logging_utils import get_logger
from sequoia.utils.utils import dict_union
from ..continual.setting import ContinualRLSetting
from ..continual.setting import supported_envs as _parent_supported_envs
from .tasks import DiscreteTask, is_supported, make_discrete_task
from .test_environment import DiscreteTaskAgnosticRLTestEnvironment
logger = get_logger(__name__)
supported_envs: Dict[str, EnvSpec] = dict_union(
_parent_supported_envs,
{
spec.id: spec
for env_id, spec in registry.env_specs.items()
if spec.id not in _parent_supported_envs and is_supported(env_id)
},
)
available_datasets: Dict[str, str] = {env_id: env_id for env_id in supported_envs}
from .results import DiscreteTaskAgnosticRLResults
@dataclass
class DiscreteTaskAgnosticRLSetting(DiscreteContextAssumption, ContinualRLSetting):
"""Continual Reinforcement Learning Setting where there are clear task boundaries,
but where the task information isn't available.
"""
# TODO: Update the type or results that we get for this Setting.
Results: ClassVar[Type[Results]] = DiscreteTaskAgnosticRLResults
# The type wrapper used to wrap the test environment, and which produces the
# results.
TestEnvironment: ClassVar[Type[TestEnvironment]] = DiscreteTaskAgnosticRLTestEnvironment
# The function used to create the tasks for the chosen env.
_task_sampling_function: ClassVar[Callable[..., DiscreteTask]] = make_discrete_task
# Class variable that holds the dict of available environments.
available_datasets: ClassVar[Dict[str, Union[str, Any]]] = available_datasets
# Which environment (a.k.a. "dataset") to learn on.
# The dataset could be either a string (env id or a key from the
# available_datasets dict), a gym.Env, or a callable that returns a
# single environment.
dataset: str = choice(available_datasets, default="CartPole-v0")
# The number of "tasks" that will be created for the training, valid and test
# environments. When left unset, will use a default value that makes sense
# (something like 5).
nb_tasks: int = field(5, alias=["n_tasks", "num_tasks"])
# Maximum number of training steps per task.
train_steps_per_task: Optional[int] = None
# Number of test steps per task.
test_steps_per_task: Optional[int] = None
# # Maximum number of episodes in total.
# train_max_episodes: Optional[int] = None
# # TODO: Add tests for this 'max episodes' and 'episodes_per_task'.
# train_max_episodes_per_task: Optional[int] = None
# # Total number of steps in the test loop. (Also acts as the "length" of the testing
# # environment.)
# test_max_steps_per_task: int = 10_000
# test_max_episodes_per_task: Optional[int] = None
# # Max number of steps per training task. When left unset and when `train_max_steps`
# # is set, takes the value of `train_max_steps` divided by `nb_tasks`.
# train_max_steps_per_task: Optional[int] = None
# # (WIP): Maximum number of episodes per training task. When left unset and when
# # `train_max_episodes` is set, takes the value of `train_max_episodes` divided by
# # `nb_tasks`.
# train_max_episodes_per_task: Optional[int] = None
# # Maximum number of steps per task in the test loop. When left unset and when
# # `test_max_steps` is set, takes the value of `test_max_steps` divided by `nb_tasks`.
# test_max_steps_per_task: Optional[int] = None
# # (WIP): Maximum number of episodes per test task. When left unset and when
# # `test_max_episodes` is set, takes the value of `test_max_episodes` divided by
# # `nb_tasks`.
# test_max_episodes_per_task: Optional[int] = None
# def warn(self, warning: Warning):
# logger.warning(warning)
# warnings.warn(warning)
def __post_init__(self):
# TODO: Rework all the messy fields from before by just considering these as eg.
# the maximum number of steps per task, rather than the fixed number of steps
# per task.
assert not self.smooth_task_boundaries
super().__post_init__()
if self.max_episode_steps is None:
if is_monsterkong_env(self.dataset):
self.max_episode_steps = 500
def create_train_task_schedule(self) -> TaskSchedule[DiscreteTask]:
# IDEA: Could convert max_episodes into max_steps if max_steps_per_episode is
# set.
return super().create_train_task_schedule()
def create_val_task_schedule(self) -> TaskSchedule[DiscreteTask]:
# Always the same as train task schedule for now.
return super().create_val_task_schedule()
def create_test_task_schedule(self) -> TaskSchedule[DiscreteTask]:
return super().create_test_task_schedule()