Skip to content

Commit

Permalink
Move export from examples to library functions
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 15, 2023
1 parent 8c2390f commit 7a423aa
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 19 deletions.
95 changes: 93 additions & 2 deletions reinforced_lib/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from abc import ABC, abstractmethod
from typing import Any
from functools import wraps
from typing import Any, Tuple, Callable

import gymnasium as gym
from chex import dataclass, PRNGKey
import jax
from chex import dataclass, PRNGKey, ArrayTree

from reinforced_lib.utils.exceptions import UnimplementedSpaceError
from reinforced_lib.utils import is_array, is_dict


@dataclass
Expand Down Expand Up @@ -79,3 +84,89 @@ def action_space(self) -> gym.spaces.Space:
"""

raise NotImplementedError()

def export(self, init_key: PRNGKey, state: AgentState = None) -> Tuple[Any, Any, Any]:
"""
Exports the agent to TensorFlow Lite format.
Parameters
----------
init_key : PRNGKey
Key used to initialize the agent.
state : AgentState, optional
State of the agent to be exported. If not specified, the agent is initialized with ``init_key``.
"""

import tensorflow as tf

@dataclass
class TfLiteState:
state: ArrayTree
key: PRNGKey

def add_state(state: TfLiteState, args: Any) -> Any:
if args is None:
raise UnimplementedSpaceError()
elif is_dict(args):
return {**args, 'state': state}
elif is_array(args):
return [state] + list(args)
else:
return [state, args]

def flatten_args(tree_args_fun: Callable, treedef: ArrayTree) -> Callable:
@wraps(tree_args_fun)
def flat_args_fun(*leaves):
tree_args = jax.tree_util.tree_unflatten(treedef, leaves)

if is_dict(tree_args):
tree_ret = tree_args_fun(**tree_args)
else:
tree_ret = tree_args_fun(*tree_args)

return jax.tree_util.tree_leaves(tree_ret)

return flat_args_fun

def make_converter(fun: Callable, arguments: Any) -> tf.lite.TFLiteConverter:
leaves, treedef = jax.tree_util.tree_flatten(arguments)
flat_fun = flatten_args(fun, treedef)

inputs = [[(f'arg{i}', l) for i, l in enumerate(leaves)]]
converter = tf.lite.TFLiteConverter.experimental_from_jax([flat_fun], inputs)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]

return converter

def init() -> TfLiteState:
return TfLiteState(
state=self.init(init_key),
key=init_key
)

def sample(state: TfLiteState, *args, **kwargs) -> Tuple[Any, TfLiteState]:
sample_key, key = jax.random.split(state.key)
action = self.sample(state.state, sample_key, *args, **kwargs)
return action, TfLiteState(state=state.state, key=key)

def update(state: TfLiteState, *args, **kwargs) -> TfLiteState:
update_key, key = jax.random.split(state.key)
new_state = self.update(state.state, update_key, *args, **kwargs)
return TfLiteState(state=new_state, key=key)

if state is None:
state = init()
else:
state = TfLiteState(state=state, key=init_key)

update_args = add_state(state, self.update_observation_space.sample())
sample_args = add_state(state, self.sample_observation_space.sample())

tfl_init = make_converter(init, []).convert()
tfl_update = make_converter(update, update_args).convert()
tfl_sample = make_converter(sample, sample_args).convert()

return tfl_init, tfl_update, tfl_sample
5 changes: 2 additions & 3 deletions reinforced_lib/logs/csv_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os.path
from datetime import datetime
from typing import Any, Dict, List

import jax.numpy as jnp
Expand All @@ -9,6 +8,7 @@

from reinforced_lib.logs import BaseLogger, Source
from reinforced_lib.utils.exceptions import UnsupportedCustomLogsError
from utils import timestamp


class CsvLogger(BaseLogger):
Expand All @@ -30,8 +30,7 @@ def __init__(self, csv_path: str = None, **kwargs) -> None:
super().__init__(**kwargs)

if csv_path is None:
now = datetime.now()
csv_path = f'rlib-logs-{now.strftime("%Y%m%d")}-{now.strftime("%H%M%S")}.csv'
csv_path = f'rlib-logs-{timestamp()}.csv'
csv_path = os.path.join(os.path.expanduser("~"), csv_path)

self._file = open(csv_path, 'w')
Expand Down
9 changes: 4 additions & 5 deletions reinforced_lib/logs/logs_observer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from collections import defaultdict
from typing import Any, Callable, Dict, List

import jax.numpy as jnp

from reinforced_lib.agents import BaseAgent
from reinforced_lib.logs import BaseLogger, Source, SourceType
from reinforced_lib.utils.exceptions import IncorrectLoggerTypeError, IncorrectSourceTypeError
from reinforced_lib.utils import is_scalar, is_array, is_dict


class LogsObserver:
Expand Down Expand Up @@ -157,11 +156,11 @@ def _update(loggers: Dict[BaseLogger, List[Source]], get_value: Callable) -> Non

custom = name is None

if jnp.isscalar(value) or (hasattr(value, 'ndim') and value.ndim == 0):
if is_scalar(value):
logger.log_scalar(source, value, custom)
elif isinstance(value, dict):
elif is_dict(value):
logger.log_dict(source, value, custom)
elif isinstance(value, (list, tuple)) or (hasattr(value, 'ndim') and value.ndim == 1):
elif is_array(value):
logger.log_array(source, value, custom)
else:
logger.log_other(source, value, custom)
6 changes: 2 additions & 4 deletions reinforced_lib/logs/plots_logger.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os.path
from collections import defaultdict
from datetime import datetime
from typing import List

import jax.numpy as jnp
import matplotlib.pyplot as plt
from chex import Array, Scalar

from reinforced_lib.logs import BaseLogger, Source
from utils import timestamp


class PlotsLogger(BaseLogger):
Expand Down Expand Up @@ -80,10 +80,8 @@ def scatterplot(values: List, label: bool = False) -> None:
plt.scatter(xs, val, c=f'C{i % 10}', label=i if label else '', marker='.', s=4)
plt.legend()

now = datetime.now()

for name, values in self._plots_values.items():
filename = f'rlib-plot-{name}-{now.strftime("%Y%m%d")}-{now.strftime("%H%M%S")}.{self._plots_ext}'
filename = f'rlib-plot-{name}-{timestamp()}.{self._plots_ext}'

if self._plots_scatter:
scatterplot(values, True)
Expand Down
57 changes: 52 additions & 5 deletions reinforced_lib/rlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import datetime
import os
import pickle
from typing import Any, Dict, List, Tuple, Union
Expand All @@ -9,14 +8,15 @@
import gymnasium as gym
import jax.random
import lz4.frame
import numpy as np
from chex import dataclass

from reinforced_lib.agents import BaseAgent
from reinforced_lib.exts import BaseExt
from reinforced_lib.logs import Source
from reinforced_lib.logs.logs_observer import LogsObserver
from reinforced_lib.utils.exceptions import *
from reinforced_lib.utils import is_scalar
from utils import timestamp


@dataclass
Expand Down Expand Up @@ -461,14 +461,13 @@ def save(self, agent_ids: Union[int, List[int]] = None, path: str = None) -> str

if agent_ids is None:
agent_ids = list(range(len(self._agent_containers)))
elif np.isscalar(agent_ids) or (hasattr(agent_ids, 'ndim') and agent_ids.ndim == 0):
elif is_scalar(agent_ids):
agent_ids = [agent_ids]

agent_containers = [self._agent_containers[agent_id] for agent_id in agent_ids]

if path is None:
timestamp = datetime.datetime.now()
path = os.path.join(self._save_directory, f"rlib-checkpoint-{timestamp.date()}-{timestamp.time()}.pkl.lz4")
path = os.path.join(self._save_directory, f"rlib-checkpoint-{timestamp()}.pkl.lz4")
elif path[-8:] != self._lz4_ext:
path = path + self._lz4_ext

Expand Down Expand Up @@ -572,3 +571,51 @@ def log(self, name: str, value: Any) -> None:
"""

