forked from b4be1/gh_gym
-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.py
157 lines (125 loc) · 4.05 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import socket
import struct
import pickle
from math import sin, cos, pi
class Connection:
def __init__(self, s):
self._socket = s
self._buffer = bytearray()
def receive_object(self):
while len(self._buffer) < 4 or len(self._buffer) < struct.unpack("<L", self._buffer[:4])[0] + 4:
new_bytes = self._socket.recv(16)
if len(new_bytes) == 0:
return None
self._buffer += new_bytes
length = struct.unpack("<L", self._buffer[:4])[0]
header, body = self._buffer[:4], self._buffer[4:length + 4]
obj = pickle.loads(body)
self._buffer = self._buffer[length + 4:]
return obj
def send_object(self, d):
body = pickle.dumps(d)
header = struct.pack("<L", len(body))
msg = header + body
self._socket.send(msg)
class Env:
def __init__(self):
self.max_speed = 8.
self.max_torque = 2.
self.dt = .05
self.action_space = {"space_type": "box", "high": [self.max_torque]}
self.observation_space = {"space_type": "box", "high": [1., 1., self.max_speed]}
self._state = None
self._iter = None
def step(self, u):
def clip(val, max_val, min_val):
if val > max_val:
return max_val
elif val < min_val:
return min_val
else:
return val
def angle_normalize(x):
return ((x + pi) % (2 * pi)) - pi
th, thdot = self._state
g = 10.
m = 1.
l = 1.
dt = self.dt
assert type(u) is list
u[0] = clip(u[0], self.max_torque, -self.max_torque)
costs = angle_normalize(th)**2 + .1*thdot**2 + .001*(u[0]**2)
newthdot = thdot + (-3*g/(2*l) * sin(th + pi) + 3./(m*l**2)*u[0]) * dt
newth = th + newthdot*dt
newthdot = clip(newthdot, self.max_speed, -self.max_speed)
self._state = [newth, newthdot]
done = False
self._iter += 1
if self._iter == 200:
self._iter = 0
done = True
return self._get_obs(), -costs, done, {}
def reset(self):
self._iter = 0
self._state = [0., 0.]
return self._get_obs()
def _get_obs(self):
theta, thetadot = self._state
return [cos(theta), sin(theta), thetadot]
env = None
observation = None
reward = None
done = None
info = None
agent_socket = None
agent_conn = None
def environment(action, reset):
global env, observation, reward, done, info
if env is None:
env = Env()
if reset:
observation = env.reset()
info = {"action_space": "gym.spaces.Box(low=-np.array([2.]), high=np.array([2.]))",
"observation_space": "gym.spaces.Box(low=-np.array([1., 1., 8.]), high=np.array([1., 1., 8.]))"}
elif action:
observation, reward, done, info = env.step(action)
else:
raise RuntimeError("Either reset or action must be provided")
def agent(iter):
global agent_socket, agent_conn
action = None
reset = None
# Connection
if iter == 0:
addr = ("127.0.0.1", 50710)
agent_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
agent_socket.connect(addr)
agent_conn = Connection(agent_socket)
# Reset
msg_in = agent_conn.receive_object()
if msg_in == "reset":
reset = True
else:
raise RuntimeError("First message must be 'reset'")
else:
# Send message
msg_out = {"observation": observation,
"reward": reward,
"done": done,
"info": info}
agent_conn.send_object(msg_out)
# Receive message
msg_in = agent_conn.receive_object()
if msg_in == "reset":
reset = True
elif msg_in == "close":
reset = True
agent_socket.close()
else:
action = msg_in
return action, reset
iter = 0
while True:
action, reset = agent(iter)
environment(action, reset)
iter += 1