Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
KohlerHECTOR committed Jul 15, 2024
1 parent a8a9b57 commit c6c204b
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.2.1
```

```python
from interpreter import Interpreter
ffrom interpreter import Interpreter
from interpreter import ObliqueDTPolicy, SB3Policy, DTPolicy

from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor

import gymnasium as gym
from gymnasium.wrappers.time_limit import TimeLimit
from sklearn.tree import DecisionTreeRegressor
from huggingface_sb3 import load_from_hub

Expand Down Expand Up @@ -60,21 +59,14 @@ interpret = Interpreter(oracle, tree_policy, env)
interpret.fit(10)

# Eval and save the best tree
best_tree_policy, _ = interpret.get_best_tree_policy()
final_tree_reward, _ = evaluate_policy(best_tree_policy, env=env, n_eval_episodes=10)
final_tree_reward, _ = evaluate_policy(interpret._policy, env=env, n_eval_episodes=10)
print(final_tree_reward)
# Here you can replace pickle with joblib or cloudpickle
with open("tree_halfcheetah.pkl", "wb") as f:
dump(best_tree_policy.clf, f, protocol=5)
dump(interpret._policy.clf, f)

with open("tree_halfcheetah.pkl", "rb") as f:
clf = load(f)
# Render
evaluate_policy(
ObliqueDTPolicy(clf, env),
env=Monitor(gym.make("HalfCheetah-v4", render_mode="human")),
render=True,
)

```

Expand Down

0 comments on commit c6c204b

Please sign in to comment.