diff --git a/examples/lite/lite.py b/examples/lite/lite.py deleted file mode 100644 index 43d2bc5..0000000 --- a/examples/lite/lite.py +++ /dev/null @@ -1,136 +0,0 @@ -from functools import wraps -from functools import wraps -from typing import Tuple, Any - -import chex -import jax -import jax.numpy as jnp -import jax.tree_util as tree -import tensorflow as tf -from chex import PRNGKey, ArrayTree - -from reinforced_lib.agents import BaseAgent -from reinforced_lib.agents.mab import ThompsonSampling - - -def flatten_args(tree_args_fun, treedef): - @wraps(tree_args_fun) - def flat_args_fun(*leaves): - tree_args = tree.tree_unflatten(treedef, leaves) - tree_ret = tree_args_fun(*tree_args) - return tree.tree_leaves(tree_ret) - - return flat_args_fun - - -def make_converter(f, example): - """ - - Parameters - ---------- - f a function - example arguments in tree format - """ - leaves, treedef = tree.tree_flatten(example) - - flat_fun = flatten_args(f, treedef) - - inputs = [[(f'arg{i}', l) for i, l in enumerate(leaves)]] - converter = tf.lite.TFLiteConverter.experimental_from_jax([flat_fun], inputs) - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. - tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. - ] - - return converter - - -@chex.dataclass -class TfLiteState: - agent_state: ArrayTree - key: PRNGKey - - -def export(agent: BaseAgent, base_name: str, init_args, update_args, sample_args): - def init(*args, **kwargs) -> TfLiteState: - return TfLiteState( - agent_state=agent.init(*args, **kwargs), - key=jax.random.PRNGKey(42) - ) - - def sample(state: TfLiteState, *args, **kwargs) -> Tuple[Any, TfLiteState]: - sample_key, key = jax.random.split(state.key) - action = agent.sample(state.agent_state, sample_key, *args, **kwargs) - return action, TfLiteState(agent_state=state.agent_state, key=key) - - def update(state: TfLiteState, *args, **kwargs) -> TfLiteState: - update_key, key = jax.random.split(state.key) - new_state = agent.update(state.agent_state, update_key, *args, **kwargs) - return TfLiteState(agent_state=new_state, key=key) - - s0 = init(*init_args) - converter = make_converter(init, init_args) - tfl_init = converter.convert() - with open(f'{base_name}_init.tflite', 'wb') as f: - f.write(tfl_init) - - args = [s0] + update_args - s1 = update(*args) - converter = make_converter(update, args) - tfl_update = converter.convert() - - with open(f'{base_name}_update.tflite', 'wb') as f: - f.write(tfl_update) - - args = [s1] + sample_args - converter = make_converter(sample, args) - tfl_sample = converter.convert() - with open(f'{base_name}_sample.tflite', 'wb') as f: - f.write(tfl_sample) - - return - - -def main(): - k = jax.random.PRNGKey(42) - ts = ThompsonSampling(16) - - export(ts, 'ts', init_args=[k], update_args=[1, 2, 1, 0.1], sample_args=[1.]) - - with open('ts_init.tflite', 'rb') as f: - interpreter = tf.lite.Interpreter(model_content=f.read()) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - interpreter.set_tensor(input_details[0]['index'], k) - interpreter.invoke() - outs = [interpreter.get_tensor(od['index']) for od in interpreter.get_output_details()] - - with open('ts_update.tflite', 'rb') as f: - interpreter = tf.lite.Interpreter(model_content=f.read()) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - # types - ins = outs + tree.tree_map(jnp.asarray,[1, 2, 1, 0.1]) - for d, a in zip(input_details, ins): - interpreter.set_tensor(d['index'], a) - interpreter.invoke() - next_outs = [interpreter.get_tensor(od['index']) for od in interpreter.get_output_details()] - - with open('ts_sample.tflite', 'rb') as f: - interpreter = tf.lite.Interpreter(model_content=f.read()) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - ins = next_outs + [jnp.asarray(1.)] - for d, a in zip(input_details, ins): - interpreter.set_tensor(d['index'], a) - interpreter.invoke() - action_and_state = [interpreter.get_tensor(od['index']) for od in interpreter.get_output_details()] - - action = action_and_state[0] - next_state = action_and_state[1:] - - -if __name__ == '__main__': - main() - - ... diff --git a/examples/lite/lite_test.py b/examples/lite/lite_test.py deleted file mode 100644 index 2c42d44..0000000 --- a/examples/lite/lite_test.py +++ /dev/null @@ -1,139 +0,0 @@ -import jax -import jax.numpy as jnp -import tensorflow as tf - -from reinforced_lib.agents.mab import ThompsonSampling - - -def main2(): - k = jax.random.PRNGKey(42) - ts = ThompsonSampling(16) - - state = ts.init(k) - new_state = ts.update(state, k, 1, 2, 1, 0.1) - a = ts.sample(new_state, k, 1.) - - arrays, defs = jax.tree_util.tree_flatten(state) - - def _update(*args): - s = jax.tree_util.tree_unflatten(defs, args[:2]) - tmp = ts.update(s, *args[2:]) - arrs, _ = jax.tree_util.tree_flatten(tmp) - x = args[0].at[1].set(4) - - return x - - converter = tf.lite.TFLiteConverter.experimental_from_jax([_update], - [[('alpha', state.alpha), - ('beta', state.beta), - ('k', k), - ('action', jnp.asarray(1)), - ('n_successful', jnp.asarray(1)), - ("n_failed", jnp.asarray(0)), - ("delta_time", jnp.asarray(0.1))]]) - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. - tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. - - ] - - tflite_update = converter.convert() - with open('update.tflite', 'wb') as f: - f.write(tflite_update) - - interpreter = tf.lite.Interpreter(model_content=tflite_update) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - args = jax.tree_util.tree_leaves(state) + jax.tree_util.tree_map(jnp.asarray, [k, 1, 2, 1, 0.1]) - - # for a, d in zip(args, input_details): - # interpreter.set_tensor(d['index'], a) - - interpreter.invoke() - return - - -def main3(): - print(f'JAX {jax.__version__}') - print(f'tf {tf.__version__}') - - @jax.jit - def _update(x): - return x.at[0].set(4) - - converter = tf.lite.TFLiteConverter.experimental_from_jax([_update], - [[('x', jnp.ones(2))]]) - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. - tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. - - ] - - tflite_update = converter.convert() - with open('update.tflite', 'wb') as f: - f.write(tflite_update) - - interpreter = tf.lite.Interpreter(model_content=tflite_update) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - args = jnp.ones(2) - - expected = _update(args) - print("Expected output:", expected) - - interpreter.set_tensor(input_details[0]['index'], args) - interpreter.invoke() - - interpreter.set_tensor(input_details[0]['index'], args) - interpreter.invoke() - - output = interpreter.get_tensor(output_details[0]['index']) - print("Output:", output) - - return - - -def main4(): - print(f'JAX {jax.__version__}') - print(f'tf {tf.__version__}') - - @jax.jit - def _update(x): - return x + 4 * jax.nn.one_hot(0, x.shape[0]) - - converter = tf.lite.TFLiteConverter.experimental_from_jax([_update], - [[('x', jnp.ones(2))]]) - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. - tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. - - ] - - tflite_update = converter.convert() - with open('update.tflite', 'wb') as f: - f.write(tflite_update) - - interpreter = tf.lite.Interpreter(model_content=tflite_update) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - args = jnp.ones(2) - - expected = _update(args) - - interpreter.set_tensor(input_details[0]['index'], args) - - interpreter.invoke() - assert jnp.allclose(interpreter.get_tensor(output_details[0]['index']), expected) - return - - -if __name__ == '__main__': - main2() - main3() - main4() diff --git a/reinforced_lib/agents/base_agent.py b/reinforced_lib/agents/base_agent.py index fa00ee3..577fc97 100644 --- a/reinforced_lib/agents/base_agent.py +++ b/reinforced_lib/agents/base_agent.py @@ -110,7 +110,7 @@ def append_value(value: any, value_name: str, args: any) -> any: if args is None: raise UnimplementedSpaceError() elif is_dict(args): - return args | {value_name: value} + return {value_name: value} | args elif is_array(args): return [value] + list(args) else: diff --git a/test/test_rlib_to_tflite.py b/test/test_rlib_to_tflite.py new file mode 100644 index 0000000..56c7811 --- /dev/null +++ b/test/test_rlib_to_tflite.py @@ -0,0 +1,163 @@ +import os +import unittest +from glob import glob + +import haiku as hk +import numpy as np +import optax + +from reinforced_lib.agents.deep import DDPG +from reinforced_lib.agents.mab import ThompsonSampling +from reinforced_lib.exts import Gymnasium +from reinforced_lib.rlib import RLib + + +class TestRLibToTflite(unittest.TestCase): + def test_sample_only_export(self): + try: + import tensorflow as tf + except ModuleNotFoundError: + return + + rl = RLib( + agent_type=ThompsonSampling, + agent_params={'n_arms': 4}, + no_ext_mode=True + ) + rl.to_tflite(agent_id=0, sample_only=True) + + with open(glob(f'{os.path.expanduser("~")}/rlib-ThompsonSampling-0-*-init.tflite')[0], 'rb') as f: + interpreter = tf.lite.Interpreter(model_content=f.read()) + interpreter.allocate_tensors() + interpreter.invoke() + key = [interpreter.get_tensor(od['index']) for od in interpreter.get_output_details()] + + with open(glob(f'{os.path.expanduser("~")}/rlib-ThompsonSampling-0-*-sample.tflite')[0], 'rb') as f: + interpreter = tf.lite.Interpreter(model_content=f.read()) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + ins = key + [np.arange(4).astype(np.float32)] + + for d, a in zip(input_details, ins): + interpreter.set_tensor(d['index'], a) + + interpreter.invoke() + action, key = [interpreter.get_tensor(od['index']) for od in interpreter.get_output_details()] + + assert isinstance(action, (int, np.int32)) + assert isinstance(key, np.ndarray) + assert key.shape == (2,) + + def test_full_export(self): + try: + import tensorflow as tf + except ModuleNotFoundError: + return + + rl = RLib( + agent_type=ThompsonSampling, + agent_params={'n_arms': 4}, + no_ext_mode=True + ) + rl.to_tflite() + + with open(glob(f'{os.path.expanduser("~")}/rlib-ThompsonSampling-*-init.tflite')[0], 'rb') as f: + interpreter = tf.lite.Interpreter(model_content=f.read()) + interpreter.allocate_tensors() + interpreter.invoke() + state = [interpreter.get_tensor(od['index']) for od in interpreter.get_output_details()] + + with open(glob(f'{os.path.expanduser("~")}/rlib-ThompsonSampling-*-update.tflite')[0], 'rb') as f: + interpreter = tf.lite.Interpreter(model_content=f.read()) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + ins = state + [ + np.asarray(2, dtype=np.int32), + np.asarray([1.], dtype=np.float32), + np.asarray([1], dtype=np.int32), + np.asarray([0], dtype=np.int32) + ] + + for d, a in zip(input_details, ins): + interpreter.set_tensor(d['index'], a) + + interpreter.invoke() + state = [interpreter.get_tensor(od['index']) for od in interpreter.get_output_details()] + + with open(glob(f'{os.path.expanduser("~")}/rlib-ThompsonSampling-*-sample.tflite')[0], 'rb') as f: + interpreter = tf.lite.Interpreter(model_content=f.read()) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + ins = state + [np.arange(4).astype(np.float32)] + + for d, a in zip(input_details, ins): + interpreter.set_tensor(d['index'], a) + + interpreter.invoke() + action, key, state_alpha, state_beta = [interpreter.get_tensor(od['index']) for od in interpreter.get_output_details()] + + assert isinstance(action, (int, np.int32)) + assert isinstance(key, np.ndarray) + assert key.shape == (2,) + assert isinstance(state_alpha, np.ndarray) + assert state_alpha.shape == (4, 1) + assert isinstance(state_beta, np.ndarray) + assert state_beta.shape == (4, 1) + + def test_drl_sample_only_export(self): + try: + import tensorflow as tf + except ModuleNotFoundError: + return + + @hk.transform_with_state + def q_network(s, a): + return hk.nets.MLP([64, 1])(s) + + @hk.transform_with_state + def a_network(s): + return hk.nets.MLP([32, 32, 1])(s) + + rl = RLib( + agent_type=DDPG, + agent_params={ + 'q_network': q_network, + 'a_network': a_network, + 'q_optimizer': optax.adam(2e-3), + 'a_optimizer': optax.adam(1e-3) + }, + ext_type=Gymnasium, + ext_params={'env_id': 'Pendulum-v1'} + ) + rl.to_tflite(agent_id=0, sample_only=True) + + def test_drl_full_export(self): + try: + import tensorflow as tf + except ModuleNotFoundError: + return + + @hk.transform_with_state + def q_network(s, a): + return hk.nets.MLP([64, 1])(s) + + @hk.transform_with_state + def a_network(s): + return hk.nets.MLP([32, 32, 1])(s) + + rl = RLib( + agent_type=DDPG, + agent_params={ + 'q_network': q_network, + 'a_network': a_network, + 'q_optimizer': optax.adam(2e-3), + 'a_optimizer': optax.adam(1e-3) + }, + ext_type=Gymnasium, + ext_params={'env_id': 'Pendulum-v1'} + ) + rl.to_tflite() + + +if __name__ == '__main__': + unittest.main()