Skip to content

Commit

Permalink
environment/tabular/tests/examples
Browse files Browse the repository at this point in the history
  • Loading branch information
hamishs committed Apr 3, 2021
1 parent d5876c5 commit ff785c8
Show file tree
Hide file tree
Showing 10 changed files with 316 additions and 86 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ Main libraries used:
* Expected SARSA

## TODO
### v0
* More unit tests
* environment
* documentation

### v1
* Other PPO based algorithms?
* Prioritised experience replay
* Multi-agent
Expand Down
79 changes: 79 additions & 0 deletions examples/tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
'''
Solve a simple grid world with tabular RL.
'''

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from jax_rl.tabular import QLearning, DoubleQLearning, SARSA, ExpectedSARSA
from jax_rl.policies import EpsilonGreedy
from jax_rl.environment import GridWorld

# setup gridworld

grid = np.array([
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 1, 1, 0],
[0, 1, 0, 1, 1, 1, 1, 1, 0],
[0, 1, 0, 1, 0, 1, 1, 1, 0],
[0, 1, 0, 1, 0, 1, 1, 1, 0],
[0, 1, 1, 1, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0 ,0]])

start = (1, 2)
end = (6, 7)
wall_reward = -10.0
step_reward = -0.1

class FlatWrapper:
''' Wrapper to return integer states for grid world.'''

def __init__(self, env, rows, cols):
self.env = env
self.rows = rows
self.cols = cols

def reset(self):
s = self.env.reset()
return s[0] * self.cols + s[1]

def step(self, a):
s_, r, d, info = self.env.step(a)
s = s_[0] * self.cols + s_[1]
return s, r, d, info

env = FlatWrapper(GridWorld(grid, start, end, wall_reward, step_reward),
grid.shape[0], grid.shape[1])

# train each algorithm on the grid problem

q_learning = QLearning(4, grid.size, 4, 0.99, EpsilonGreedy(0.1), 0.1)
q_rs = q_learning.train_on_env(env, 30, verbose = 5)

double_q = DoubleQLearning(5, grid.size, 4, 0.99, EpsilonGreedy(0.1), 0.1)
qd_rs = double_q.train_on_env(env, 30, verbose = 5)

sarsa = SARSA(6, grid.size, 4, 0.99, EpsilonGreedy(0.1), 0.1)
s_rs = sarsa.train_on_env(env, 30, verbose = 5)

expected_sarsa = ExpectedSARSA(7, grid.size, 4, 0.99, EpsilonGreedy(0.1), 0.1)
es_rs = expected_sarsa.train_on_env(env, 30, verbose = 5)

# plot results
plt.plot(q_rs, label = 'Q-learning')
plt.plot(qd_rs, label = 'Double Q-learning')
plt.plot(s_rs, label = 'SARSA')
plt.plot(es_rs, label = 'Expected SARSA')
plt.legend()
plt.show()








2 changes: 1 addition & 1 deletion src/jax_rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def act(self, s, exploration = True):
assert s.shape == (1, self.n_states)
q_values = self.q_network(self.params, s)

return self.policy(next(self.prng), self.n_actions, q_values, exploration)
return self.policy(next(self.prng), q_values, exploration)

def train(self, batch_size):
''' Train the agent on a single episode. Uses the double q-learning target.
Expand Down
50 changes: 3 additions & 47 deletions src/jax_rl/algorithms/drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
class DRQN(BaseAgent):
''' Deep Recurrent Q-network using double Q-learning.'''

def __init__(self, key, n_states, n_actions, gamma, buffer_size, max_len, policy, model, init_state, lr):
def __init__(self, key, n_states, n_actions, gamma, buffer_size,
max_len, policy, model, init_state, lr):
'''
model must take sequential inputs and a hidden state.
init_state must provide the initial state for a given batch_size.
Expand Down Expand Up @@ -40,7 +41,7 @@ def act(self, s, exploration = True):
assert s.shape == (1, 1, self.n_states)
q_values, self.hidden_state = self.q_network(self.params, s, self.hidden_state)

return self.policy(next(self.prng), self.n_actions, q_values, exploration)
return self.policy(next(self.prng), q_values, exploration)

