Skip to content

Commit

Permalink
Fix sample_only export mode
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 17, 2023
1 parent c80b84c commit 69b8073
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
39 changes: 26 additions & 13 deletions reinforced_lib/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from functools import wraps
from functools import wraps, partial
from typing import Any, Tuple, Callable

import gymnasium as gym
Expand Down Expand Up @@ -104,15 +104,15 @@ class TfLiteState:
state: ArrayTree
key: PRNGKey

def add_state(state: TfLiteState, args: Any) -> Any:
def append_value(value: Any, value_name: str, args: Any) -> Any:
if args is None:
raise UnimplementedSpaceError()
elif is_dict(args):
return {**args, 'state': state}
return {**args, value_name: value}
elif is_array(args):
return [state] + list(args)
return [value] + list(args)
else:
return [state, args]
return [value, args]

def flatten_args(tree_args_fun: Callable, treedef: ArrayTree) -> Callable:
@wraps(tree_args_fun)
Expand Down Expand Up @@ -157,16 +157,29 @@ def update(state: TfLiteState, *args, **kwargs) -> TfLiteState:
new_state = self.update(state.state, update_key, *args, **kwargs)
return TfLiteState(state=new_state, key=key)

def get_key() -> PRNGKey:
return init_key

def sample_without_state(state: AgentState, key: PRNGKey, *args, **kwargs) -> Tuple[Any, PRNGKey]:
sample_key, key = jax.random.split(key)
action = self.sample(state, sample_key, *args, **kwargs)
return action, key

if state is None:
state = init()
else:
state = TfLiteState(state=state, key=init_key)

update_args = add_state(state, self.update_observation_space.sample())
sample_args = add_state(state, self.sample_observation_space.sample())
update_args = append_value(state, 'state', self.update_observation_space.sample())
sample_args = append_value(state, 'state', self.sample_observation_space.sample())

init_tfl = make_converter(init, []).convert()
update_tfl = make_converter(update, update_args).convert()
sample_tfl = make_converter(sample, sample_args).convert()

return init_tfl, update_tfl, sample_tfl
else:
sample_args = append_value(init_key, 'key', self.sample_observation_space.sample())

tfl_init = make_converter(init, []).convert()
tfl_update = make_converter(update, update_args).convert()
tfl_sample = make_converter(sample, sample_args).convert()
init_tfl = make_converter(get_key, []).convert()
sample_tfl = make_converter(partial(sample_without_state, state), sample_args).convert()

return tfl_init, tfl_update, tfl_sample
return init_tfl, None, sample_tfl
9 changes: 6 additions & 3 deletions reinforced_lib/rlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,9 @@ def to_tflite(self, path: str = None, agent_id: int = None, sample_only: bool =
if len(self._agent_containers) == 0:
self.init()

if sample_only and agent_id is None:
raise ValueError("Agent ID must be specified when saving sample function only.")

if path is None:
path = self._save_directory

Expand All @@ -609,12 +612,12 @@ def to_tflite(self, path: str = None, agent_id: int = None, sample_only: bool =
base_name += f'-{agent_id}-' if agent_id is not None else '-'
base_name += timestamp()

with open(os.path.join(path, f'rlib-{base_name}-init.tflite'), 'wb') as f:
f.write(init_tfl)

with open(os.path.join(path, f'rlib-{base_name}-sample.tflite'), 'wb') as f:
f.write(sample_tfl)

if not sample_only:
with open(os.path.join(path, f'rlib-{base_name}-init.tflite'), 'wb') as f:
f.write(init_tfl)

with open(os.path.join(path, f'rlib-{base_name}-update.tflite'), 'wb') as f:
f.write(update_tfl)

0 comments on commit 69b8073

Please sign in to comment.