Skip to content

Commit

Permalink
minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
boomanaiden154 committed Sep 8, 2024
1 parent 2982d39 commit 72990ea
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 48 deletions.
5 changes: 2 additions & 3 deletions compiler_opt/es/blackbox_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,9 @@ def _log_tf_summary(self, rewards: List[float]) -> None:
with self._summary_writer.as_default():
tf.summary.scalar(
'reward/average_reward_train', np.mean(rewards), step=self._step)

tf.summary.scalar(
'reward/maximum_reward_train', np.max(rewards), step=self._step
)
'reward/maximum_reward_train', np.max(rewards), step=self._step)

tf.summary.histogram('reward/reward_train', rewards, step=self._step)

Expand Down
37 changes: 15 additions & 22 deletions compiler_opt/es/es_trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,12 @@
"pretrained_policy_path", None,
"The path of the pretrained policy. If not provided, it will \
construct a new policy with randomly initialized weights.")
_CORPUS_DIR = flags.DEFINE_string("corpus_dir", None,
"The path to the corpus to use")
_CLANG_PATH = flags.DEFINE_string("clang_path", None,
"The path to the clang binary to use.")
_TRACE_PATH = flags.DEFINE_string("trace_path", None,
"The path to the BB trace to use.")
_BB_TRACE_MODEL_PATH = flags.DEFINE_string(
"bb_trace_model_path", None,
"THe path to the basic_block_trace_model binary to use.")

_CORPUS_DIR = '/usr/local/google/home/aidengrossman/opt_mlregalloc/corpus_subset'
_CLANG_PATH = '/usr/local/google/home/aidengrossman/opt_mlregalloc/clang'
_TRACE_PATH = '/usr/local/google/home/aidengrossman/opt_mlregalloc/bb_trace.pb'
_FUNCTION_INDEX_PATH = '/usr/local/google/home/aidengrossman/opt_mlregalloc/function_index.pb'
_BB_TRACE_MODEL_PATH = '/usr/local/google/home/aidengrossman/opt_mlregalloc/basic_block_trace_model'


class ESWorker(worker.Worker):
Expand All @@ -76,11 +73,7 @@ def __init__(self, *, all_gin):
self._template_dir = tempfile.mkdtemp()
saver.save(self._template_dir)

self._corpus_dir = '/usr/local/google/home/aidengrossman/programming/opt_mlregalloc/corpus'
self._clang_path = '/usr/local/google/home/aidengrossman/programming/opt_mlregalloc/clang'
self._trace_path = '/usr/local/google/home/aidengrossman/programming/opt_mlregalloc/execution_trace.pb'
self._bb_trace_model_path = '/usr/local/google/home/aidengrossman/programming/opt_mlregalloc/basic_block_trace_model'
self._models_for_test_path = '/usr/local/google/home/aidengrossman/programming/output_traces/'
self._models_for_test_path = '/usr/local/google/home/aidengrossman/output_models/'

def es_compile(self, params: list[float], baseline_score: float) -> float:
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -97,15 +90,16 @@ def es_compile(self, params: list[float], baseline_score: float) -> float:
policy_saver.OUTPUT_SIGNATURE),
os.path.join(tflitedir, policy_saver.OUTPUT_SIGNATURE))

trace_data_collector.compile_corpus(self._corpus_dir, tempdir,
self._clang_path, tflitedir, single_threaded=True)
trace_data_collector.compile_corpus(
_CORPUS_DIR, tempdir, _CLANG_PATH, tflitedir, thread_count=4)
score = trace_data_collector.evaluate_compiled_corpus(
tempdir, self._trace_path, self._bb_trace_model_path)
tempdir, _TRACE_PATH, _FUNCTION_INDEX_PATH, _BB_TRACE_MODEL_PATH, 4)

reward = compilation_runner._calculate_reward(score, baseline_score)
print(reward)

output_path = os.path.join(self._models_for_test_path, "model" + str(reward))
output_path = os.path.join(self._models_for_test_path,
"model" + str(reward))
if reward > 0 and not os.path.exists(output_path):
shutil.copytree(tflitedir, output_path)
return compilation_runner._calculate_reward(score, baseline_score)
Expand Down Expand Up @@ -235,10 +229,9 @@ def train(worker_class=None):

# Get baseline score
with tempfile.TemporaryDirectory() as tempdir:
trace_data_collector.compile_corpus(_CORPUS_DIR.value, tempdir,
_CLANG_PATH.value)
trace_data_collector.compile_corpus(_CORPUS_DIR, tempdir, _CLANG_PATH)
baseline_score = trace_data_collector.evaluate_compiled_corpus(
tempdir, _TRACE_PATH.value, _BB_TRACE_MODEL_PATH.value)
tempdir, _TRACE_PATH, _FUNCTION_INDEX_PATH, _BB_TRACE_MODEL_PATH)

