-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
316 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
''' | ||
Solve a simple grid world with tabular RL. | ||
''' | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from jax_rl.tabular import QLearning, DoubleQLearning, SARSA, ExpectedSARSA | ||
from jax_rl.policies import EpsilonGreedy | ||
from jax_rl.environment import GridWorld | ||
|
||
# setup gridworld | ||
|
||
grid = np.array([ | ||
[0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
[0, 1, 1, 1, 1, 1, 1, 1, 0], | ||
[0, 1, 0, 1, 1, 1, 1, 1, 0], | ||
[0, 1, 0, 1, 0, 1, 1, 1, 0], | ||
[0, 1, 0, 1, 0, 1, 1, 1, 0], | ||
[0, 1, 1, 1, 0, 1, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 0, 0, 1, 0], | ||
[0, 0, 0, 0, 0, 0, 0, 0 ,0]]) | ||
|
||
start = (1, 2) | ||
end = (6, 7) | ||
wall_reward = -10.0 | ||
step_reward = -0.1 | ||
|
||
class FlatWrapper: | ||
''' Wrapper to return integer states for grid world.''' | ||
|
||
def __init__(self, env, rows, cols): | ||
self.env = env | ||
self.rows = rows | ||
self.cols = cols | ||
|
||
def reset(self): | ||
s = self.env.reset() | ||
return s[0] * self.cols + s[1] | ||
|
||
def step(self, a): | ||
s_, r, d, info = self.env.step(a) | ||
s = s_[0] * self.cols + s_[1] | ||
return s, r, d, info | ||
|
||
env = FlatWrapper(GridWorld(grid, start, end, wall_reward, step_reward), | ||
grid.shape[0], grid.shape[1]) | ||
|
||
# train each algorithm on the grid problem | ||
|
||
q_learning = QLearning(4, grid.size, 4, 0.99, EpsilonGreedy(0.1), 0.1) | ||
q_rs = q_learning.train_on_env(env, 30, verbose = 5) | ||
|
||
double_q = DoubleQLearning(5, grid.size, 4, 0.99, EpsilonGreedy(0.1), 0.1) | ||
qd_rs = double_q.train_on_env(env, 30, verbose = 5) | ||
|
||
sarsa = SARSA(6, grid.size, 4, 0.99, EpsilonGreedy(0.1), 0.1) | ||
s_rs = sarsa.train_on_env(env, 30, verbose = 5) | ||
|
||
expected_sarsa = ExpectedSARSA(7, grid.size, 4, 0.99, EpsilonGreedy(0.1), 0.1) | ||
es_rs = expected_sarsa.train_on_env(env, 30, verbose = 5) | ||
|
||
# plot results | ||
plt.plot(q_rs, label = 'Q-learning') | ||
plt.plot(qd_rs, label = 'Double Q-learning') | ||
plt.plot(s_rs, label = 'SARSA') | ||
plt.plot(es_rs, label = 'Expected SARSA') | ||
plt.legend() | ||
plt.show() | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import numpy as np | ||
from abc import ABC, abstractmethod | ||
|
||
class Environment(ABC): | ||
''' | ||
An environment for a Markov Decision Process | ||
Uses similar syntax to env from gym (could subclass instead). | ||
''' | ||
def __init__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def reset(self): | ||
pass | ||
|
||
@abstractmethod | ||
def step(self, a): | ||
pass | ||
|
||
class GridWorld(Environment): | ||
''' | ||
Discrete grid world for tabular algorithms. Grid is represented | ||
by a np array indicating the walls and valid squares. Actions are | ||
movements up, down, left or right. There is a start square and a | ||
single end square which is the only terminal state. | ||
Actions: | ||
0 = up | ||
1 = down | ||
2 = left | ||
3 = right | ||
''' | ||
def __init__(self, grid, start, end, wall_reward, step_reward, max_steps = 100): | ||
''' | ||
grid : 2D binary numpy array with 0 for wall and 1 for valid square. | ||
start : tuple indicating row and column of starting square. | ||
end : tuple indicating row and column of end square. | ||
wall_reward : the reward of bumping into a wall square. | ||
step_reward : the reward for each step taken in the environment. | ||
''' | ||
self.grid = grid | ||
self.start = start | ||
self.end = end | ||
self.wall_reward = wall_reward | ||
self.step_reward = step_reward | ||
self.max_steps = max_steps | ||
|
||
self.n_rows, self.n_cols = self.grid.shape | ||
|
||
def reset(self): | ||
self.state = self.start | ||
self.done = False | ||
self._t = 0 | ||
return list(self.state) | ||
|
||
def step(self, a): | ||
|
||
if self.done: | ||
print('Done') | ||
raise ValueError | ||
|
||
row, col = self.state | ||
reward = self.step_reward | ||
info = '' | ||
|
||
# process action | ||
if a == 0: | ||
r_next, c_next = row + 1, col | ||
elif a == 1: | ||
r_next, c_next = row - 1, col | ||
elif a == 2: | ||
r_next, c_next = row, col - 1 | ||
elif a == 3: | ||
r_next, c_next = row, col + 1 | ||
else: | ||
print('Invalid action') | ||
raise ValueError | ||
|
||
# make sure in grid | ||
if (r_next < 0) or (r_next >= self.n_rows): | ||
r_next = row | ||
reward += self.wall_reward | ||
info = 'Tried to leave the grid' | ||
elif (r_next < 0) or (r_next >= self.n_rows): | ||
c_next = col | ||
reward += self.wall_reward | ||
info = 'Tried to leave the grid' | ||
|
||
# if a wall then stay still | ||
if self.grid[r_next, c_next] == 0: | ||
r_next = row | ||
c_next = col | ||
reward += self.wall_reward | ||
info = 'Hit a wall' | ||
|
||
self.state = (r_next, c_next) | ||
|
||
# check if done | ||
self._t += 1 | ||
if (self.state == self.end) or (self._t >= self.max_steps): | ||
self.done = True | ||
|
||
return list(self.state), reward, self.done, info | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.