From 7a423aa712d46879a9100d55160a9e5f0d33087a Mon Sep 17 00:00:00 2001 From: Maksymilian Wojnar Date: Sat, 15 Jul 2023 21:41:03 +0200 Subject: [PATCH] Move export from examples to library functions --- reinforced_lib/agents/base_agent.py | 95 +++++++++++++++++++++++++++- reinforced_lib/logs/csv_logger.py | 5 +- reinforced_lib/logs/logs_observer.py | 9 ++- reinforced_lib/logs/plots_logger.py | 6 +- reinforced_lib/rlib.py | 57 +++++++++++++++-- reinforced_lib/utils/__init__.py | 89 ++++++++++++++++++++++++++ reinforced_lib/utils/exceptions.py | 9 +++ 7 files changed, 251 insertions(+), 19 deletions(-) diff --git a/reinforced_lib/agents/base_agent.py b/reinforced_lib/agents/base_agent.py index cf78e22..0e5316a 100644 --- a/reinforced_lib/agents/base_agent.py +++ b/reinforced_lib/agents/base_agent.py @@ -1,8 +1,13 @@ from abc import ABC, abstractmethod -from typing import Any +from functools import wraps +from typing import Any, Tuple, Callable import gymnasium as gym -from chex import dataclass, PRNGKey +import jax +from chex import dataclass, PRNGKey, ArrayTree + +from reinforced_lib.utils.exceptions import UnimplementedSpaceError +from reinforced_lib.utils import is_array, is_dict @dataclass @@ -79,3 +84,89 @@ def action_space(self) -> gym.spaces.Space: """ raise NotImplementedError() + + def export(self, init_key: PRNGKey, state: AgentState = None) -> Tuple[Any, Any, Any]: + """ + Exports the agent to TensorFlow Lite format. + + Parameters + ---------- + init_key : PRNGKey + Key used to initialize the agent. + state : AgentState, optional + State of the agent to be exported. If not specified, the agent is initialized with ``init_key``. + """ + + import tensorflow as tf + + @dataclass + class TfLiteState: + state: ArrayTree + key: PRNGKey + + def add_state(state: TfLiteState, args: Any) -> Any: + if args is None: + raise UnimplementedSpaceError() + elif is_dict(args): + return {**args, 'state': state} + elif is_array(args): + return [state] + list(args) + else: + return [state, args] + + def flatten_args(tree_args_fun: Callable, treedef: ArrayTree) -> Callable: + @wraps(tree_args_fun) + def flat_args_fun(*leaves): + tree_args = jax.tree_util.tree_unflatten(treedef, leaves) + + if is_dict(tree_args): + tree_ret = tree_args_fun(**tree_args) + else: + tree_ret = tree_args_fun(*tree_args) + + return jax.tree_util.tree_leaves(tree_ret) + + return flat_args_fun + + def make_converter(fun: Callable, arguments: Any) -> tf.lite.TFLiteConverter: + leaves, treedef = jax.tree_util.tree_flatten(arguments) + flat_fun = flatten_args(fun, treedef) + + inputs = [[(f'arg{i}', l) for i, l in enumerate(leaves)]] + converter = tf.lite.TFLiteConverter.experimental_from_jax([flat_fun], inputs) + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. + tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. + ] + + return converter + + def init() -> TfLiteState: + return TfLiteState( + state=self.init(init_key), + key=init_key + ) + + def sample(state: TfLiteState, *args, **kwargs) -> Tuple[Any, TfLiteState]: + sample_key, key = jax.random.split(state.key) + action = self.sample(state.state, sample_key, *args, **kwargs) + return action, TfLiteState(state=state.state, key=key) + + def update(state: TfLiteState, *args, **kwargs) -> TfLiteState: + update_key, key = jax.random.split(state.key) + new_state = self.update(state.state, update_key, *args, **kwargs) + return TfLiteState(state=new_state, key=key) + + if state is None: + state = init() + else: + state = TfLiteState(state=state, key=init_key) + + update_args = add_state(state, self.update_observation_space.sample()) + sample_args = add_state(state, self.sample_observation_space.sample()) + + tfl_init = make_converter(init, []).convert() + tfl_update = make_converter(update, update_args).convert() + tfl_sample = make_converter(sample, sample_args).convert() + + return tfl_init, tfl_update, tfl_sample diff --git a/reinforced_lib/logs/csv_logger.py b/reinforced_lib/logs/csv_logger.py index 0f12f23..d8f3e58 100644 --- a/reinforced_lib/logs/csv_logger.py +++ b/reinforced_lib/logs/csv_logger.py @@ -1,6 +1,5 @@ import json import os.path -from datetime import datetime from typing import Any, Dict, List import jax.numpy as jnp @@ -9,6 +8,7 @@ from reinforced_lib.logs import BaseLogger, Source from reinforced_lib.utils.exceptions import UnsupportedCustomLogsError +from utils import timestamp class CsvLogger(BaseLogger): @@ -30,8 +30,7 @@ def __init__(self, csv_path: str = None, **kwargs) -> None: super().__init__(**kwargs) if csv_path is None: - now = datetime.now() - csv_path = f'rlib-logs-{now.strftime("%Y%m%d")}-{now.strftime("%H%M%S")}.csv' + csv_path = f'rlib-logs-{timestamp()}.csv' csv_path = os.path.join(os.path.expanduser("~"), csv_path) self._file = open(csv_path, 'w') diff --git a/reinforced_lib/logs/logs_observer.py b/reinforced_lib/logs/logs_observer.py index ba31a4e..7b4419a 100644 --- a/reinforced_lib/logs/logs_observer.py +++ b/reinforced_lib/logs/logs_observer.py @@ -1,11 +1,10 @@ from collections import defaultdict from typing import Any, Callable, Dict, List -import jax.numpy as jnp - from reinforced_lib.agents import BaseAgent from reinforced_lib.logs import BaseLogger, Source, SourceType from reinforced_lib.utils.exceptions import IncorrectLoggerTypeError, IncorrectSourceTypeError +from reinforced_lib.utils import is_scalar, is_array, is_dict class LogsObserver: @@ -157,11 +156,11 @@ def _update(loggers: Dict[BaseLogger, List[Source]], get_value: Callable) -> Non custom = name is None - if jnp.isscalar(value) or (hasattr(value, 'ndim') and value.ndim == 0): + if is_scalar(value): logger.log_scalar(source, value, custom) - elif isinstance(value, dict): + elif is_dict(value): logger.log_dict(source, value, custom) - elif isinstance(value, (list, tuple)) or (hasattr(value, 'ndim') and value.ndim == 1): + elif is_array(value): logger.log_array(source, value, custom) else: logger.log_other(source, value, custom) diff --git a/reinforced_lib/logs/plots_logger.py b/reinforced_lib/logs/plots_logger.py index 01b4dbb..76af2f0 100644 --- a/reinforced_lib/logs/plots_logger.py +++ b/reinforced_lib/logs/plots_logger.py @@ -1,6 +1,5 @@ import os.path from collections import defaultdict -from datetime import datetime from typing import List import jax.numpy as jnp @@ -8,6 +7,7 @@ from chex import Array, Scalar from reinforced_lib.logs import BaseLogger, Source +from utils import timestamp class PlotsLogger(BaseLogger): @@ -80,10 +80,8 @@ def scatterplot(values: List, label: bool = False) -> None: plt.scatter(xs, val, c=f'C{i % 10}', label=i if label else '', marker='.', s=4) plt.legend() - now = datetime.now() - for name, values in self._plots_values.items(): - filename = f'rlib-plot-{name}-{now.strftime("%Y%m%d")}-{now.strftime("%H%M%S")}.{self._plots_ext}' + filename = f'rlib-plot-{name}-{timestamp()}.{self._plots_ext}' if self._plots_scatter: scatterplot(values, True) diff --git a/reinforced_lib/rlib.py b/reinforced_lib/rlib.py index d32f0b0..f77ce66 100644 --- a/reinforced_lib/rlib.py +++ b/reinforced_lib/rlib.py @@ -1,6 +1,5 @@ from __future__ import annotations -import datetime import os import pickle from typing import Any, Dict, List, Tuple, Union @@ -9,7 +8,6 @@ import gymnasium as gym import jax.random import lz4.frame -import numpy as np from chex import dataclass from reinforced_lib.agents import BaseAgent @@ -17,6 +15,8 @@ from reinforced_lib.logs import Source from reinforced_lib.logs.logs_observer import LogsObserver from reinforced_lib.utils.exceptions import * +from reinforced_lib.utils import is_scalar +from utils import timestamp @dataclass @@ -461,14 +461,13 @@ def save(self, agent_ids: Union[int, List[int]] = None, path: str = None) -> str if agent_ids is None: agent_ids = list(range(len(self._agent_containers))) - elif np.isscalar(agent_ids) or (hasattr(agent_ids, 'ndim') and agent_ids.ndim == 0): + elif is_scalar(agent_ids): agent_ids = [agent_ids] agent_containers = [self._agent_containers[agent_id] for agent_id in agent_ids] if path is None: - timestamp = datetime.datetime.now() - path = os.path.join(self._save_directory, f"rlib-checkpoint-{timestamp.date()}-{timestamp.time()}.pkl.lz4") + path = os.path.join(self._save_directory, f"rlib-checkpoint-{timestamp()}.pkl.lz4") elif path[-8:] != self._lz4_ext: path = path + self._lz4_ext @@ -572,3 +571,51 @@ def log(self, name: str, value: Any) -> None: """ self._logs_observer.update_custom(value, name) + + def to_tflite(self, path: str = None, agent_id: int = None, sample_only: bool = False) -> None: + """ + Converts the agent to a TensorFlow Lite model and saves it to a file. + + Parameters + ---------- + path : str, optional + Path to the output file. + agent_id : int, optional + The identifier of the agent instance to convert. If specified, + state of the selected agent will be saved. + sample_only : bool + Flag indicating if the method should save only the sample function. + """ + + if not self._agent: + raise NoAgentError() + + if len(self._agent_containers) == 0: + self.init() + + if path is None: + path = self._save_directory + + if agent_id is None: + init_tfl, update_tfl, sample_tfl = self._agent.export( + init_key=jax.random.PRNGKey(42) + ) + else: + init_tfl, update_tfl, sample_tfl = self._agent.export( + init_key=self._agent_containers[agent_id].key, + state=self._agent_containers[agent_id].state + ) + + base_name = self._agent.__class__.__name__ + base_name += f'-{agent_id}-' if agent_id is not None else '-' + base_name += timestamp() + + with open(os.path.join(path, f'rlib-{base_name}-sample.tflite'), 'wb') as f: + f.write(sample_tfl) + + if not sample_only: + with open(os.path.join(path, f'rlib-{base_name}-init.tflite'), 'wb') as f: + f.write(init_tfl) + + with open(os.path.join(path, f'rlib-{base_name}-update.tflite'), 'wb') as f: + f.write(update_tfl) diff --git a/reinforced_lib/utils/__init__.py b/reinforced_lib/utils/__init__.py index e69de29..76100f6 100644 --- a/reinforced_lib/utils/__init__.py +++ b/reinforced_lib/utils/__init__.py @@ -0,0 +1,89 @@ +from datetime import datetime +from typing import Any + +import numpy as np + + +def is_scalar(x: Any) -> bool: + """ + Checks whether the input is a scalar. + + Parameters + ---------- + x : any + Input to check. + + Returns + ------- + bool + ``True`` if the input is a scalar, ``False`` otherwise. + """ + + return np.isscalar(x) or (hasattr(x, 'ndim') and x.ndim == 0) + + +def is_array(x: Any) -> bool: + """ + Checks whether the input is an array. + + Parameters + ---------- + x : any + Input to check. + + Returns + ------- + bool + ``True`` if the input is an array, ``False`` otherwise. + """ + + return isinstance(x, (list, tuple)) or (hasattr(x, 'ndim') and x.ndim == 1) + + +def is_tensor(x: Any) -> bool: + """ + Checks whether the input is a tensor. + + Parameters + ---------- + x : any + Input to check. + + Returns + ------- + bool + ``True`` if the input is a tensor, ``False`` otherwise. + """ + + return hasattr(x, 'ndim') and x.ndim > 1 + + +def is_dict(x: Any) -> bool: + """ + Checks whether the input is a dictionary. + + Parameters + ---------- + x : any + Input to check. + + Returns + ------- + bool + ``True`` if the input is a dictionary, ``False`` otherwise. + """ + + return isinstance(x, dict) + + +def timestamp() -> str: + """ + Returns the current timestamp. + + Returns + ------- + str + Current timestamp. + """ + + return datetime.now().strftime('%Y%m%d-%H%M%S') diff --git a/reinforced_lib/utils/exceptions.py b/reinforced_lib/utils/exceptions.py index 518311a..2ec0553 100644 --- a/reinforced_lib/utils/exceptions.py +++ b/reinforced_lib/utils/exceptions.py @@ -135,6 +135,15 @@ def __str__(self) -> str: return 'Cannot find corresponding Gymnasium space.' +class UnimplementedSpaceError(Exception): + """ + Raised when an observation space is required but not implemented. + """ + + def __str__(self) -> str: + return 'Appropriate observation space is not implemented.' + + class IncompatibleSpacesError(Exception): """ Raised when the observation spaces of two different modules are not compatible.