-
Notifications
You must be signed in to change notification settings - Fork 0
/
value_iteration.py
64 lines (42 loc) · 1.59 KB
/
value_iteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from gamblers_problem import GP
def value_iteration(env, discount_factor=1.0, epsilon=0.0001):
def one_step_lookahead(s, vfn):
actions = np.zeros(env.nA)
for a, action in enumerate(env.P[s]):
for outcome in range(len(action)):
(prob, next_state, reward, done) = action[outcome]
actions[a] += prob*(reward + discount_factor*vfn[next_state])
return actions
policy = np.zeros([env.nS, env.nA])
vfn = np.zeros(env.nS)
while True:
delta = 0
for s in range(env.nS):
action_values = one_step_lookahead(s, vfn)
best_action_value = max(action_values)
delta = max(delta, abs(vfn[s] - best_action_value))
vfn[s] = best_action_value
best_action = np.argmax(action_values)
policy[s] = np.eye(env.nA)[best_action]
if delta < epsilon:
return policy, vfn
if __name__ == "__main__":
env = GP()
p, v = value_iteration(env)
matplotlib.use('TkAgg')
fig = plt.figure(figsize=(20,20))
vfn = fig.add_subplot(121)
pol = fig.add_subplot(122)
vfn.plot(v)
vfn.set_xlabel('Capital', fontsize=16)
vfn.set_ylabel('State Value', fontsize=16)
vfn.set_title('Value function Gamblers Problem', fontsize=22)
actions = np.argmax(np.array(p), axis=1)
pol.plot(actions)
pol.set_xlabel('Capital', fontsize=16)
pol.set_ylabel('Stake', fontsize=16)
pol.set_title('Policy Gamblers Problem', fontsize=22)
plt.show()