logging.info("Initializing blackbox learner.")
learner = blackbox_learner.BlackboxLearner(
Expand Down
6 changes: 3 additions & 3 deletions compiler_opt/es/gin_configs/blackbox_learner.gin
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import compiler_opt.rl.gin_external_configurables
import compiler_opt.es.blackbox_optimizers

# Blackbox learner config
BlackboxLearnerConfig.total_steps = 100
BlackboxLearnerConfig.total_num_perturbations = 100
BlackboxLearnerConfig.total_steps = 10000
BlackboxLearnerConfig.total_num_perturbations = 25
BlackboxLearnerConfig.blackbox_optimizer = %blackbox_optimizers.Algorithm.MONTE_CARLO
BlackboxLearnerConfig.est_type = %blackbox_optimizers.EstimatorType.ANTITHETIC
# BlackboxLearnerConfig.est_type = %blackbox_optimizers.EstimatorType.FORWARD_FD
Expand All @@ -17,5 +17,5 @@ BlackboxLearnerConfig.num_top_directions = 0
BlackboxLearnerConfig.precision_parameter = 0.5

# Try the 0.0005 step size next
BlackboxLearnerConfig.step_size = 0.005
BlackboxLearnerConfig.step_size = 0.005
# BlackboxLearnerConfig.step_size = 0.0005
51 changes: 31 additions & 20 deletions compiler_opt/rl/trace_data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import multiprocessing


def compile_module(module_path, corpus_path, clang_path, tflite_dir, output_path):
def compile_module(module_path, corpus_path, clang_path, tflite_dir,
output_path):
module_full_input_path = os.path.join(corpus_path, module_path) + '.bc'
module_full_output_path = os.path.join(output_path, module_path) + '.bc.o'
module_command_full_path = os.path.join(corpus_path, module_path) + '.cmd'
Expand All @@ -36,20 +37,21 @@ def compile_module(module_path, corpus_path, clang_path, tflite_dir, output_path

command_vector = [clang_path]
command_vector.extend(module_command_line)
command_vector.extend(
[module_full_input_path, '-o', module_full_output_path])
command_vector.extend([module_full_input_path, '-o', module_full_output_path])

if tflite_dir is not None:
command_vector.extend(['-mllvm', '-regalloc-enable-advisor=development'])
command_vector.extend(['-mllvm', '-regalloc-model=' + tflite_dir])

subprocess.run(command_vector, check=True)
logging.info(
f'Just finished compiling {module_full_output_path}'
)
logging.info(f'Just finished compiling {module_full_output_path}')


def compile_corpus(corpus_path, output_path, clang_path, tflite_dir=None, single_threaded=False):
def compile_corpus(corpus_path,
output_path,
clang_path,
tflite_dir=None,
thread_count=multiprocessing.cpu_count()):
with open(
os.path.join(corpus_path, 'corpus_description.json'),
encoding='utf-8') as corpus_description_handle:
Expand All @@ -61,31 +63,40 @@ def compile_corpus(corpus_path, output_path, clang_path, tflite_dir=None, single
# Compile each module.
to_compile.append(module_path)

if single_threaded:
for module_to_compile in to_compile:
compile_module(module_to_compile, corpus_path, clang_path, tflite_dir, output_path)
else:
with multiprocessing.Pool() as pool:
pool.map(functools.partial(compile_module, corpus_path=corpus_path, clang_path=clang_path, tflite_dir=tflite_dir, output_path=output_path), to_compile)
with multiprocessing.Pool(thread_count) as pool:
pool.map(
functools.partial(
compile_module,
corpus_path=corpus_path,
clang_path=clang_path,
tflite_dir=tflite_dir,
output_path=output_path), to_compile)

shutil.copy(
os.path.join(corpus_path, 'corpus_description.json'),
os.path.join(output_path, 'corpus_description.json'))


def evaluate_compiled_corpus(compiled_corpus_path, trace_path,
bb_trace_model_path):
corpus_description_path = os.path.join(compiled_corpus_path, 'corpus_description.json')
function_index_path, bb_trace_model_path, thread_count=multiprocessing.cpu_count()):
corpus_description_path = os.path.join(compiled_corpus_path,
'corpus_description.json')
command_vector = [
bb_trace_model_path, '--bb_trace_path=' + trace_path,
'--corpus_path=' + corpus_description_path, '--cpu_name=skylake-avx512'
'--corpus_path=' + corpus_description_path, '--cpu_name=skylake-avx512',
'--function_index_path=' + function_index_path, '--thread_count=' + str(thread_count)
]

process_return = subprocess.run(
command_vector,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
command_vector, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

output = process_return.stdout.decode('utf-8')

return int(output)
total_cost = 0.0

for line in output.split('\n'):
if line == '':
continue
total_cost += float(line)

return total_cost

0 comments on commit 72990ea

Please sign in to comment.