From 2878b51a15540f43c770d8ba7a0b8f5ee30d6806 Mon Sep 17 00:00:00 2001 From: tvmarino <145081464+tvmarino@users.noreply.github.com> Date: Tue, 10 Sep 2024 10:00:06 -0400 Subject: [PATCH] Policy combiner (#357) * Combine two tf_agents policies with time_step_spec and action_spec given by registry problem config. The combiner policy uses a new time_step_spec feature "model_selector" to select the requested policy at the current state. The feature is computed as a md5 hash from the respective policies names. --- compiler_opt/tools/combine_tfa_policies.py | 89 ++++++++++ .../tools/combine_tfa_policies_lib.py | 116 +++++++++++++ .../tools/combine_tfa_policies_lib_test.py | 152 ++++++++++++++++++ 3 files changed, 357 insertions(+) create mode 100755 compiler_opt/tools/combine_tfa_policies.py create mode 100644 compiler_opt/tools/combine_tfa_policies_lib.py create mode 100644 compiler_opt/tools/combine_tfa_policies_lib_test.py diff --git a/compiler_opt/tools/combine_tfa_policies.py b/compiler_opt/tools/combine_tfa_policies.py new file mode 100755 index 00000000..c3146711 --- /dev/null +++ b/compiler_opt/tools/combine_tfa_policies.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Runs the policy combiner.""" +from absl import app +from absl import flags +from absl import logging + +import sys + +import gin + +import tensorflow as tf + +from compiler_opt.rl import policy_saver +from compiler_opt.rl import registry +from compiler_opt.tools import combine_tfa_policies_lib as cfa_lib + +_COMBINE_POLICIES_NAMES = flags.DEFINE_multi_string( + 'policies_names', [], 'List in order of policy names for combined policies.' + 'Order must match that of policies_paths.') +_COMBINE_POLICIES_PATHS = flags.DEFINE_multi_string( + 'policies_paths', [], 'List in order of policy paths for combined policies.' + 'Order must match that of policies_names.') +_COMBINED_POLICY_PATH = flags.DEFINE_string( + 'combined_policy_path', '', 'Path to save the combined policy.') +_GIN_FILES = flags.DEFINE_multi_string( + 'gin_files', [], 'List of paths to gin configuration files.') +_GIN_BINDINGS = flags.DEFINE_multi_string( + 'gin_bindings', [], + 'Gin bindings to override the values set in the config files.') + + +def main(_): + flags.mark_flag_as_required('policies_names') + flags.mark_flag_as_required('policies_paths') + flags.mark_flag_as_required('combined_policy_path') + if len(_COMBINE_POLICIES_NAMES.value) != len(_COMBINE_POLICIES_PATHS.value): + logging.error( + 'Length of policies_names: %d must equal length of policies_paths: %d.', + len(_COMBINE_POLICIES_NAMES.value), len(_COMBINE_POLICIES_PATHS.value)) + sys.exit(1) + gin.parse_config_files_and_bindings( + _GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False) + + problem_config = registry.get_configuration() + expected_signature, action_spec = problem_config.get_signature_spec() + expected_signature.observation.update({ + 'model_selector': + tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector') + }) + # TODO(359): We only support combining two policies.Generalize this to handle + # multiple policies. + if len(_COMBINE_POLICIES_NAMES.value) != 2: + logging.error('Policy combiner only supports two policies, %d given.', + len(_COMBINE_POLICIES_NAMES.value)) + sys.exit(1) + policy1_name = _COMBINE_POLICIES_NAMES.value[0] + policy1_path = _COMBINE_POLICIES_PATHS.value[0] + policy2_name = _COMBINE_POLICIES_NAMES.value[1] + policy2_path = _COMBINE_POLICIES_PATHS.value[1] + policy1 = tf.saved_model.load(policy1_path, tags=None, options=None) + policy2 = tf.saved_model.load(policy2_path, tags=None, options=None) + combined_policy = cfa_lib.CombinedTFPolicy( + tf_policies={ + policy1_name: policy1, + policy2_name: policy2 + }, + time_step_spec=expected_signature, + action_spec=action_spec) + combined_policy_path = _COMBINED_POLICY_PATH.value + policy_dict = {'combined_policy': combined_policy} + saver = policy_saver.PolicySaver(policy_dict=policy_dict) + saver.save(combined_policy_path) + + +if __name__ == '__main__': + app.run(main) diff --git a/compiler_opt/tools/combine_tfa_policies_lib.py b/compiler_opt/tools/combine_tfa_policies_lib.py new file mode 100644 index 00000000..4d1de2b3 --- /dev/null +++ b/compiler_opt/tools/combine_tfa_policies_lib.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Combines two tf-agent policies with the given state and action spec.""" +from typing import Dict, Optional, Tuple + +import tensorflow as tf +import hashlib + +import tf_agents +from tf_agents.trajectories import time_step as ts +from tf_agents.typing import types +from tf_agents.trajectories import policy_step +import tensorflow_probability as tfp + + +class CombinedTFPolicy(tf_agents.policies.TFPolicy): + """Policy which combines two target policies.""" + + def __init__(self, *args, tf_policies: Dict[str, tf_agents.policies.TFPolicy], + **kwargs): + super().__init__(*args, **kwargs) + + self.tf_policies = [] + self.tf_policy_names = [] + for name, policy in tf_policies.items(): + self.tf_policies.append(policy) + self.tf_policy_names.append(name) + + self.expected_signature = self.time_step_spec + self.sorted_keys = sorted(self.expected_signature.observation.keys()) + + high_low_tensors = [] + for name in self.tf_policy_names: + m = hashlib.md5() + m.update(name.encode("utf-8")) + high_low_tensors.append( + tf.stack([ + tf.constant( + int.from_bytes(m.digest()[8:], "little"), dtype=tf.uint64), + tf.constant( + int.from_bytes(m.digest()[:8], "little"), dtype=tf.uint64) + ])) + self.high_low_tensors = tf.stack(high_low_tensors) + # Related LLVM commit: https://github.com/llvm/llvm-project/pull/96276 + m = hashlib.md5() + m.update(self.tf_policy_names[0].encode("utf-8")) + self.high = int.from_bytes(m.digest()[8:], "little") + self.low = int.from_bytes(m.digest()[:8], "little") + self.high_low_tensor = tf.constant([self.high, self.low], dtype=tf.uint64) + + def _process_observation( + self, observation: types.NestedSpecTensorOrArray + ) -> Tuple[types.NestedSpecTensorOrArray, types.TensorOrArray]: + assert "model_selector" in self.sorted_keys + high_low_tensor = self.high_low_tensor + for name in self.sorted_keys: + if name in ["model_selector"]: + # model_selector is a Tensor of shape (1,) which requires indexing [0] + switch_tensor = observation.pop(name)[0] + high_low_tensor = switch_tensor + + tf.debugging.Assert( + tf.equal( + tf.reduce_any( + tf.reduce_all( + tf.equal(high_low_tensor, self.high_low_tensors), + axis=1)), True), + [high_low_tensor, self.high_low_tensors]) + + return observation, high_low_tensor + + def _action(self, + time_step: ts.TimeStep, + policy_state: types.NestedTensorSpec, + seed: Optional[types.Seed] = None) -> policy_step.PolicyStep: + new_observation = time_step.observation + new_observation, switch_tensor = self._process_observation(new_observation) + updated_step = ts.TimeStep( + step_type=time_step.step_type, + reward=time_step.reward, + discount=time_step.discount, + observation=new_observation) + + # TODO(359): We only support combining two policies. Generalize this to + # handle multiple policies. + def f0(): + return tf.cast( + self.tf_policies[0].action(updated_step).action[0], dtype=tf.int64) + + def f1(): + return tf.cast( + self.tf_policies[1].action(updated_step).action[0], dtype=tf.int64) + + action = tf.cond( + tf.math.reduce_all(tf.equal(switch_tensor, self.high_low_tensor)), f0, + f1) + return policy_step.PolicyStep(action=action, state=policy_state) + + def _distribution( + self, time_step: ts.TimeStep, + policy_state: types.NestedTensorSpec) -> policy_step.PolicyStep: + """Placeholder for distribution as every TFPolicy requires it.""" + return policy_step.PolicyStep( + action=tfp.distributions.Deterministic(2.), state=policy_state) diff --git a/compiler_opt/tools/combine_tfa_policies_lib_test.py b/compiler_opt/tools/combine_tfa_policies_lib_test.py new file mode 100644 index 00000000..7cd873c5 --- /dev/null +++ b/compiler_opt/tools/combine_tfa_policies_lib_test.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the combine_tfa_policies_lib.py module""" + +from absl.testing import absltest + +import tensorflow as tf +from compiler_opt.tools import combine_tfa_policies_lib +from tf_agents.trajectories import time_step as ts +import tf_agents +from tf_agents.specs import tensor_spec +from tf_agents.trajectories import policy_step +from tf_agents.typing import types +import hashlib +import numpy as np + + +def client_side_model_selector_calculation(policy_name: str) -> types.Tensor: + m = hashlib.md5() + m.update(policy_name.encode('utf-8')) + high = int.from_bytes(m.digest()[8:], 'little') + low = int.from_bytes(m.digest()[:8], 'little') + model_selector = tf.constant([[high, low]], dtype=tf.uint64) + return model_selector + + +class AddOnePolicy(tf_agents.policies.TFPolicy): + """Test policy which increments the obs feature.""" + + def __init__(self): + obs_spec = {'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)} + time_step_spec = ts.time_step_spec(obs_spec) + + act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) + + super().__init__(time_step_spec=time_step_spec, action_spec=act_spec) + + def _distribution(self, time_step): + """Boilerplate function for TFPolicy.""" + pass + + def _variables(self): + """Boilerplate function for TFPolicy.""" + return () + + def _action(self, time_step, policy_state, seed): + """Boilerplate function for TFPolicy.""" + observation = time_step.observation['obs'][0] + action = tf.reshape(observation + 1, (1,)) + return policy_step.PolicyStep(action, policy_state) + + +class SubtractOnePolicy(tf_agents.policies.TFPolicy): + """Test policy which decrements the obs feature.""" + + def __init__(self): + obs_spec = {'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)} + time_step_spec = ts.time_step_spec(obs_spec) + + act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) + + super().__init__(time_step_spec=time_step_spec, action_spec=act_spec) + + def _distribution(self, time_step): + """Boilerplate function for TFPolicy.""" + pass + + def _variables(self): + """Boilerplate function for TFPolicy.""" + return () + + def _action(self, time_step, policy_state, seed): + """Boilerplate function for TFPolicy.""" + observation = time_step.observation['obs'][0] + action = tf.reshape(observation - 1, (1,)) + return policy_step.PolicyStep(action, policy_state) + + +observation_spec = ts.time_step_spec({ + 'obs': + tf.TensorSpec(dtype=tf.int32, shape=(), name='obs'), + 'model_selector': + tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector') +}) + +action_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) + + +class CombinedTFPolicyTest(absltest.TestCase): + """Test for CombinedTFPolicy.""" + + def test_select_add_policy(self): + policy1 = AddOnePolicy() + policy2 = SubtractOnePolicy() + combined_policy = combine_tfa_policies_lib.CombinedTFPolicy( + tf_policies={ + 'add_one': policy1, + 'subtract_one': policy2 + }, + time_step_spec=observation_spec, + action_spec=action_spec) + + model_selector = client_side_model_selector_calculation('add_one') + + state = ts.TimeStep( + discount=tf.constant(np.array([0.]), dtype=tf.float32), + observation={ + 'obs': tf.constant(np.array([42]), dtype=tf.int64), + 'model_selector': model_selector + }, + reward=tf.constant(np.array([0]), dtype=tf.float64), + step_type=tf.constant(np.array([0]), dtype=tf.int64)) + + self.assertEqual( + combined_policy.action(state).action, tf.constant(43, dtype=tf.int64)) + + def test_select_subtract_policy(self): + policy1 = AddOnePolicy() + policy2 = SubtractOnePolicy() + combined_policy = combine_tfa_policies_lib.CombinedTFPolicy( + tf_policies={ + 'add_one': policy1, + 'subtract_one': policy2 + }, + time_step_spec=observation_spec, + action_spec=action_spec) + + model_selector = client_side_model_selector_calculation('subtract_one') + + state = ts.TimeStep( + discount=tf.constant(np.array([0.]), dtype=tf.float32), + observation={ + 'obs': tf.constant(np.array([42]), dtype=tf.int64), + 'model_selector': model_selector + }, + reward=tf.constant(np.array([0]), dtype=tf.float64), + step_type=tf.constant(np.array([0]), dtype=tf.int64)) + + self.assertEqual( + combined_policy.action(state).action, tf.constant(41, dtype=tf.int64))