This repository has been archived by the owner on Oct 26, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 14
/
policy_value_net.py
72 lines (62 loc) · 2.64 KB
/
policy_value_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding: utf-8 -*-
# @Time : 2021/3/29 21:01
# @Author : He Ruizhi
# @File : policy_value_net.py
# @Software: PyCharm
import numpy as np
import paddle
class PolicyValueNet(paddle.nn.Layer):
def __init__(self, input_channels: int = 10,
board_size: int = 9):
"""
:param input_channels: 输入的通道数,默认为10。双方最近4步,再加一个表示当前落子方的平面,再加上一个最近一手位置的平面
:param board_size: 棋盘大小
"""
super(PolicyValueNet, self).__init__()
# AlphaGo Zero网络架构:一个身子,两个头
# 特征提取网络部分
self.conv_layer = paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels=input_channels, out_channels=32, kernel_size=3, padding=1),
paddle.nn.ReLU(),
paddle.nn.Conv2D(in_channels=32, out_channels=64, kernel_size=3, padding=1),
paddle.nn.ReLU(),
paddle.nn.Conv2D(in_channels=64, out_channels=128, kernel_size=3, padding=1),
paddle.nn.ReLU()
)
# 策略网络部分
self.policy_layer = paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels=128, out_channels=8, kernel_size=1),
paddle.nn.ReLU(),
paddle.nn.Flatten(),
paddle.nn.Linear(in_features=9*9*8, out_features=256),
paddle.nn.ReLU(),
paddle.nn.Linear(in_features=256, out_features=board_size*board_size+1),
paddle.nn.Softmax()
)
# 价值网络部分
self.value_layer = paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels=128, out_channels=4, kernel_size=1),
paddle.nn.ReLU(),
paddle.nn.Flatten(),
paddle.nn.Linear(in_features=9*9*4, out_features=128),
paddle.nn.ReLU(),
paddle.nn.Linear(in_features=128, out_features=64),
paddle.nn.ReLU(),
paddle.nn.Linear(in_features=64, out_features=1),
paddle.nn.Tanh()
)
def forward(self, x):
x = self.conv_layer(x)
policy = self.policy_layer(x)
value = self.value_layer(x)
return policy, value
def policy_value_fn(self, simulate_game_state):
"""
:param simulate_game_state:
:return:
"""
legal_positions = simulate_game_state.valid_move_idcs()
current_state = paddle.to_tensor(simulate_game_state.get_board_state()[np.newaxis], dtype='float32')
act_probs, value = self.forward(current_state)
act_probs = zip(legal_positions, act_probs.numpy().flatten()[legal_positions])
return act_probs, value