Skip to content

Commit

Permalink
Serialize entire policy in blackbox_learner (#366)
Browse files Browse the repository at this point in the history
This patch adjusts blackbox_learner so that it returns an entire policy
rather than just the bytes of the policy. When actually running
evaluations, we need to writ out the full policy, including the output
spec, to disk so the compiler can pick it up. Before this patch, we were
not passing along the output spec to the worker.
  • Loading branch information
boomanaiden154 authored Sep 16, 2024
1 parent cf2f790 commit 5230947
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
9 changes: 5 additions & 4 deletions compiler_opt/es/blackbox_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from compiler_opt.rl import corpus
from compiler_opt.es import blackbox_optimizers
from compiler_opt.distributed import buffered_scheduler
from compiler_opt.rl import policy_saver


class BlackboxEvaluator(metaclass=abc.ABCMeta):
Expand All @@ -36,8 +37,8 @@ def __init__(self, train_corpus: corpus.Corpus):

@abc.abstractmethod
def get_results(
self, pool: FixedWorkerPool,
perturbations: List[bytes]) -> List[concurrent.futures.Future]:
self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy]
) -> List[concurrent.futures.Future]:
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -66,8 +67,8 @@ def __init__(self, train_corpus: corpus.Corpus,
super().__init__(train_corpus)

def get_results(
self, pool: FixedWorkerPool,
perturbations: List[bytes]) -> List[concurrent.futures.Future]:
self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy]
) -> List[concurrent.futures.Future]:
if not self._samples:
for _ in range(self._total_num_perturbations):
sample = self._train_corpus.sample(self._num_ir_repeats_within_worker)
Expand Down
18 changes: 9 additions & 9 deletions compiler_opt/es/blackbox_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ def _save_model(self) -> None:
def get_model_weights(self) -> npt.NDArray[np.float32]:
return self._model_weights

def _get_policy_as_bytes(self,
perturbation: npt.NDArray[np.float32]) -> bytes:
# TODO: The current conversion is inefficient (performance-wise). We should
# consider doing this on the worker side.
def _get_policy_from_perturbation(
self, perturbation: npt.NDArray[np.float32]) -> policy_saver.Policy:
sm = tf.saved_model.load(self._tf_policy_path)
# devectorize the perturbation
policy_utils.set_vectorized_parameters_for_policy(sm, perturbation)
Expand All @@ -242,7 +244,7 @@ def _get_policy_as_bytes(self,

# create and return policy
policy_obj = policy_saver.Policy.from_filesystem(tfl_dir)
return policy_obj.policy
return policy_obj

def run_step(self, pool: FixedWorkerPool) -> None:
"""Run a single step of blackbox learning.
Expand All @@ -258,14 +260,12 @@ def run_step(self, pool: FixedWorkerPool) -> None:
p for p in initial_perturbations for p in (p, -p)
]

# convert to bytes for compile job
# TODO: current conversion is inefficient.
# consider doing this on the worker side
perturbations_as_bytes = []
perturbations_as_policies = []
for perturbation in initial_perturbations:
perturbations_as_bytes.append(self._get_policy_as_bytes(perturbation))
perturbations_as_policies.append(
self._get_policy_from_perturbation(perturbation))

results = self._evaluator.get_results(pool, perturbations_as_bytes)
results = self._evaluator.get_results(pool, perturbations_as_policies)
rewards = self._evaluator.get_rewards(results)

num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards)
Expand Down
3 changes: 2 additions & 1 deletion compiler_opt/es/blackbox_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __init__(self, arg, *, kwarg):
self._kwarg = kwarg
self.function_value = 0.0

def compile(self, policy: bytes, samples: List[corpus.ModuleSpec]) -> float:
def compile(self, policy: policy_saver.Policy,
samples: List[corpus.ModuleSpec]) -> float:
if policy and samples:
self.function_value += 1.0
return self.function_value
Expand Down

0 comments on commit 5230947

Please sign in to comment.