-
Notifications
You must be signed in to change notification settings - Fork 11
/
mcts.py
443 lines (388 loc) · 16.4 KB
/
mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
"""
Adapted from https://github.com/tensorflow/minigo/blob/master/mcts.py
Implementation of the Monte-Carlo tree search algorithm as detailed in the
AlphaGo Zero paper (https://www.nature.com/articles/nature24270).
"""
import math
import random as rd
import collections
import numpy as np
# Exploration constant
c_PUCT = 1.38
# Dirichlet noise alpha parameter.
D_NOISE_ALPHA = 0.03
# Number of steps into the episode after which we always select the
# action with highest action probability rather than selecting randomly
TEMP_THRESHOLD = 5
class DummyNode:
"""
Special node that is used as the node above the initial root node to
prevent having to deal with special cases when traversing the tree.
"""
def __init__(self):
self.parent = None
self.child_N = collections.defaultdict(float)
self.child_W = collections.defaultdict(float)
def revert_virtual_loss(self, up_to=None): pass
def add_virtual_loss(self, up_to=None): pass
def revert_visits(self, up_to=None): pass
def backup_value(self, value, up_to=None): pass
class MCTSNode:
"""
Represents a node in the Monte-Carlo search tree. Each node holds a single
environment state.
"""
def __init__(self, state, n_actions, TreeEnv, action=None, parent=None):
"""
:param state: State that the node should hold.
:param n_actions: Number of actions that can be performed in each
state. Equal to the number of outgoing edges of the node.
:param TreeEnv: Static class that defines the environment dynamics,
e.g. which state follows from another state when performing an action.
:param action: Index of the action that led from the parent node to
this node.
:param parent: Parent node.
"""
self.TreeEnv = TreeEnv
if parent is None:
self.depth = 0
parent = DummyNode()
else:
self.depth = parent.depth+1
self.parent = parent
self.action = action
self.state = state
self.n_actions = n_actions
self.is_expanded = False
self.n_vlosses = 0 # Number of virtual losses on this node
self.child_N = np.zeros([n_actions], dtype=np.float32)
self.child_W = np.zeros([n_actions], dtype=np.float32)
# Save copy of original prior before it gets mutated by dirichlet noise
self.original_prior = np.zeros([n_actions], dtype=np.float32)
self.child_prior = np.zeros([n_actions], dtype=np.float32)
self.children = {}
@property
def N(self):
"""
Returns the current visit count of the node.
"""
return self.parent.child_N[self.action]
@N.setter
def N(self, value):
self.parent.child_N[self.action] = value
@property
def W(self):
"""
Returns the current total value of the node.
"""
return self.parent.child_W[self.action]
@W.setter
def W(self, value):
self.parent.child_W[self.action] = value
@property
def Q(self):
"""
Returns the current action value of the node.
"""
return self.W / (1 + self.N)
@property
def child_Q(self):
return self.child_W / (1 + self.child_N)
@property
def child_U(self):
return (c_PUCT * math.sqrt(1 + self.N) *
self.child_prior / (1 + self.child_N))
@property
def child_action_score(self):
"""
Action_Score(s, a) = Q(s, a) + U(s, a) as in paper. A high value
means the node should be traversed.
"""
return self.child_Q + self.child_U
def select_leaf(self):
"""
Traverses the MCT rooted in the current node until it finds a leaf
(i.e. a node that only exists in its parent node in terms of its
child_N and child_W values but not as a dedicated node in the parent's
children-mapping). Nodes are selected according to child_action_score.
It expands the leaf by adding a dedicated MCTSNode. Note that the
estimated value and prior probabilities still have to be set with
`incorporate_estimates` afterwards.
:return: Expanded leaf MCTSNode.
"""
current = self
while True:
current.N += 1
# Encountered leaf node (i.e. node that is not yet expanded).
if not current.is_expanded:
break
# Choose action with highest score.
best_move = np.argmax(current.child_action_score)
current = current.maybe_add_child(best_move)
return current
def maybe_add_child(self, action):
"""
Adds a child node for the given action if it does not yet exists, and
returns it.
:param action: Action to take in current state which leads to desired
child node.
:return: Child MCTSNode.
"""
if action not in self.children:
# Obtain state following given action.
new_state = self.TreeEnv.next_state(self.state, action)
self.children[action] = MCTSNode(new_state, self.n_actions,
self.TreeEnv,
action=action, parent=self)
return self.children[action]
def add_virtual_loss(self, up_to):
"""
Propagate a virtual loss up to a given node.
:param up_to: The node to propagate until.
"""
self.n_vlosses += 1
self.W -= 1
if self.parent is None or self is up_to:
return
self.parent.add_virtual_loss(up_to)
def revert_virtual_loss(self, up_to):
"""
Undo adding virtual loss.
:param up_to: The node to propagate until.
"""
self.n_vlosses -= 1
self.W += 1
if self.parent is None or self is up_to:
return
self.parent.revert_virtual_loss(up_to)
def revert_visits(self, up_to):
"""
Revert visit increments.
Sometimes, repeated calls to select_leaf return the same node.
This is rare and we're okay with the wasted computation to evaluate
the position multiple times by the dual_net. But select_leaf has the
side effect of incrementing visit counts. Since we want the value to
only count once for the repeatedly selected node, we also have to
revert the incremented visit counts.
:param up_to: The node to propagate until.
"""
self.N -= 1
if self.parent is None or self is up_to:
return
self.parent.revert_visits(up_to)
def incorporate_estimates(self, action_probs, value, up_to):
"""
Call if the node has just been expanded via `select_leaf` to
incorporate the prior action probabilities and state value estimated
by the neural network.
:param action_probs: Action probabilities for the current node's state
predicted by the neural network.
:param value: Value of the current node's state predicted by the neural
network.
:param up_to: The node to propagate until.
"""
# A done node (i.e. episode end) should not go through this code path.
# Rather it should directly call `backup_value` on the final node.
# TODO: Add assert here
# Another thread already expanded this node in the meantime.
# Ignore wasted computation but correct visit counts.
if self.is_expanded:
self.revert_visits(up_to=up_to)
return
self.is_expanded = True
self.original_prior = self.child_prior = action_probs
# This is a deviation from the paper that led to better results in
# practice (following the MiniGo implementation).
self.child_W = np.ones([self.n_actions], dtype=np.float32) * value
self.backup_value(value, up_to=up_to)
def backup_value(self, value, up_to):
"""
Propagates a value estimation up to the root node.
:param value: Value estimate to be propagated.
:param up_to: The node to propagate until.
"""
self.W += value
if self.parent is None or self is up_to:
return
self.parent.backup_value(value, up_to)
def is_done(self):
return self.TreeEnv.is_done_state(self.state, self.depth)
def inject_noise(self):
dirch = np.random.dirichlet([D_NOISE_ALPHA] * self.n_actions)
self.child_prior = self.child_prior * 0.75 + dirch * 0.25
def visits_as_probs(self, squash=False):
"""
Returns the child visit counts as a probability distribution.
:param squash: If True, exponentiate the probabilities by a temperature
slightly large than 1 to encourage diversity in early steps.
:return: Numpy array of shape (n_actions).
"""
probs = self.child_N
if squash:
probs = probs ** .95
return probs / np.sum(probs)
def print_tree(self, level=0):
node_string = "\033[94m|" + "----"*level
node_string += "Node: action={}\033[0m".format(self.action)
node_string += "\n• state:\n{}".format(self.state)
node_string += "\n• N={}".format(self.N)
node_string += "\n• score:\n{}".format(self.child_action_score)
node_string += "\n• Q:\n{}".format(self.child_Q)
node_string += "\n• P:\n{}".format(self.child_prior)
print(node_string)
for _, child in sorted(self.children.items()):
child.print_tree(level+1)
class MCTS:
"""
Represents a Monte-Carlo search tree and provides methods for performing
the tree search.
"""
def __init__(self, agent_netw, TreeEnv, seconds_per_move=None,
simulations_per_move=800, num_parallel=8):
"""
:param agent_netw: Network for predicting action probabilities and
state value estimate.
:param TreeEnv: Static class that defines the environment dynamics,
e.g. which state follows from another state when performing an action.
:param seconds_per_move: Currently unused.
:param simulations_per_move: Number of traversals through the tree
before performing a step.
:param num_parallel: Number of leaf nodes to collect before evaluating
them in conjunction.
"""
self.agent_netw = agent_netw
self.TreeEnv = TreeEnv
self.seconds_per_move = seconds_per_move
self.simulations_per_move = simulations_per_move
self.num_parallel = num_parallel
self.temp_threshold = None # Overwritten in initialize_search
self.qs = []
self.rewards = []
self.searches_pi = []
self.obs = []
self.root = None
def initialize_search(self, state=None):
init_state = self.TreeEnv.initial_state()
n_actions = self.TreeEnv.n_actions
self.root = MCTSNode(init_state, n_actions, self.TreeEnv)
# Number of steps into the episode after which we always select the
# action with highest action probability rather than selecting randomly
self.temp_threshold = TEMP_THRESHOLD
self.qs = []
self.rewards = []
self.searches_pi = []
self.obs = []
def tree_search(self, num_parallel=None):
"""
Performs multiple simulations in the tree (following trajectories)
until a given amount of leaves to expand have been encountered.
Then it expands and evalutes these leaf nodes.
:param num_parallel: Number of leaf states which the agent network can
evaluate at once. Limits the number of simulations.
:return: The leaf nodes which were expanded.
"""
if num_parallel is None:
num_parallel = self.num_parallel
leaves = []
# Failsafe for when we encounter almost only done-states which would
# prevent the loop from ever ending.
failsafe = 0
while len(leaves) < num_parallel and failsafe < num_parallel * 2:
failsafe += 1
# self.root.print_tree()
# print("_"*50)
leaf = self.root.select_leaf()
# If we encounter done-state, we do not need the agent network to
# bootstrap. We can backup the value right away.
if leaf.is_done():
value = self.TreeEnv.get_return(leaf.state, leaf.depth)
leaf.backup_value(value, up_to=self.root)
continue
# Otherwise, discourage other threads to take the same trajectory
# via virtual loss and enqueue the leaf for evaluation by agent
# network.
leaf.add_virtual_loss(up_to=self.root)
leaves.append(leaf)
# Evaluate the leaf-states all at once and backup the value estimates.
if leaves:
action_probs, values = self.agent_netw.step(
self.TreeEnv.get_obs_for_states([leaf.state for leaf in leaves]))
for leaf, action_prob, value in zip(leaves, action_probs, values):
leaf.revert_virtual_loss(up_to=self.root)
leaf.incorporate_estimates(action_prob, value, up_to=self.root)
return leaves
def pick_action(self):
"""
Selects an action for the root state based on the visit counts.
"""
if self.root.depth > self.temp_threshold:
action = np.argmax(self.root.child_N)
else:
cdf = self.root.child_N.cumsum()
cdf /= cdf[-1]
selection = rd.random()
action = cdf.searchsorted(selection)
assert self.root.child_N[action] != 0
return action
def take_action(self, action):
"""
Takes the specified action for the root state. The subsequent child
state becomes the new root state of the tree.
:param action: Action to take for the root state.
"""
# Store data to be used as experience tuples.
ob = self.TreeEnv.get_obs_for_states([self.root.state])
self.obs.append(ob)
self.searches_pi.append(
self.root.visits_as_probs()) # TODO: Use self.root.position.n < self.temp_threshold as argument
self.qs.append(self.root.Q)
reward = (self.TreeEnv.get_return(self.root.children[action].state,
self.root.children[action].depth)
- sum(self.rewards))
self.rewards.append(reward)
# Resulting state becomes new root of the tree.
self.root = self.root.maybe_add_child(action)
del self.root.parent.children
def execute_episode(agent_netw, num_simulations, TreeEnv):
"""
Executes a single episode of the task using Monte-Carlo tree search with
the given agent network. It returns the experience tuples collected during
the search.
:param agent_netw: Network for predicting action probabilities and state
value estimate.
:param num_simulations: Number of simulations (traverses from root to leaf)
per action.
:param TreeEnv: Static environment that describes the environment dynamics.
:return: The observations for each step of the episode, the policy outputs
as output by the MCTS (not the pure neural network outputs), the individual
rewards in each step, total return for this episode and the final state of
this episode.
"""
mcts = MCTS(agent_netw, TreeEnv)
mcts.initialize_search()
# Must run this once at the start, so that noise injection actually affects
# the first action of the episode.
first_node = mcts.root.select_leaf()
probs, vals = agent_netw.step(
TreeEnv.get_obs_for_states([first_node.state]))
first_node.incorporate_estimates(probs[0], vals[0], first_node)
while True:
mcts.root.inject_noise()
current_simulations = mcts.root.N
# We want `num_simulations` simulations per action not counting
# simulations from previous actions.
while mcts.root.N < current_simulations + num_simulations:
mcts.tree_search()
# mcts.root.print_tree()
# print("_"*100)
action = mcts.pick_action()
mcts.take_action(action)
if mcts.root.is_done():
break
# Computes the returns at each step from the list of rewards obtained at
# each step. The return is the sum of rewards obtained *after* the step.
ret = [TreeEnv.get_return(mcts.root.state, mcts.root.depth) for _
in range(len(mcts.rewards))]
total_rew = np.sum(mcts.rewards)
obs = np.concatenate(mcts.obs)
return (obs, mcts.searches_pi, ret, total_rew, mcts.root.state)