-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
11c30b2
commit e80d790
Showing
15 changed files
with
1,476 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# In-Context RL | ||
|
||
This eval tests models' ability to solve RL environments simply by interacting with them in-context, without dedicated training or fine-tuning. | ||
|
||
## Usage | ||
|
||
Run with: | ||
|
||
```bash | ||
oaieval <solver> incontext_rl | ||
``` | ||
|
||
For examples of tested solvers, see [`./scripts/run_experiments.sh`](./scripts/run_experiments.sh). | ||
|
||
## Dataset | ||
|
||
The eval is currently set up to test models on the following canonical RL environments: | ||
1. [FrozenLake-v1](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) (non-slippery version, default map), 4x4 gridworld where the agent has to reach the goal without falling into traps. | ||
2. [CliffWalking-v0](https://gymnasium.farama.org/environments/toy_text/cliff_walking/). 4x12 gridworld where the agent has to reach the other side of the map without falling off a cliff. | ||
3. [BanditTwoArmedHighLowFixed-v1](https://github.com/james-aung/gymasium-bandits). Stochastic two-armed bandit setup where Arm 1 pays out 80% of the time with reward 1, and Arm 2 pays out 20% of the time with reward 1. | ||
4. [BanditTenArmedRandomFixed-v1](https://github.com/james-aung/gymasium-bandits). Stochastic ten-armed bandit setup where each arm has some randomly-initialized probability of payout. | ||
|
||
Besides these four environments, our eval is also built to be compatible with any environments that have discrete action and observation spaces using the Gymnasium API. Future work may generalize our eval to work with environments with other types of action/observation spaces. | ||
|
||
## Evaluation Process | ||
|
||
Each run of the eval tests the model on all four environments in the dataset, and has the model take steps in each environment until 200 steps are taken or the model’s context limit is reached. | ||
|
||
At each step, the eval provides the following to the model: | ||
- The next observation and the reward from the last action. The model is also told when the environment has reset due to its action leading to a termination. | ||
- How many of the maximum number of steps it has already taken. | ||
- The total reward it has accumulated so far across all episodes. | ||
|
||
If an episode ends, the environment resets and a new episode begins. | ||
|
||
If the eval receive 4 responses in a row where we cannot parse an action selection, we end the evaluation for that environment. (This provides a natural end for runs where the model’s context window is exceeded.) | ||
|
||
|
||
## Prompts | ||
|
||
We refer readers to the [`./defaults.py`](./defaults.py) file for the `TASK_DESCRIPTION` and other prompts used in the eval. | ||
|
||
## Metrics | ||
<!-- prettier-ignore-start --> | ||
We provide the following metrics per evaluated environment: | ||
|
||
| **Metric** | **Notes** | | ||
|---|---| | ||
| `average_episode_reward` | The average reward achieved per episode | | ||
| `total_steps` | The number of steps taken across all episodes before the environment sample ended | | ||
| `invalid_response_rate` | % of responses that were in an invalid format for the eval | | ||
<!-- prettier-ignore-end --> | ||
|
||
## Token Usage Estimates | ||
|
||
<!-- prettier-ignore-start --> | ||
| Model | Token Usage Per Run | | ||
|---|---| | ||
| **gpt-3.5-turbo** | 4200000 ± 400000 | | ||
| **gpt-4-turbo-preview** | 21900000 ± 10100000 | | ||
| **mixtral-8x7b** | 2700000 ± 800000 | | ||
<!-- prettier-ignore-end --> | ||
|
||
## Future modifications | ||
|
||
- Extend the eval to work with other observation and action spaces beyond Discrete spaces | ||
|
||
## Version History | ||
|
||
- v0: Initial version released | ||
|
||
## Contribution Statement | ||
|
||
Eval design, implementation, and results evaluation were primarily conducted by James Aung. Chan Jun Shern was responsible for code reviews throughout the implementation process, along with fine-grained feedback on the project in general. Additional guidance was provided by Steven Adler, who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import Any | ||
from evals.solvers.solver import NestedSolver, Solver, SolverResult, SolverSpec | ||
from evals.task_state import Message, TaskState | ||
|
||
ANTI_COT_TEMPLATE = "RESPOND ONLY WITH YOUR FINAL ANSWER IN THE FORMAT REQUESTED. DO NOT OUTPUT ANY ADDITIONAL REASONING OR TEXT." | ||
|
||
class AntiCoTSolver(NestedSolver): | ||
""" | ||
Instructs the model to not do any further reasoning and just respond with the final answer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
solver: SolverSpec, | ||
registry: Any = None, | ||
): | ||
super().__init__(solver=solver) | ||
|
||
@property | ||
def solver(self) -> Solver: | ||
return self.get_solver("solver") | ||
|
||
def _solve( | ||
self, | ||
task_state: TaskState, | ||
**kwargs, | ||
) -> SolverResult: | ||
task_state.messages += ( | ||
[ | ||
Message(role="system", content=ANTI_COT_TEMPLATE), | ||
] | ||
) | ||
solver_result = self.solver(task_state=task_state, **kwargs) | ||
return solver_result | ||
|
||
@property | ||
def name(self) -> str: | ||
return f"Anti-CoT_{self.solver.name}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import random | ||
|
||
import numpy as np | ||
|
||
from evals.elsuite.incontext_rl.eval import CurrentState | ||
from evals.record import record_sampling | ||
from evals.solvers.solver import Solver, SolverResult | ||
from evals.task_state import TaskState | ||
|
||
|
||
class RandomSolver(Solver): | ||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
def _solve( | ||
self, | ||
task_state: TaskState, | ||
**kwargs, | ||
) -> SolverResult: | ||
|
||
cs: CurrentState = task_state.current_state | ||
|
||
try: | ||
action = cs.action_space.sample() | ||
response = f"[SELECT: {action}]" | ||
except Exception as e: | ||
response = f"Error: {e}" | ||
|
||
record_sampling( | ||
prompt=cs.observations[-1], | ||
sampled=response, | ||
model="incontext_rl_random", | ||
) | ||
|
||
return SolverResult(response) | ||
|
||
|
||
class QlearningSolver(Solver): | ||
def __init__( | ||
self, | ||
learning_rate=0.7, | ||
gamma=0.95, | ||
epsilon=1.0, | ||
min_epsilon=0.05, | ||
max_epsilon=1.0, | ||
decay_rate=0.0005, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
self.learning_rate = learning_rate | ||
self.gamma = gamma | ||
self.epsilon = epsilon | ||
self.min_epsilon = min_epsilon | ||
self.max_epsilon = max_epsilon | ||
self.decay_rate = decay_rate | ||
self.q_table = None | ||
|
||
def initialize_q_table(self, observation_space_size, action_space_size): | ||
self.q_table = np.zeros((observation_space_size, action_space_size)) | ||
|
||
def select_action(self, state, action_space): | ||
if random.uniform(0, 1) < self.epsilon: | ||
return action_space.sample() # Explore action space | ||
else: | ||
return np.argmax(self.q_table[state][:]) # Exploit learned values | ||
|
||
def update_q_table(self, state, action, reward, next_state): | ||
next_max = np.max(self.q_table[next_state]) | ||
self.q_table[state, action] = self.q_table[state, action] + self.learning_rate * ( | ||
reward + self.gamma * next_max - self.q_table[state, action] | ||
) | ||
|
||
def reduce_epsilon(self, episode_number): | ||
self.epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon) * np.exp( | ||
-self.decay_rate * episode_number | ||
) | ||
|
||
def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: | ||
|
||
cs: CurrentState = task_state.current_state | ||
|
||
# TODO these might not be true if environment is not discrete | ||
assert ( | ||
cs.observation_space_n is not None | ||
), "Environment must have discrete observation space" | ||
assert cs.action_space_n is not None, "Environment must have discrete action space" | ||
|
||
if self.q_table is None: | ||
print("Initializing Q-table") | ||
self.initialize_q_table( | ||
observation_space_size=cs.observation_space_n, action_space_size=cs.action_space_n | ||
) | ||
|
||
# This shouln't run on the first step | ||
if len(cs.actions) >= 1 and len(cs.rewards) >= 1 and len(cs.observations) >= 2: | ||
print(cs.actions) | ||
self.update_q_table( | ||
state=cs.observations[-2], | ||
action=cs.actions[-1], | ||
reward=cs.rewards[-1], | ||
next_state=cs.observations[-1], | ||
) | ||
print( | ||
f"The last action {cs.actions[-1]} resulted in reward {cs.rewards[-1]}. We went from state {cs.observations[-2]} to state {cs.observations[-1]}" | ||
) | ||
self.reduce_epsilon(episode_number=len(cs.episode_end_steps)) | ||
|
||
action = self.select_action(state=cs.observations[-1], action_space=cs.action_space) | ||
response = f"[SELECT: {action}]" | ||
|
||
record_sampling( | ||
prompt=cs.observations[-1], | ||
sampled=response, | ||
model="incontext_rl_qlearning", | ||
) | ||
|
||
return SolverResult(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from string import Template | ||
|
||
task_description_template = Template("""$explanations | ||
You have the following $action_space actions available to you: $actions. | ||
On each turn, you will receive an observation from the environment, which maybe be one of the following $observation_space possible observations: | ||
$observations | ||
Each action you take will give you a reward (which might be negative) and may move you to a new state. | ||
You should take actions such that you accumulate the highest possible reward across all your actions. | ||
This is how we will play the game: | ||
1. You will consider all the information you have received to decide upon the best action you can take now. | ||
2. You will select your action from the list above by specifying the number key of the action in the command [SELECT: x], where x is the number key of the action. | ||
3. Your selected action will be taken. | ||
4. As a result of your action, you will be given an observation from the environment and you may receive some reward. | ||
5. Repeat from step 1. | ||
""") | ||
|
||
step_counter = Template("Total actions taken so far: $step_count") | ||
reward_counter = Template("Total reward so far: $reward_count") | ||
reset_msg = Template("""After the game reset you are now in $observation. | ||
Please pick an action, providing your reasoning. You must format your final action choice as [SELECT: x]""") | ||
step_result = Template("""You took Action $action. You are now in $next_observation. | ||
The last step you did provided reward: $reward. | ||
Please pick an action, providing your reasoning. You must format your final action choice as [SELECT: x]""") | ||
step_result_reset = Template("""You took Action $action. You arrived at $next_observation. | ||
The last step made the game reset. | ||
The last step you did provided reward: $reward.""") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Optional setup scripts for specific environments. | ||
""" | ||
|
||
def setup_GymnasiumBandits(): | ||
import gymnasium_bandits | ||
return | ||
|
||
ENV_SETUP_FUNCS = { | ||
"BanditTwoArmedHighLowFixed-v0": setup_GymnasiumBandits, | ||
"BanditTenArmedRandomFixed-v0": setup_GymnasiumBandits, | ||
} |
Oops, something went wrong.