-
Notifications
You must be signed in to change notification settings - Fork 0
/
MCTS.py
73 lines (61 loc) · 1.68 KB
/
MCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import time
from copy import deepcopy
import numpy as np
import random
import Chessboard
INF = np.inf
def uct(node, C=1.414):
return INF if node.N == 0 else node.U / node.N + C * np.sqrt(np.log(node.parent.N) / node.N)
class MctsNode(object):
def __init__(self, parent=None, state=None, U=0, N=0):
self.parent = parent
self.children = {}
self.state = state
self.U = U
self.N = N
def select(self):
if self.children:
return max(self.children.keys(), key=uct)
else:
return self
def expand(self):
if not self.children and not Chessboard.is_terminal(self.state):
self.children = {
MctsNode(parent=self, state=Chessboard.flip_chess(self.state, action)): action
for action in Chessboard.get_valid_moves(self.state)
}
return self.select()
def simulate(self):
player = self.state.to_move
_state = deepcopy(self.state)
no_action = False
while not Chessboard.is_terminal(_state):
actions = Chessboard.get_valid_moves(_state)
if not actions:
if no_action:
break
_state = Chessboard.flip_chess(_state, (-1, -1))
no_action = True
continue
else:
no_action = False
action = random.choice(actions)
_state = Chessboard.flip_chess(_state, action)
_util = Chessboard.utility(_state, player)
return -_util
def back_prop(self, value):
self.N += 1
if value > 0:
self.U += value
if self.parent:
self.parent.back_prop(-value)
def go(self, start_time, time_out):
while True:
leaf = self.select()
child = leaf.expand()
res = child.simulate()
child.back_prop(res)
if time.time() - start_time > time_out * 0.95:
break
node, action = max(self.children.items(), key=lambda x: x[0].N)
return node.N, action