Skip to content

Commit

Permalink
new version with rlberry
Browse files Browse the repository at this point in the history
  • Loading branch information
KohlerHECTOR committed Jul 15, 2024
1 parent ac69381 commit c1a1e9c
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 31 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__pycache__
.coverage
*.pkl
*.pkl
rlberry_data/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ In the provided ```ObliqueDTPolicy``` class, the method get_oblique_data generat

# Usage
```bash
pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.1.5
pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.2.0
```

```python
Expand Down
2 changes: 1 addition & 1 deletion docs/usage.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Installation
```bash
pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.1.5
pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.2.0
```


Expand Down
2 changes: 1 addition & 1 deletion examples/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@
dump(interpret._policy.clf, f)

with open("tree_halfcheetah.pkl", "rb") as f:
clf = load(f)
clf = load(f)
31 changes: 23 additions & 8 deletions interpreter/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from rlberry.agents import AgentWithSimplePolicy

from gymnasium.spaces import Discrete, Box
from gymnasium.wrappers.time_limit import TimeLimit

import numpy as np
from copy import deepcopy
from tqdm import tqdm


class Interpreter(AgentWithSimplePolicy):
"""
A class to interpret a neural net policy using a decision tree policy.
Expand Down Expand Up @@ -59,13 +62,17 @@ def __init__(self, oracle, tree_policy, env, data_per_iter=5000, **kwargs):
self._data_per_iter = data_per_iter

check_for_correct_spaces(
self.env, self._tree_policy.observation_space, self._tree_policy.action_space
self.env,
self._tree_policy.observation_space,
self._tree_policy.action_space,
)
check_for_correct_spaces(
self.env, self._oracle.observation_space, self._oracle.action_space
)
check_for_correct_spaces(
self.eval_env, self._tree_policy.observation_space, self._tree_policy.action_space
self.eval_env,
self._tree_policy.observation_space,
self._tree_policy.action_space,
)
check_for_correct_spaces(
self.eval_env, self._oracle.observation_space, self._oracle.action_space
Expand All @@ -92,8 +99,12 @@ def fit(self, nb_timesteps):

for t in range(1, nb_iter + 1):
print("Fitting tree nb {} ...".format(t + 1))
S_tree, _ = self.generate_data(self._tree_policy, int((t/nb_iter) * self._data_per_iter))
S_oracle, A_oracle = self.generate_data(self._oracle, int((1 - t/nb_iter) * self._data_per_iter))
S_tree, _ = self.generate_data(
self._tree_policy, int((t / nb_iter) * self._data_per_iter)
)
S_oracle, A_oracle = self.generate_data(
self._oracle, int((1 - t / nb_iter) * self._data_per_iter)
)

S = np.concatenate((S, S_tree, S_oracle))
A = np.concatenate((A, self._oracle.predict(S_tree)[0], A_oracle))
Expand All @@ -107,12 +118,16 @@ def fit(self, nb_timesteps):

# self.tree_policies += [deepcopy(self._tree_policy)]
# self.tree_policies_rewards += [tree_reward]

def policy(self, obs):
return self._policy.predict(obs)

def eval(self, n_simulations):
return evaluate_policy(self._policy, self.eval_env, n_eval_episodes=n_simulations)[0]
def eval(self, eval_horizon=10**5, n_simulations=10, gamma=1.0):
return evaluate_policy(
self._policy,
TimeLimit(self.eval_env, eval_horizon),
n_eval_episodes=n_simulations,
)[0]

def generate_data(self, policy, nb_data):
"""
Expand All @@ -132,7 +147,7 @@ def generate_data(self, policy, nb_data):
A : np.ndarray
The generated actions.
"""
assert (nb_data >= 0)
assert nb_data >= 0
if isinstance(self.env.action_space, Discrete):
A = np.zeros((nb_data))
elif isinstance(self.env.action_space, Box):
Expand Down
18 changes: 10 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
long_description_content_type="text/markdown",
author="Hector Kohler",
author_email="hector.kohler@inria.fr",
install_requires=["scikit-learn>=1.3.0",
"stable-baselines3",
"rlberry",
"gymnasium[mujoco]",
"huggingface-sb3",
"tqdm",
"Shimmy==1.3.0",
"gym"],
install_requires=[
"scikit-learn>=1.3.0",
"stable-baselines3",
"rlberry",
"gymnasium[mujoco]",
"huggingface-sb3",
"tqdm",
"Shimmy==1.3.0",
"gym",
],
)
16 changes: 5 additions & 11 deletions tests/test_policies_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def test_interpreter_oblique_ctnuous_actions_high_dim():
interpret.policy(env.reset()[0])



def test_interpreter_ctnuous_actions_high_dim():
env = gym.make("Ant-v4")
model = PPO("MlpPolicy", env)
Expand All @@ -138,6 +137,7 @@ def test_interpreter_ctnuous_actions_high_dim():
interpret.fit(3)
interpret.policy(env.reset()[0])


def test_interpreter_rlberry():
env = gym.make("Ant-v4")
model = PPO("MlpPolicy", env)
Expand All @@ -147,20 +147,14 @@ def test_interpreter_rlberry():

exp = ExperimentManager(
agent_class=Interpreter,
train_env=(gym_make, {"id":"Ant-v4"}),
train_env=(gym_make, {"id": "Ant-v4"}),
fit_budget=1e4,
init_kwargs=dict(oracle=oracle, tree_policy=tree_policy),
n_fit=2,
seed=42
seed=42,
)
exp.fit()

# output = plot_writer_data(
# [exp],
# tag="reward",
# smooth=True,
# title="Episode Reward smoothed",
# )
_ = evaluate_agents(
[exp], n_simulations=50, show=False
) # Evaluate the trained agent on
[exp], n_simulations=50, show=False
) # Evaluate the trained agent on

0 comments on commit c1a1e9c

Please sign in to comment.