Skip to content

Commit

Permalink
Policy combiner (#357)
Browse files Browse the repository at this point in the history
* Combine two tf_agents policies with time_step_spec and action_spec
given by registry problem config. The combiner policy uses a new 
time_step_spec feature "model_selector" to select the requested policy
at the current state. The feature is computed as a md5 hash from the
respective policies names.
  • Loading branch information
tvmarino committed Sep 10, 2024
1 parent 8d90940 commit 2878b51
Show file tree
Hide file tree
Showing 3 changed files with 357 additions and 0 deletions.
89 changes: 89 additions & 0 deletions compiler_opt/tools/combine_tfa_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs the policy combiner."""
from absl import app
from absl import flags
from absl import logging

import sys

import gin

import tensorflow as tf

from compiler_opt.rl import policy_saver
from compiler_opt.rl import registry
from compiler_opt.tools import combine_tfa_policies_lib as cfa_lib

_COMBINE_POLICIES_NAMES = flags.DEFINE_multi_string(
'policies_names', [], 'List in order of policy names for combined policies.'
'Order must match that of policies_paths.')
_COMBINE_POLICIES_PATHS = flags.DEFINE_multi_string(
'policies_paths', [], 'List in order of policy paths for combined policies.'
'Order must match that of policies_names.')
_COMBINED_POLICY_PATH = flags.DEFINE_string(
'combined_policy_path', '', 'Path to save the combined policy.')
_GIN_FILES = flags.DEFINE_multi_string(
'gin_files', [], 'List of paths to gin configuration files.')
_GIN_BINDINGS = flags.DEFINE_multi_string(
'gin_bindings', [],
'Gin bindings to override the values set in the config files.')


def main(_):
flags.mark_flag_as_required('policies_names')
flags.mark_flag_as_required('policies_paths')
flags.mark_flag_as_required('combined_policy_path')
if len(_COMBINE_POLICIES_NAMES.value) != len(_COMBINE_POLICIES_PATHS.value):
logging.error(
'Length of policies_names: %d must equal length of policies_paths: %d.',
len(_COMBINE_POLICIES_NAMES.value), len(_COMBINE_POLICIES_PATHS.value))
sys.exit(1)
gin.parse_config_files_and_bindings(
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False)

problem_config = registry.get_configuration()
expected_signature, action_spec = problem_config.get_signature_spec()
expected_signature.observation.update({
'model_selector':
tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector')
})
# TODO(359): We only support combining two policies.Generalize this to handle
# multiple policies.
if len(_COMBINE_POLICIES_NAMES.value) != 2:
logging.error('Policy combiner only supports two policies, %d given.',
len(_COMBINE_POLICIES_NAMES.value))
sys.exit(1)
policy1_name = _COMBINE_POLICIES_NAMES.value[0]
policy1_path = _COMBINE_POLICIES_PATHS.value[0]
policy2_name = _COMBINE_POLICIES_NAMES.value[1]
policy2_path = _COMBINE_POLICIES_PATHS.value[1]
policy1 = tf.saved_model.load(policy1_path, tags=None, options=None)
policy2 = tf.saved_model.load(policy2_path, tags=None, options=None)
combined_policy = cfa_lib.CombinedTFPolicy(
tf_policies={
policy1_name: policy1,
policy2_name: policy2
},
time_step_spec=expected_signature,
action_spec=action_spec)
combined_policy_path = _COMBINED_POLICY_PATH.value
policy_dict = {'combined_policy': combined_policy}
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
saver.save(combined_policy_path)


if __name__ == '__main__':
app.run(main)
116 changes: 116 additions & 0 deletions compiler_opt/tools/combine_tfa_policies_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Combines two tf-agent policies with the given state and action spec."""
from typing import Dict, Optional, Tuple

import tensorflow as tf
import hashlib

import tf_agents
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.trajectories import policy_step
import tensorflow_probability as tfp


class CombinedTFPolicy(tf_agents.policies.TFPolicy):
"""Policy which combines two target policies."""

def __init__(self, *args, tf_policies: Dict[str, tf_agents.policies.TFPolicy],
**kwargs):
super().__init__(*args, **kwargs)

self.tf_policies = []
self.tf_policy_names = []
for name, policy in tf_policies.items():
self.tf_policies.append(policy)
self.tf_policy_names.append(name)

self.expected_signature = self.time_step_spec
self.sorted_keys = sorted(self.expected_signature.observation.keys())

high_low_tensors = []
for name in self.tf_policy_names:
m = hashlib.md5()
m.update(name.encode("utf-8"))
high_low_tensors.append(
tf.stack([
tf.constant(
int.from_bytes(m.digest()[8:], "little"), dtype=tf.uint64),
tf.constant(
int.from_bytes(m.digest()[:8], "little"), dtype=tf.uint64)
]))
self.high_low_tensors = tf.stack(high_low_tensors)
# Related LLVM commit: https://github.com/llvm/llvm-project/pull/96276
m = hashlib.md5()
m.update(self.tf_policy_names[0].encode("utf-8"))
self.high = int.from_bytes(m.digest()[8:], "little")
self.low = int.from_bytes(m.digest()[:8], "little")
self.high_low_tensor = tf.constant([self.high, self.low], dtype=tf.uint64)

def _process_observation(
self, observation: types.NestedSpecTensorOrArray
) -> Tuple[types.NestedSpecTensorOrArray, types.TensorOrArray]:
assert "model_selector" in self.sorted_keys
high_low_tensor = self.high_low_tensor
for name in self.sorted_keys:
if name in ["model_selector"]:
# model_selector is a Tensor of shape (1,) which requires indexing [0]
switch_tensor = observation.pop(name)[0]
high_low_tensor = switch_tensor

tf.debugging.Assert(
tf.equal(
tf.reduce_any(
tf.reduce_all(
tf.equal(high_low_tensor, self.high_low_tensors),
axis=1)), True),
[high_low_tensor, self.high_low_tensors])

return observation, high_low_tensor

def _action(self,
time_step: ts.TimeStep,
policy_state: types.NestedTensorSpec,
seed: Optional[types.Seed] = None) -> policy_step.PolicyStep:
new_observation = time_step.observation
new_observation, switch_tensor = self._process_observation(new_observation)
updated_step = ts.TimeStep(
step_type=time_step.step_type,
reward=time_step.reward,
discount=time_step.discount,
observation=new_observation)

# TODO(359): We only support combining two policies. Generalize this to
# handle multiple policies.
def f0():
return tf.cast(
self.tf_policies[0].action(updated_step).action[0], dtype=tf.int64)

def f1():
return tf.cast(
self.tf_policies[1].action(updated_step).action[0], dtype=tf.int64)

action = tf.cond(
tf.math.reduce_all(tf.equal(switch_tensor, self.high_low_tensor)), f0,
f1)
return policy_step.PolicyStep(action=action, state=policy_state)

def _distribution(
self, time_step: ts.TimeStep,
policy_state: types.NestedTensorSpec) -> policy_step.PolicyStep:
"""Placeholder for distribution as every TFPolicy requires it."""
return policy_step.PolicyStep(
action=tfp.distributions.Deterministic(2.), state=policy_state)
152 changes: 152 additions & 0 deletions compiler_opt/tools/combine_tfa_policies_lib_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the combine_tfa_policies_lib.py module"""

from absl.testing import absltest

import tensorflow as tf
from compiler_opt.tools import combine_tfa_policies_lib
from tf_agents.trajectories import time_step as ts
import tf_agents
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import policy_step
from tf_agents.typing import types
import hashlib
import numpy as np


def client_side_model_selector_calculation(policy_name: str) -> types.Tensor:
m = hashlib.md5()
m.update(policy_name.encode('utf-8'))
high = int.from_bytes(m.digest()[8:], 'little')
low = int.from_bytes(m.digest()[:8], 'little')
model_selector = tf.constant([[high, low]], dtype=tf.uint64)
return model_selector


class AddOnePolicy(tf_agents.policies.TFPolicy):
"""Test policy which increments the obs feature."""

def __init__(self):
obs_spec = {'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)}
time_step_spec = ts.time_step_spec(obs_spec)

act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)

super().__init__(time_step_spec=time_step_spec, action_spec=act_spec)

def _distribution(self, time_step):
"""Boilerplate function for TFPolicy."""
pass

def _variables(self):
"""Boilerplate function for TFPolicy."""
return ()

def _action(self, time_step, policy_state, seed):
"""Boilerplate function for TFPolicy."""
observation = time_step.observation['obs'][0]
action = tf.reshape(observation + 1, (1,))
return policy_step.PolicyStep(action, policy_state)


class SubtractOnePolicy(tf_agents.policies.TFPolicy):
"""Test policy which decrements the obs feature."""

def __init__(self):
obs_spec = {'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)}
time_step_spec = ts.time_step_spec(obs_spec)

act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)

super().__init__(time_step_spec=time_step_spec, action_spec=act_spec)

def _distribution(self, time_step):
"""Boilerplate function for TFPolicy."""
pass

def _variables(self):
"""Boilerplate function for TFPolicy."""
return ()

def _action(self, time_step, policy_state, seed):
"""Boilerplate function for TFPolicy."""
observation = time_step.observation['obs'][0]
action = tf.reshape(observation - 1, (1,))
return policy_step.PolicyStep(action, policy_state)


observation_spec = ts.time_step_spec({
'obs':
tf.TensorSpec(dtype=tf.int32, shape=(), name='obs'),
'model_selector':
tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector')
})

action_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)


class CombinedTFPolicyTest(absltest.TestCase):
"""Test for CombinedTFPolicy."""

def test_select_add_policy(self):
policy1 = AddOnePolicy()
policy2 = SubtractOnePolicy()
combined_policy = combine_tfa_policies_lib.CombinedTFPolicy(
tf_policies={
'add_one': policy1,
'subtract_one': policy2
},
time_step_spec=observation_spec,
action_spec=action_spec)

model_selector = client_side_model_selector_calculation('add_one')

state = ts.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'obs': tf.constant(np.array([42]), dtype=tf.int64),
'model_selector': model_selector
},
reward=tf.constant(np.array([0]), dtype=tf.float64),
step_type=tf.constant(np.array([0]), dtype=tf.int64))

self.assertEqual(
combined_policy.action(state).action, tf.constant(43, dtype=tf.int64))

def test_select_subtract_policy(self):
policy1 = AddOnePolicy()
policy2 = SubtractOnePolicy()
combined_policy = combine_tfa_policies_lib.CombinedTFPolicy(
tf_policies={
'add_one': policy1,
'subtract_one': policy2
},
time_step_spec=observation_spec,
action_spec=action_spec)

model_selector = client_side_model_selector_calculation('subtract_one')

state = ts.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'obs': tf.constant(np.array([42]), dtype=tf.int64),
'model_selector': model_selector
},
reward=tf.constant(np.array([0]), dtype=tf.float64),
step_type=tf.constant(np.array([0]), dtype=tf.int64))

self.assertEqual(
combined_policy.action(state).action, tf.constant(41, dtype=tf.int64))

0 comments on commit 2878b51

Please sign in to comment.