Skip to content

Commit

Permalink
Fix exporting updatable agent with a given initial state
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 17, 2023
1 parent 755e7f9 commit 6644e5c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
15 changes: 11 additions & 4 deletions reinforced_lib/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def action_space(self) -> gym.spaces.Space:

raise NotImplementedError()

def export(self, init_key: PRNGKey, state: AgentState = None) -> tuple[any, any, any]:
def export(self, init_key: PRNGKey, state: AgentState = None, sample_only: bool = False) -> tuple[any, any, any]:
"""
Exports the agent to TensorFlow Lite format.
Expand All @@ -95,6 +95,8 @@ def export(self, init_key: PRNGKey, state: AgentState = None) -> tuple[any, any,
Key used to initialize the agent.
state : AgentState, optional
State of the agent to be exported. If not specified, the agent is initialized with ``init_key``.
sample_only : bool, optional
If ``True``, the exported agent will only be able to sample actions, but not update its state.
"""

import tensorflow as tf
Expand Down Expand Up @@ -165,8 +167,11 @@ def sample_without_state(state: AgentState, key: PRNGKey, *args, **kwargs) -> tu
action = self.sample(state, sample_key, *args, **kwargs)
return action, key

if state is None:
state = init()
if not sample_only:
if state is None:
state = init()
else:
state = TfLiteState(state=state, key=init_key)

update_args = append_value(state, 'state', self.update_observation_space.sample())
sample_args = append_value(state, 'state', self.sample_observation_space.sample())
Expand All @@ -176,10 +181,12 @@ def sample_without_state(state: AgentState, key: PRNGKey, *args, **kwargs) -> tu
sample_tfl = make_converter(sample, sample_args).convert()

return init_tfl, update_tfl, sample_tfl
else:
elif state is not None:
sample_args = append_value(init_key, 'key', self.sample_observation_space.sample())

init_tfl = make_converter(get_key, []).convert()
sample_tfl = make_converter(partial(sample_without_state, state), sample_args).convert()

return init_tfl, None, sample_tfl
else:
raise ValueError('Either `state` must be provided or `sample_only` must be False.')
3 changes: 2 additions & 1 deletion reinforced_lib/rlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,8 @@ def to_tflite(self, path: str = None, agent_id: int = None, sample_only: bool =
else:
init_tfl, update_tfl, sample_tfl = self._agent.export(
init_key=self._agent_containers[agent_id].key,
state=self._agent_containers[agent_id].state
state=self._agent_containers[agent_id].state,
sample_only=sample_only
)

base_name = self._agent.__class__.__name__
Expand Down

0 comments on commit 6644e5c

Please sign in to comment.