-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert.py
143 lines (125 loc) · 4.48 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import dill
import numpy as np
import pickle
import types
import inspect
from train_ddqn import AgentWithNormalMemory, AgentNormalMultiReward, AgentWithPER, AgentWithPERAndMultiRewards
from tensorflow.python.keras.activations import relu, linear
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from argparse import ArgumentParser
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
class MyDenseLayer:
weight = None
bias = None
activation = None
def my_relu(self, x):
return (x > 0) * x
def my_linear(self, x):
return x
def my_leaky_relu(self, x):
return np.maximum(x, x * 0.1)
def __init__(self, layer=None):
self.activation_functions = {
relu: 'relu',
linear: 'linear',
LeakyReLU: 'leaky_relu'
}
self.activation_remap = {
'relu': self.my_relu,
'linear': self.my_linear,
'leaky_relu': self.my_leaky_relu
}
if layer is not None:
self.weight = layer.weights[0].numpy()
self.bias = layer.bias.numpy()
if layer.activation is not None:
if isinstance(layer.activation, types.FunctionType):
self.activation = self.activation_functions[layer.activation]
elif inspect.isclass(type(layer.activation)):
self.activation = self.activation_functions[type(layer.activation)]
def __call__(self, x):
val = x.dot(self.weight) + self.bias
if self.activation is not None:
val = self.activation_remap[self.activation](val)
return val
def to_json(self):
return {
'weight': self.weight,
'bias': self.bias,
'activation': self.activation
}
@staticmethod
def from_json(data):
layer = MyDenseLayer()
layer.weight = data['weight']
layer.bias = data['bias']
layer.activation = data['activation']
return layer
class MyPyNetwork:
def __init__(self):
self.layers = []
def from_network(self, network):
layers = network.layers
print('from network', len(network.layers), len(self.layers))
for layer in layers:
converted_layer = MyDenseLayer(layer)
self.layers.append(converted_layer)
def __call__(self, x):
tmp = x
for layer in self.layers[:-2]:
tmp = layer(tmp)
tmp = self.layers[-1](tmp)
return tmp
def dump(self, path):
layers_data = []
for layer in self.layers:
layers_data.append(layer.to_json())
pickle.dump(layers_data, open(path, 'wb'))
def dump_pkl(self, path):
pickle.dump(self.layers, open(path, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
def load(self, path):
# self.layers = dill.load(open(path, 'rb'))
json_data = pickle.load(open(path, 'rb'))
self.layers = []
for d in json_data:
self.layers.append(MyDenseLayer.from_json(d))
if __name__ == '__main__':
agents = {
'normal': AgentWithNormalMemory,
'per': AgentWithPER,
'per_multi': AgentWithPERAndMultiRewards,
'normal_multi': AgentNormalMultiReward,
}
multi_reward_types = {
'normal': False,
'per': False,
'per_multi': True,
'normal_multi': True
}
parser = ArgumentParser()
parser.add_argument('--agent', type=str, choices=['normal', 'per', 'per_multi', 'normal_multi'], required=True)
args = parser.parse_args()
Agent = agents[args.agent]
print('start converting model for ', Agent)
my_agent = Agent()
my_agent.load_model()
multi_reward = multi_reward_types[args.agent]
print('convert to numpy weights')
if multi_reward:
offensive_network = MyPyNetwork()
offensive_network.from_network(my_agent.q_net_offensive)
offensive_network.dump('q_offensive.pickle')
print('q_offensive', len(offensive_network.layers))
defensive_network = MyPyNetwork()
defensive_network.from_network(my_agent.q_net_defensive)
defensive_network.dump('q_defensive.pickle')
print('q_defensive', len(defensive_network.layers), len(my_agent.q_net_defensive.layers))
else:
my_network = MyPyNetwork()
my_network.from_network(my_agent.q_net)
my_network.dump('my_network.pickle')
print('q model')
# my_network.dump_pkl('my_network.pickle')
# my_network
print('conversion done')