Skip to content

Commit

Permalink
Merge pull request #129 from edbeeching/godot_env_py-_get_data-refactor
Browse files Browse the repository at this point in the history
Refactor godot_env.py
  • Loading branch information
Ivan-267 authored Nov 30, 2023
2 parents a5363cb + 5689d95 commit ac7241b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
41 changes: 28 additions & 13 deletions godot_rl/core/godot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def step(self, action, order_ij=False):
"""
self.step_send(action, order_ij=order_ij)
return self.step_recv()


def step_send(self, action, order_ij=False):
"""
Expand All @@ -172,7 +172,7 @@ def step_send(self, action, order_ij=False):
"action": self.from_numpy(action, order_ij=order_ij),
}
self._send_as_json(message)

def step_recv(self):
"""
Receive the step response from the Godot environment.
Expand All @@ -190,8 +190,8 @@ def step_recv(self):
np.array(response["done"]).tolist(), # TODO update API to term, trunc
[{}] * len(response["done"]),
)



def _process_obs(self, response_obs: dict):
"""
Expand Down Expand Up @@ -387,19 +387,34 @@ def _clear_socket(self):

def _get_data(self):
try:
data = self.connection.recv(4)
if not data:
time.sleep(0.000001)
return self._get_data()
length = int.from_bytes(data, "little")
string = ""
while len(string) != length: # TODO: refactor as string concatenation could be slow
string += self.connection.recv(length).decode()
# Receive the size (in bytes) of the remaining data to receive
string_size_bytes: bytearray = bytearray()
received_length: int = 0

# The first 4 bytes contain the length of the remaining data
length: int = 4

while received_length < length:
data = self.connection.recv(length - received_length)
received_length += len(data)
string_size_bytes.extend(data)

length = int.from_bytes(string_size_bytes, "little")

# Receive the rest of the data
string_bytes: bytearray = bytearray()
received_length = 0

while received_length < length:
data = self.connection.recv(length - received_length)
received_length += len(data)
string_bytes.extend(data)

string: str = string_bytes.decode()

return string
except socket.timeout as e:
print("env timed out", e)

return None

def _send_string(self, string):
Expand Down

0 comments on commit ac7241b

Please sign in to comment.