-
Notifications
You must be signed in to change notification settings - Fork 12
/
species.py
226 lines (172 loc) · 6.12 KB
/
species.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from dataclasses import dataclass
from agent import Agent
from human_agent import HumanAgent
from mcts_agent import MCTSAgent
from environment_registry import get_env_module
import intuition_model
from paths import build_model_paths
import settings
SPECIES_REGISTRY = set()
@dataclass
class Species:
name: str
AgentClass: Agent
def agent_settings(self, environment, generation, play_setting):
raise NotImplementedError()
def self_play_settings(self, environment, generation):
raise NotImplementedError()
def training_settings(self, environment, generation):
raise NotImplementedError()
@dataclass
class Human(Species):
name: str = "human"
AgentClass: Agent = HumanAgent
def agent_settings(self, environment, generation, play_setting):
return dict(species=self.name, generation=generation)
SPECIES_REGISTRY.add(Human) # noqa
@dataclass
class GBDT(Species):
name: str = "gbdt"
AgentClass: Agent = MCTSAgent
def agent_settings(self, environment, generation, play_setting):
env_module = get_env_module(environment)
# Setup value/policy models
if generation == 1:
value_model = intuition_model.UnopinionatedValue()
if hasattr(env_module, "BootstrapValue"):
value_model = env_module.BootstrapValue()
policy_model = intuition_model.UniformPolicy()
else:
value_model_path, policy_model_path = build_model_paths(
environment,
self.name,
generation,
)
value_model = intuition_model.GBDTValue()
value_model.load(value_model_path)
policy_model = intuition_model.GBDTPolicy()
policy_model.load(policy_model_path)
# Settings
move_consideration_time = 3.0
temperature = 0.0
full_search_proportion = 1.0
full_search_steps = 800
partial_search_steps = full_search_steps // 5
# Play-setting dependent
# - XXX Add Puct params
lower_bound_time = 0.01
if play_setting == "self_play":
temperature = .3 # XXX Make a temp profile per move
move_consideration_time = lower_bound_time
elif play_setting == "evaluation":
move_consideration_time = lower_bound_time
return dict(
species=self.name,
generation=generation,
game_tree=None,
current_node=None,
feature_extractor=env_module.generate_features,
value_model=value_model,
policy_model=policy_model,
move_consideration_time=move_consideration_time, # at least N seconds
puct_explore_factor=1.0,
puct_noise_alpha=0.4,
puct_noise_influence=0.25,
full_search_proportion=full_search_proportion,
full_search_steps=full_search_steps,
partial_search_steps=partial_search_steps,
temperature=temperature,
)
def self_play_settings(self, environment, generation):
return dict(num_games=3000)
def training_settings(self, environment, generation):
return dict(
ValueModel=intuition_model.GBDTValue,
value_model_settings=dict(),
# Value model settings for reweighting
# dict(weighting_strat=strat, highest_generation=generation - 1)
PolicyModel=intuition_model.GBDTPolicy,
policy_model_settings=dict(num_workers=settings.GBDT_TRAINING_THREADS),
)
SPECIES_REGISTRY.add(GBDT) # noqa
@dataclass
class GBDTR1(GBDT):
name: str = "gbdtR1"
SPECIES_REGISTRY.add(GBDTR1) # noqa
@dataclass
class GBDTR2(GBDT):
name: str = "gbdtR2"
SPECIES_REGISTRY.add(GBDTR2) # noqa
@dataclass
class PCR(GBDT):
name: str = "gbdt_pcr"
def agent_settings(self, environment, generation, play_setting):
sets = super().agent_settings(environment, generation, play_setting)
if play_setting == "self_play":
sets["full_search_proportion"] = .2
return sets
SPECIES_REGISTRY.add(PCR) # noqa
@dataclass
class PCRR1(PCR):
name: str = "gbdt_pcrR1"
SPECIES_REGISTRY.add(PCRR1) # noqa
@dataclass
class PCRR2(PCR):
name: str = "gbdt_pcrR2"
SPECIES_REGISTRY.add(PCRR2) # noqa
@dataclass
class PCRP1(GBDT):
name: str = "gbdt_pcrp1"
def agent_settings(self, environment, generation, play_setting):
sets = super().agent_settings(environment, generation, play_setting)
if play_setting == "self_play":
sets["full_search_proportion"] = .1
return sets
SPECIES_REGISTRY.add(PCRP1) # noqa
@dataclass
class PCRV(GBDT):
name: str = "gbdt_pcrv"
def agent_settings(self, environment, generation, play_setting):
sets = super().agent_settings(environment, generation, play_setting)
sets["require_full_steps"] = False
if play_setting == "self_play":
sets["full_search_proportion"] = .2
return sets
SPECIES_REGISTRY.add(PCRV) # noqa
@dataclass
class PCRVR1(PCRV):
name: str = "gbdt_pcrvR1"
SPECIES_REGISTRY.add(PCRVR1) # noqa
@dataclass
class PCRVR2(PCRV):
name: str = "gbdt_pcrvR2"
SPECIES_REGISTRY.add(PCRVR2) # noqa
@dataclass
class RVE(GBDT):
name: str = "gbdt_rve"
def agent_settings(self, environment, generation, play_setting):
sets = super().agent_settings(environment, generation, play_setting)
sets["require_full_steps"] = False
sets["revisit_violated_expectations"] = True
if play_setting == "self_play":
sets["full_search_proportion"] = .2
return sets
def self_play_settings(self, environment, generation):
sps = super().self_play_settings(environment, generation)
sps["num_games"] = 900
return sps
SPECIES_REGISTRY.add(RVE) # noqa
@dataclass
class RVER1(RVE):
name: str = "gbdt_rveR1"
SPECIES_REGISTRY.add(RVER1) # noqa
@dataclass
class RVER2(RVE):
name: str = "gbdt_rveR2"
SPECIES_REGISTRY.add(RVER2) # noqa
SPECIES_BY_NAME = {}
for species_class in SPECIES_REGISTRY:
sp = species_class()
SPECIES_BY_NAME[sp.name] = sp
def get_species(name):
return SPECIES_BY_NAME[name]