From e362b85dece8c06c7f5513113a118727b317adf7 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Thu, 12 Sep 2024 23:10:08 +0000 Subject: [PATCH] Rename corpus reference to train_corpus in blackbox_learner.py This patch renames the corpus refereince in blackbox_learner.py from sampler to train_corpus to better represent what it actually is. This patch also fixes an incorrect string as a byproduct of this where an argument name in the docstring did not match the actual argument name. --- compiler_opt/es/blackbox_learner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/compiler_opt/es/blackbox_learner.py b/compiler_opt/es/blackbox_learner.py index 289d3e57..b37a1f0d 100644 --- a/compiler_opt/es/blackbox_learner.py +++ b/compiler_opt/es/blackbox_learner.py @@ -128,7 +128,7 @@ class BlackboxLearner: def __init__(self, blackbox_opt: blackbox_optimizers.BlackboxOptimizer, - sampler: corpus.Corpus, + train_corpus: corpus.Corpus, tf_policy_path: str, output_dir: str, policy_saver_fn: PolicySaverCallableType, @@ -141,7 +141,7 @@ def __init__(self, Args: blackbox_opt: the blackbox optimizer to use - train_sampler: corpus_sampler for training data. + train_corpus: the training corpus to utiilize tf_policy_path: where to write the tf policy output_dir: the directory to write all outputs policy_saver_fn: function to save a policy to cns @@ -152,7 +152,7 @@ def __init__(self, deadline: the deadline in seconds for requests to the inlining server. """ self._blackbox_opt = blackbox_opt - self._sampler = sampler + self._train_corpus = train_corpus self._tf_policy_path = tf_policy_path self._output_dir = output_dir self._policy_saver_fn = policy_saver_fn @@ -250,7 +250,8 @@ def _get_results( perturbations: List[bytes]) -> List[concurrent.futures.Future]: if not self._samples: for _ in range(self._config.total_num_perturbations): - sample = self._sampler.sample(self._config.num_ir_repeats_within_worker) + sample = self._train_corpus.sample( + self._config.num_ir_repeats_within_worker) self._samples.append(sample) # add copy of sample for antithetic perturbation pair if self._config.est_type == (