-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathnmcts_generate.py
81 lines (66 loc) · 2.82 KB
/
nmcts_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import pathlib
import random
import torch
import torch.nn as nn
from utttpy.game.ultimate_tic_tac_toe import UltimateTicTacToe
from utttpy.selfplay.policy_value_network import PolicyValueNetwork
from utttpy.selfplay.neural_monte_carlo_tree_search import (
NeuralMonteCarloTreeSearch,
serialize_evaluated_state,
serialize_evaluated_actions,
)
def run_argparse() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--policy_value_net_path", type=pathlib.Path, required=True)
parser.add_argument("--uttt_state", type=str, required=True)
parser.add_argument("--num_simulations", type=int, required=True)
parser.add_argument("--exploration_strength", type=float, required=True)
parser.add_argument("--random_seed", type=int, required=True)
parser.add_argument("--output_path", type=pathlib.Path, required=True)
parser.add_argument("--device", type=torch.device, default="cuda")
args = parser.parse_args()
return args
def load_policy_value_net(state_dict_path: pathlib.Path, device: torch.device) -> nn.Module:
policy_value_net = PolicyValueNetwork()
policy_value_net.to(device=device)
state_dict = torch.load(state_dict_path, map_location=device)
policy_value_net.load_state_dict(state_dict)
policy_value_net.eval()
return policy_value_net
def main() -> None:
args = run_argparse()
print(args)
random.seed(args.random_seed)
policy_value_net = load_policy_value_net(
state_dict_path=args.policy_value_net_path,
device=args.device,
)
uttt = UltimateTicTacToe(state=bytearray(map(int, args.uttt_state)))
nmcts = NeuralMonteCarloTreeSearch(
uttt=uttt.clone(),
num_simulations=args.num_simulations,
exploration_strength=args.exploration_strength,
policy_value_net=policy_value_net,
)
evaluations_str = ""
while not uttt.is_terminated():
nmcts.run(progress_bar=True)
print(nmcts)
evaluated_state = nmcts.get_evaluated_state()
evaluated_actions = nmcts.get_evaluated_actions()
evaluated_state_str = serialize_evaluated_state(evaluated_state=evaluated_state)
evaluated_actions_str = serialize_evaluated_actions(evaluated_actions=evaluated_actions)
evaluation_str = f"{evaluated_state_str} {evaluated_actions_str}"
print(evaluation_str)
evaluations_str += f"{evaluation_str}\n"
selected_action = nmcts.select_action(evaluated_actions=evaluated_actions, selection_method="sample")
print("selected", selected_action)
uttt.execute(action=selected_action)
nmcts.synchronize(uttt=uttt)
print(nmcts)
with open(args.output_path, "w") as f:
f.write(evaluations_str)
print(f"evaluations saved to {args.output_path} successfully!")
if __name__ == "__main__":
main()