Skip to content

Commit

Permalink
Merge pull request #10 from KohlerHECTOR/KohlerHECTOR/issue9
Browse files Browse the repository at this point in the history
Kohler hector/issue9
  • Loading branch information
KohlerHECTOR authored Jul 25, 2024
2 parents c6c204b + 8a7bfb7 commit fe1d485
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 55 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ In the provided ```ObliqueDTPolicy``` class, the method get_oblique_data generat

# Usage
```bash
pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.2.1
pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.3.0
```

```python
Expand Down Expand Up @@ -51,11 +51,11 @@ print(evaluate_policy(oracle, Monitor(env))[0])
clf = DecisionTreeRegressor(
max_leaf_nodes=32
) # Change to DecisionTreeClassifier for discrete Actions.
tree_policy = ObliqueDTPolicy(clf, env) #
learner = ObliqueDTPolicy(clf, env) #
# You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs.

# Start the imitation learning
interpret = Interpreter(oracle, tree_policy, env)
interpret = Interpreter(oracle, learner, env)
interpret.fit(10)

# Eval and save the best tree
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
project = "interpreter"
copyright = "2024, Hector Kohler"
author = "Hector Kohler"
release = "0.2.1"
release = "0.3.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
6 changes: 3 additions & 3 deletions docs/usage.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Installation
```bash
pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.2.1
pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.3.0
```


Expand Down Expand Up @@ -36,11 +36,11 @@ print(evaluate_policy(oracle, Monitor(env))[0])
clf = DecisionTreeRegressor(
max_leaf_nodes=32
) # Change to DecisionTreeClassifier for discrete Actions.
tree_policy = ObliqueDTPolicy(clf, env) #
learner = ObliqueDTPolicy(clf, env) #
# You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs.

# Start the imitation learning
interpret = Interpreter(oracle, tree_policy, env)
interpret = Interpreter(oracle, learner, env)
interpret.fit(10)

# Eval and save the best tree
Expand Down
4 changes: 2 additions & 2 deletions examples/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
clf = DecisionTreeRegressor(
max_leaf_nodes=32
) # Change to DecisionTreeClassifier for discrete Actions.
tree_policy = ObliqueDTPolicy(clf, env) #
learner = ObliqueDTPolicy(clf, env) #
# You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs.

# Start the imitation learning
interpret = Interpreter(oracle, tree_policy, env)
interpret = Interpreter(oracle, learner, env)
interpret.fit(10)

# Eval and save the best tree
Expand Down
2 changes: 1 addition & 1 deletion interpreter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .policies import ObliqueDTPolicy, SB3Policy, DTPolicy
from .policies import ObliqueDTPolicy, SB3Policy, DTPolicy, SymbPolicy
from .interpreter import Interpreter
45 changes: 23 additions & 22 deletions interpreter/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .policies import DTPolicy, SB3Policy, ObliqueDTPolicy, SymbPolicy

from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.utils import check_for_correct_spaces
from stable_baselines3.common.monitor import Monitor

from .policies import DTPolicy, SB3Policy, ObliqueDTPolicy
from rlberry.agents import AgentWithSimplePolicy

from gymnasium.spaces import Discrete, Box
Expand All @@ -24,8 +25,8 @@ class Interpreter(AgentWithSimplePolicy):
oracle : object
The oracle model that generates the data for training.
Usually a stable-baselines3 model from the hugging face hub.
tree_policy : object
The decision tree policy to be trained.
learner : object
The decision tree policy or symbolic equation to be trained.
env : object
The environment in which the policies are evaluated (gym.Env).
data_per_iter : int, optional
Expand All @@ -36,7 +37,7 @@ class Interpreter(AgentWithSimplePolicy):
----------
oracle : object
The oracle model that generates the data for training.
tree_policy : object
learner : object
The decision tree policy to be trained.
data_per_iter : int
The number of data points to generate per iteration.
Expand All @@ -48,31 +49,31 @@ class Interpreter(AgentWithSimplePolicy):
A list to store the rewards of the trained tree policies over iterations.
"""

def __init__(self, oracle, tree_policy, env, data_per_iter=5000, **kwargs):
def __init__(self, oracle, learner, env, data_per_iter=5000, **kwargs):
assert isinstance(oracle, SB3Policy) and (
isinstance(tree_policy, DTPolicy)
or isinstance(tree_policy, ObliqueDTPolicy)
isinstance(learner, DTPolicy)
or isinstance(learner, ObliqueDTPolicy) or isinstance(learner, SymbPolicy)
)
AgentWithSimplePolicy.__init__(self, env, **kwargs)
if not isinstance(self.eval_env, Monitor):
self.eval_env = Monitor(self.eval_env)
self._oracle = oracle
self._tree_policy = tree_policy
self._policy = deepcopy(tree_policy)
self._learner = learner
self._policy = deepcopy(learner)
self._data_per_iter = data_per_iter

check_for_correct_spaces(
self.env,
self._tree_policy.observation_space,
self._tree_policy.action_space,
self._learner.observation_space,
self._learner.action_space,
)
check_for_correct_spaces(
self.env, self._oracle.observation_space, self._oracle.action_space
)
check_for_correct_spaces(
self.eval_env,
self._tree_policy.observation_space,
self._tree_policy.action_space,
self._learner.observation_space,
self._learner.action_space,
)
check_for_correct_spaces(
self.eval_env, self._oracle.observation_space, self._oracle.action_space
Expand All @@ -90,17 +91,17 @@ def fit(self, nb_timesteps):
print("Fitting tree nb {} ...".format(0))
nb_iter = int(max(1, nb_timesteps // self._data_per_iter))
S, A = self.generate_data(self._oracle, self._data_per_iter)
self._tree_policy.fit_tree(S, A)
self._policy = deepcopy(self._tree_policy)
tree_reward, _ = evaluate_policy(self._tree_policy, self.eval_env)
self._learner.fit(S, A)
self._policy = deepcopy(self._learner)
tree_reward, _ = evaluate_policy(self._learner, self.eval_env)
current_max_reward = tree_reward
# self.tree_policies = [deepcopy(self._tree_policy)]
# self.tree_policies = [deepcopy(self._learner)]
# self.tree_policies_rewards = [tree_reward]

for t in range(1, nb_iter + 1):
print("Fitting tree nb {} ...".format(t + 1))
S_tree, _ = self.generate_data(
self._tree_policy, int((t / nb_iter) * self._data_per_iter)
self._learner, int((t / nb_iter) * self._data_per_iter)
)
S_oracle, A_oracle = self.generate_data(
self._oracle, int((1 - t / nb_iter) * self._data_per_iter)
Expand All @@ -109,14 +110,14 @@ def fit(self, nb_timesteps):
S = np.concatenate((S, S_tree, S_oracle))
A = np.concatenate((A, self._oracle.predict(S_tree)[0], A_oracle))

self._tree_policy.fit_tree(S, A)
tree_reward, _ = evaluate_policy(self._tree_policy, self.eval_env)
self._learner.fit(S, A)
tree_reward, _ = evaluate_policy(self._learner, self.eval_env)
if tree_reward > current_max_reward:
current_max_reward = tree_reward
self._policy = deepcopy(self._tree_policy)
self._policy = deepcopy(self._learner)
print("New best tree reward: {}".format(tree_reward))

# self.tree_policies += [deepcopy(self._tree_policy)]
# self.tree_policies += [deepcopy(self._learner)]
# self.tree_policies_rewards += [tree_reward]

def policy(self, obs):
Expand Down
64 changes: 60 additions & 4 deletions interpreter/policies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from pysr import PySRRegressor
import gymnasium as gym
from abc import ABC, abstractmethod
import numpy as np
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.base import RegressorMixin, ClassifierMixin
from stable_baselines3.common.utils import is_vectorized_box_observation
from tqdm import tqdm


class Policy(ABC):
Expand Down Expand Up @@ -55,7 +54,64 @@ def predict(self, obs, state=None, deterministic=True, episode_start=0):
"""
raise NotImplementedError

class SymbPolicy(Policy):
def __init__(self, model, env):
assert isinstance(model, PySRRegressor)
assert isinstance(env.action_space, gym.spaces.Box), "Symbolic regression only works for continuous actions"
self.model = model
self.model.temp_equation_file = True

super().__init__(env.observation_space, env.action_space)

S = [self.observation_space.sample() for _ in range(10)]
A = [self.action_space.sample() for _ in range(10)]
self.model.fit(S, A, )
self.model.warm_start = True
self.model.batching = True

def predict(self, obs, state=None, deterministic=True, episode_start=0):
"""
Predict the action to take given an observation.
Parameters
----------
obs : np.ndarray
The observation input.
state : object, optional
The state of the policy (default is None).
deterministic : bool, optional
Whether to use a deterministic policy (default is True).
episode_start : int, optional
The episode start index (default is 0).
Returns
-------
action : np.ndarray
The action to take.
state : object
The updated state of the policy.
"""
if not is_vectorized_box_observation(obs, self.observation_space):
if isinstance(self.action_space, gym.spaces.Discrete):
action = self.model.predict(obs.reshape(1, -1)).squeeze().astype(int)
else:
if self.action_space.shape[0] > 1:
action = self.model.predict(obs.reshape(1, -1)).squeeze()
else:
action = self.model.predict(obs.reshape(1, -1))
return action, state
else:
if isinstance(self.action_space, gym.spaces.Discrete):
return self.model.predict(obs).astype(int), None
else:
if self.action_space.shape[0] > 1:
return self.model.predict(obs), None
else:
return self.model.predict(obs)[:, np.newaxis], None

def fit(self, X, y):
return self.model.fit(X, y)

class SB3Policy(Policy):
def __init__(self, base_policy):
self.base_policy = base_policy
Expand Down Expand Up @@ -142,7 +198,7 @@ def predict(self, obs, state=None, deterministic=True, episode_start=0):
else:
return self.clf.predict(obs)[:, np.newaxis], None

def fit_tree(self, S, A):
def fit(self, S, A):
"""
Fit the decision tree with the provided observations and actions.
Expand Down Expand Up @@ -269,7 +325,7 @@ def predict(self, obs, state=None, deterministic=True, episode_start=0):
None,
)

def fit_tree(self, S, A):
def fit(self, S, A):
"""
Fit the decision tree with the provided oblique observations and actions.
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

__version__ = "0.2.1"
__version__ = "0.3.0"

packages = find_packages(
exclude=[
Expand Down Expand Up @@ -32,5 +32,6 @@
"huggingface-sb3",
"tqdm",
"gym",
"pysr"
],
)
4 changes: 2 additions & 2 deletions tests/long_test_half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def long_test():
clf = DecisionTreeRegressor(
max_leaf_nodes=32
) # Change to DecisionTreeClassifier for discrete Actions.
tree_policy = ObliqueDTPolicy(clf, env) #
learner = ObliqueDTPolicy(clf, env) #
# You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs.

# Start the imitation learning
interpret = Interpreter(oracle, tree_policy, env)
interpret = Interpreter(oracle, learner, env)
interpret.fit(3)

# Eval and save the best tree
Expand Down
Loading

0 comments on commit fe1d485

Please sign in to comment.