self._logs_observer.update_custom(value, name)

def to_tflite(self, path: str = None, agent_id: int = None, sample_only: bool = False) -> None:
"""
Converts the agent to a TensorFlow Lite model and saves it to a file.
Parameters
----------
path : str, optional
Path to the output file.
agent_id : int, optional
The identifier of the agent instance to convert. If specified,
state of the selected agent will be saved.
sample_only : bool
Flag indicating if the method should save only the sample function.
"""

if not self._agent:
raise NoAgentError()

if len(self._agent_containers) == 0:
self.init()

if path is None:
path = self._save_directory

if agent_id is None:
init_tfl, update_tfl, sample_tfl = self._agent.export(
init_key=jax.random.PRNGKey(42)
)
else:
init_tfl, update_tfl, sample_tfl = self._agent.export(
init_key=self._agent_containers[agent_id].key,
state=self._agent_containers[agent_id].state
)

base_name = self._agent.__class__.__name__
base_name += f'-{agent_id}-' if agent_id is not None else '-'
base_name += timestamp()

with open(os.path.join(path, f'rlib-{base_name}-sample.tflite'), 'wb') as f:
f.write(sample_tfl)

if not sample_only:
with open(os.path.join(path, f'rlib-{base_name}-init.tflite'), 'wb') as f:
f.write(init_tfl)

with open(os.path.join(path, f'rlib-{base_name}-update.tflite'), 'wb') as f:
f.write(update_tfl)
89 changes: 89 additions & 0 deletions reinforced_lib/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from datetime import datetime
from typing import Any

import numpy as np


def is_scalar(x: Any) -> bool:
"""
Checks whether the input is a scalar.
Parameters
----------
x : any
Input to check.
Returns
-------
bool
``True`` if the input is a scalar, ``False`` otherwise.
"""

return np.isscalar(x) or (hasattr(x, 'ndim') and x.ndim == 0)


def is_array(x: Any) -> bool:
"""
Checks whether the input is an array.
Parameters
----------
x : any
Input to check.
Returns
-------
bool
``True`` if the input is an array, ``False`` otherwise.
"""

return isinstance(x, (list, tuple)) or (hasattr(x, 'ndim') and x.ndim == 1)


def is_tensor(x: Any) -> bool:
"""
Checks whether the input is a tensor.
Parameters
----------
x : any
Input to check.
Returns
-------
bool
``True`` if the input is a tensor, ``False`` otherwise.
"""

return hasattr(x, 'ndim') and x.ndim > 1


def is_dict(x: Any) -> bool:
"""
Checks whether the input is a dictionary.
Parameters
----------
x : any
Input to check.
Returns
-------
bool
``True`` if the input is a dictionary, ``False`` otherwise.
"""

return isinstance(x, dict)


def timestamp() -> str:
"""
Returns the current timestamp.
Returns
-------
str
Current timestamp.
"""

return datetime.now().strftime('%Y%m%d-%H%M%S')
9 changes: 9 additions & 0 deletions reinforced_lib/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def __str__(self) -> str:
return 'Cannot find corresponding Gymnasium space.'


class UnimplementedSpaceError(Exception):
"""
Raised when an observation space is required but not implemented.
"""

def __str__(self) -> str:
return 'Appropriate observation space is not implemented.'


class IncompatibleSpacesError(Exception):
"""
Raised when the observation spaces of two different modules are not compatible.
Expand Down

0 comments on commit 7a423aa

Please sign in to comment.