diff --git a/hopes/policy/policies.py b/hopes/policy/policies.py index 0a57d67..24d9e50 100644 --- a/hopes/policy/policies.py +++ b/hopes/policy/policies.py @@ -426,7 +426,7 @@ def log_probabilities(self, obs: np.ndarray) -> np.ndarray: """Compute the log-probabilities of the actions under the HTTP policy for a given set of observations.""" all_log_probs = [] - for chunk in np.array_split(obs, len(obs) // self.batch_size): + for chunk in np.array_split(obs, self.batch_size): # Send HTTP request to server response = requests.post( f"http{'s' if self.ssl else ''}://{self.host}:{self.port}/{self.path}",