def train(self):
''' Train the agent on a single episode. Uses the double q-learning target.
Expand Down Expand Up @@ -133,48 +134,3 @@ def train_on_env(self, env, episodes, update_freq, verbose = None):

return ep_rewards, losses



if __name__ == '__main__':

from policies import EpsilonGreedy
from utils import lstm_initial_state

import gym
env = gym.make('CartPole-v0')

# Example LSTM network
def forward(s, hidden = None):
'''
Apply the LSTM over the input sequence with given initial state.
s : (batch, seq_len, features)
hidden : LSTM state (h, c).
'''

# extract features
mlp1 = hk.nets.MLP([16, 16])
s = hk.BatchApply(mlp1)(s) # (batch, seq_len, hidden_features)

# LSTM
lstm = hk.LSTM(32)
if hidden is None: hidden = lstm.initial_state(s.shape[0])
s, hidden = hk.dynamic_unroll(lstm, jnp.transpose(s, (1, 0, 2)), hidden)

# output fully connected
mlp2 = hk.nets.MLP([16, 1])
s = hk.BatchApply(mlp2)(jnp.transpose(s, (1, 0, 2)))

# s : (batch, seq_len, 1)
# hidden = (h, c)
# h : ()

return s, hidden

model = hk.without_apply_rng(hk.transform(forward))
init_state = lambda batch_size: lstm_initial_state(32, batch_size = batch_size)

policy = EpsilonGreedy(0.1)

drqn = DRQN(0, 4, 2, 0.99, 1000, 200, policy, model, init_state, 1e-5)
ep_rewards, losses = drqn.train_on_env(env, 500, 1, verbose = 10)

106 changes: 106 additions & 0 deletions src/jax_rl/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numpy as np
from abc import ABC, abstractmethod

class Environment(ABC):
'''
An environment for a Markov Decision Process
Uses similar syntax to env from gym (could subclass instead).
'''
def __init__(self):
pass

@abstractmethod
def reset(self):
pass

@abstractmethod
def step(self, a):
pass

class GridWorld(Environment):
'''
Discrete grid world for tabular algorithms. Grid is represented
by a np array indicating the walls and valid squares. Actions are
movements up, down, left or right. There is a start square and a
single end square which is the only terminal state.
Actions:
0 = up
1 = down
2 = left
3 = right
'''
def __init__(self, grid, start, end, wall_reward, step_reward, max_steps = 100):
'''
grid : 2D binary numpy array with 0 for wall and 1 for valid square.
start : tuple indicating row and column of starting square.
end : tuple indicating row and column of end square.
wall_reward : the reward of bumping into a wall square.
step_reward : the reward for each step taken in the environment.
'''
self.grid = grid
self.start = start
self.end = end
self.wall_reward = wall_reward
self.step_reward = step_reward
self.max_steps = max_steps

self.n_rows, self.n_cols = self.grid.shape

def reset(self):
self.state = self.start
self.done = False
self._t = 0
return list(self.state)

def step(self, a):

if self.done:
print('Done')
raise ValueError

row, col = self.state
reward = self.step_reward
info = ''

# process action
if a == 0:
r_next, c_next = row + 1, col
elif a == 1:
r_next, c_next = row - 1, col
elif a == 2:
r_next, c_next = row, col - 1
elif a == 3:
r_next, c_next = row, col + 1
else:
print('Invalid action')
raise ValueError

# make sure in grid
if (r_next < 0) or (r_next >= self.n_rows):
r_next = row
reward += self.wall_reward
info = 'Tried to leave the grid'
elif (r_next < 0) or (r_next >= self.n_rows):
c_next = col
reward += self.wall_reward
info = 'Tried to leave the grid'

# if a wall then stay still
if self.grid[r_next, c_next] == 0:
r_next = row
c_next = col
reward += self.wall_reward
info = 'Hit a wall'

self.state = (r_next, c_next)

# check if done
self._t += 1
if (self.state == self.end) or (self._t >= self.max_steps):
self.done = True

return list(self.state), reward, self.done, info



18 changes: 9 additions & 9 deletions src/jax_rl/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)

@abstractmethod
def call(self, key, q_values, exploration = True, return_distibution = False):
def call(self, key, q_values, exploration = True, return_distribution = False):
pass

class EpsilonGreedy(Policy):
Expand All @@ -25,7 +25,7 @@ def __init__(self, epsilon):
self.epsilon = epsilon
self.t = 0 # step counter

def call(self, key, q_values, exploration = True, return_distibution = False):
def call(self, key, q_values, exploration = True, return_distribution = False):
'''
key : jax.random.PRNGKey.
state : jnp.array (1, n_states) - current state.
Expand All @@ -34,14 +34,14 @@ def call(self, key, q_values, exploration = True, return_distibution = False):
exploration : bool = True - wether to allow exploration or to be greedy.
return distribution : bool = False - wether to return the distribution over actions.
'''
if return_distibution:
n_actions = q_values.shape[0]
if return_distribution:
dist = jnp.ones(n_actions,) * self.epsilon / n_actions
dist[jnp.argmax(q_values)] += 1 - self.epsilon
dist = jax.ops.index_add(dist, jnp.argmax(q_values), 1 - self.epsilon)
return dist
else:
self.t += 1
eps = self.epsilon(self.t) if callable(self.epsilon) else self.epsilon
n_actions = q_values.shape[0]
if exploration and (jax.random.uniform(key, shape = (1,))[0] > 1 - eps):
return int(jax.random.randint(key, shape = (1,), minval = 0, maxval = n_actions))
else:
Expand All @@ -58,22 +58,22 @@ def __init__(self, T):
self.T = T
self.t = 0 # step counter

def call(self, key, q_values, exploration = True, return_distibution = False):
def call(self, key, q_values, exploration = True, return_distribution = False):
'''
key : jax.random.PRNGKey.
state : jnp.array (1, n_states) - current state.
q_values : (n_actions) - estimated q-values for the current state.
exploration : bool = True - wether to allow exploration or to be greedy.
return distribution : bool = False - wether to return the distribution over actions.
'''
if return_distibution:
n_actions = q_values.shape[0]
T = self.T(self.t) if callable(self.T) else self.T
if return_distribution:
prefs = jnp.exp(q_values / T)
prefs /= prefs.sum()
return prefs
else:
self.t += 1
T = self.T(self.t) if callable(self.T) else self.T
n_actions = q_values.shape[0]
if exploration:
prefs = jnp.exp(q_values / T)
prefs /= prefs.sum()
Expand Down
Loading

0 comments on commit ff785c8

Please sign in to comment.