-
Notifications
You must be signed in to change notification settings - Fork 1
/
base_env.py
90 lines (73 loc) · 2.4 KB
/
base_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# --------------------------------------------------------
# Copyright (c) 2023 Princeton University
# Email: kaichieh@princeton.edu
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, Union
import random
import numpy as np
import gym
import torch
from .utils import GenericAction, GenericState
from .agent import Agent
class BaseEnv(gym.Env, ABC):
agent: Agent
def __init__(self, cfg_env) -> None:
gym.Env.__init__(self)
self.cnt = 0
self.timeout = cfg_env.timeout
self.end_criterion = cfg_env.end_criterion
@abstractmethod
def step(self,
action: GenericAction) -> Tuple[GenericState, float, bool, Dict]:
raise NotImplementedError
@abstractmethod
def get_obs(self, state: np.ndarray) -> np.ndarray:
raise NotImplementedError
def reset(
self, state: Optional[GenericState] = None, cast_torch: bool = False,
**kwargs
) -> GenericState:
"""
Resets the environment and returns the new state.
Args:
state (Optional[GenericState], optional): reset to this state if
provided. Defaults to None.
cast_torch (bool): cast state to torch if True.
Returns:
np.ndarray: the new state.
"""
self.cnt = 0
@abstractmethod
def render(self):
raise NotImplementedError
def seed(self, seed: int = 0) -> None:
self.seed_val = seed
self.rng = np.random.default_rng(seed)
random.seed(self.seed_val)
torch.manual_seed(self.seed_val)
torch.cuda.manual_seed(self.seed_val)
torch.cuda.manual_seed_all(self.seed_val) # if using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
self.action_space.seed(seed)
self.observation_space.seed(seed)
@abstractmethod
def report(self):
raise NotImplementedError
@abstractmethod
def get_constraints_all(
self, states: np.ndarray, actions: Union[np.ndarray, dict]
) -> Dict[str, np.ndarray]:
"""
Gets the values of all constaint functions given current state, current
action, and next state.
Args:
state (np.ndarray)
controls (np.ndarray)
Returns:
Dict: each (key, value) pair is the name and values of a constraint
evaluated at the states and controls input.
"""
raise NotImplementedError