diff --git a/acme/jax/inference_server.py b/acme/jax/inference_server.py index 9e4f27a684..e87071e6f3 100644 --- a/acme/jax/inference_server.py +++ b/acme/jax/inference_server.py @@ -98,6 +98,9 @@ def _dereference_params(self, arg): key=self._keys, update_period=self._config.update_period) + if self._variable_client is None: + raise ValueError('_variable_client not set') + params = self._variable_client.params device_idx = self._call_cnt % len(self._devices) # Select device via round robin, and update its params if they changed.