-
Notifications
You must be signed in to change notification settings - Fork 0
/
cartpole_td_learning.py
155 lines (132 loc) · 6.8 KB
/
cartpole_td_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gym
import gym.spaces
import numpy as np
import math
import random
import matplotlib.pyplot as plt
from collections import deque
# create the cart-pole environment
env = gym.make('CartPole-v0')
class CartPole():
def __init__(self, buckets=(1, 1, 6, 3,), n_episodes=1000, solved_t=195,
min_epsilon=0.1, min_alpha=0.1, gamma=0.99):
self.buckets = buckets # discrete values for each feature space dimension
# (position, velocity, angle, angular velocity)
self.n_episodes = n_episodes # training episodes
self.min_alpha = min_alpha
self.min_epsilon = min_epsilon
self.gamma = gamma # discount factor
self.solved_t = solved_t # lower bound before episode ends
self.epsilon= min_epsilon
self.Q_table = np.zeros(self.buckets + (env.action_space.n,)) # action space (left, right)
print(self.Q_table )
def discretize_state(self, state):
upper_bounds = env.observation_space.high # upper and lower bounds of state dimensions
lower_bounds = env.observation_space.low
upper_bounds[1] = 0.5
upper_bounds[3] = math.radians(50) # setting manual bounds for velocity and angluar velocity
lower_bounds[1] = -0.5
lower_bounds[3] = -math.radians(50)
# discretizing each input dimension into one of the buckets
width = [upper_bounds[i] - lower_bounds[i] for i in range(len(state))]
ratios = [(state[i] - lower_bounds[i]) / width[i] for i in range(len(state))]
bucket_indices = [int(round(ratios[i] * (self.buckets[i] - 1))) for i in range(len(state))]
# making the range of indices to [0, bucket_length]
bucket_indices = [max(0, min(bucket_indices[i], self.buckets[i] - 1)) for i in range(len(state))]
return tuple(bucket_indices)
def select_action(self, state, epsilon):
# implement the epsilon-greedy approach
if random.random() <= epsilon:
return env.action_space.sample() # sample a random action with probability epsilon
else:
return np.argmax(self.Q_table[state]) # choose greedy action with hightest Q-value
def get_epsilon(self, episode_number,state):
# choose decaying epsilon in range [min_epsilon, 1]
if random.random() > self.epsilon: # select greedy action with probability epsilon
return np.argmax(self.Q_table[state])
else: # otherwise, select an action randomly
return random.choice(np.arange(env.action_space.n))
def get_alpha(self, episode_number):
# choose decaying alpha in range [min_alpha, 1]
return max(self.min_alpha, min(1, 1 - math.log10((episode_number + 1) / 25)))
def update_table(self, old_state, action, reward, new_state, alpha,next_state=None, next_action=None):
# updates the state-action pairs based on future reward
current = self.Q_table[tuple(new_state)][action] # estimate in Q-table (for current state, action pair)
# get value of state, action pair at next time step
Qsa_next = self.Q_table[tuple(next_state)][next_action] if next_state is not None else 0
target = reward + (self.gamma * Qsa_next) # construct TD target
new_value = current + (alpha * (target - current)) # get updated value
return new_value
def run(self):
# runs episodes until mean reward of last 100 consecutive episodes is atleast self.solved_t
total_epochs, total_penalties = 0, 0
#counter = 0
scores = deque(maxlen=200)
episodes_result = deque(maxlen=200)
penalties_result = deque(maxlen=100)
results = []
for episode in range(self.n_episodes):
#results.append(cartpole.run())
obs = env.reset()
curr_state = self.discretize_state(obs)
done = False
alpha = self.get_alpha(episode)
epsilon = self.get_epsilon(episode,curr_state)
epochs, penalties, episode_reward = 0, 0, 0
while not done:
#env.render()
action = self.select_action(curr_state, epsilon)
obs, reward, done, info = env.step(action)
new_state = self.discretize_state(obs)
self.update_table(curr_state, action, reward, new_state, alpha)
curr_state = new_state
episode_reward += reward
print('Reward:', reward)
if reward == -10:
penalties += 1
print('penalties:', penalties)
epochs += 1
total_penalties += penalties
total_epochs += epochs
scores.append(episode_reward)
episodes_result.append(epochs)
penalties_result.append(total_penalties)
mean_reward = np.mean(scores)
if mean_reward > self.solved_t and (episode + 1) >= 100:
print("Ran {} episodes, solved after {} trials".format(episode + 1, episode + 1 - 100))
return episode + 1 - 100
elif (episode + 1) % 50 == 0 and (episode + 1) >= 100:
print("Episode number: {}, mean reward over past 100 episodes is {}".format(episode + 1, mean_reward))
else:
print("Episode {}, reward {}".format(episode + 1, episode_reward))
print(f"Results after {episode} episodes:")
print(f"Average timesteps per episode: {total_epochs / episode}")
print(f"Average Rewards per episode: {np.mean(scores)}")
print(f"Average penalties per episode: {total_penalties / episode}")
print("Training finished.\n")
plt.hist(episodes_result, 50, normed=1, facecolor='g', alpha=0.75)
plt.xlabel('Episodes required to reach Goal')
plt.ylabel('Frequency')
plt.title('Episode Histogram of Cartpole problem solving by TD Learning')
plt.show()
plt.hist(scores, 50, normed=1, facecolor='g', alpha=0.75)
plt.xlabel('Rewards Achieved Per Episode')
plt.ylabel('Frequency')
plt.title('Rewards Histogram of Cartpole problem solving by TD Learning')
plt.show()
plt.hist(penalties_result, 50, normed=1, facecolor='g', alpha=0.75)
plt.xlabel('Penalties Per Episode')
plt.ylabel('Frequency')
plt.title('Penalties Histogram of Cartpole problem solving by TD Learning')
plt.show()
if __name__ == "__main__":
cartpole = CartPole()
cartpole.run()
results = []
#results.append(cartpole.run())
#plt.hist(results, 50, normed=1, facecolor='g', alpha=0.75)
#plt.xlabel('Episodes required to reach 200')
#plt.ylabel('Frequency')
#plt.title('Histogram of Random Search')
#plt.show()
print(np.sum(results) / 1000.0)