Skip to content

Commit

Permalink
Merge pull request #384 from mlcommons/seed_checker_updates
Browse files Browse the repository at this point in the history
Check seeds are always logged
  • Loading branch information
hiwotadese authored Sep 26, 2024
2 parents 191ed2c + 24bed7e commit 9a66740
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 43 deletions.
27 changes: 20 additions & 7 deletions mlperf_logging/package_checker/package_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _print_divider_bar():


def check_training_result_files(folder, usage, ruleset, quiet, werror,
rcp_bypass, rcp_bert_train_samples):
rcp_bypass, rcp_bert_train_samples, seed_checker_bypass):
"""Checks all result files for compliance.
Args:
Expand All @@ -44,6 +44,7 @@ def check_training_result_files(folder, usage, ruleset, quiet, werror,
allowed_benchmarks = get_allowed_benchmarks(usage, ruleset)
benchmark_file_counts = get_result_file_counts(usage)
global_rcp_bypass = rcp_bypass
global_seed_checker_bypass = seed_checker_bypass

seed_checker = SeedChecker(ruleset)
too_many_errors = False
Expand All @@ -58,12 +59,15 @@ def check_training_result_files(folder, usage, ruleset, quiet, werror,
# Set system wide rcp-bypass
params_path = os.path.join(system_folder, "package_checker_params")
system_rcp_bypass = False
system_seed_checker_bypass = False
if os.path.exists(params_path):
with open(params_path) as f:
lines = f.readlines()
for line in lines:
if line == "rcp-bypass":
system_rcp_bypass = True
if line == "seed-checker-bypass":
system_seed_checker_bypass = True
for benchmark_folder in benchmark_folders:
folder_parts = benchmark_folder.split('/')
benchmark = folder_parts[-1]
Expand Down Expand Up @@ -98,13 +102,16 @@ def check_training_result_files(folder, usage, ruleset, quiet, werror,
params_files = []
params_path = os.path.join(benchmark_folder, "package_checker_params")
result_rcp_bypass = False
result_seed_checker_bypass = False
if os.path.exists(params_path):
params_files.append(params_path)
with open(params_path) as f:
lines = f.readlines()
for line in lines:
if line == "rcp-bypass":
result_rcp_bypass = True
if line == "seed-checker-bypass":
result_seed_checker_bypass = True

# Find all source codes for this benchmark.
source_files = find_source_files_under(
Expand Down Expand Up @@ -176,7 +183,8 @@ def check_training_result_files(folder, usage, ruleset, quiet, werror,

# Check if each run use unique seeds.
if ruleset in {'1.0.0', '1.1.0', '2.0.0', '2.1.0', '3.0.0', '3.1.0', '4.0.0', '4.1.0'} and division == 'closed':
if not seed_checker.check_seeds(result_files, source_files):
seed_checker_bypass = (global_seed_checker_bypass or system_seed_checker_bypass or result_seed_checker_bypass)
if not seed_checker.check_seeds(result_files, seed_checker_bypass):
too_many_errors = True
logging.error('Seed checker failed')

Expand Down Expand Up @@ -226,7 +234,7 @@ def check_systems(folder, usage, ruleset):

return not too_many_errors

def check_training_package(folder, usage, ruleset, quiet, werror, rcp_bypass, rcp_bert_train_samples, log_output):
def check_training_package(folder, usage, ruleset, quiet, werror, rcp_bypass, rcp_bert_train_samples, seed_checker_bypass, log_output):
"""Checks a training package for compliance.
Args:
Expand All @@ -242,7 +250,7 @@ def check_training_package(folder, usage, ruleset, quiet, werror, rcp_bypass, rc
if not system_description_pass:
logging.error('System description file checker failed')

training_pass = check_training_result_files(folder, usage, ruleset, quiet, werror, rcp_bypass, rcp_bert_train_samples)
training_pass = check_training_result_files(folder, usage, ruleset, quiet, werror, rcp_bypass, rcp_bert_train_samples, seed_checker_bypass)
too_many_errors = too_many_errors or not training_pass
if too_many_errors:
logging.info('PACKAGE CHECKER FOUND ERRORS, LOOK INTO ERROR LOG LINES AND FIX THEM.')
Expand Down Expand Up @@ -286,17 +294,22 @@ def get_parser():
help='Suppress warnings. Does nothing if --werror is set',
)
parser.add_argument(
'--rcp_bypass',
'--rcp-bypass',
action='store_true',
help='Bypass failed RCP checks so that submission uploads'
)
parser.add_argument(
'--rcp_bert_train_samples',
'--rcp-bert-train-samples',
action='store_true',
help='If set, num samples used for training '
'bert benchmark is taken from train_samples, '
'istead of epoch_num',
)
parser.add_argument(
'--seed-checker-bypass',
action='store_true',
help='If set, Seed checker is bypassed '
)
parser.add_argument(
'--log_output',
type=str,
Expand All @@ -317,7 +330,7 @@ def main():
logging.getLogger().handlers[1].setFormatter(formatter)

check_training_package(args.folder, args.usage, args.ruleset, args.quiet, args.werror,
args.rcp_bypass, args.rcp_bert_train_samples, args.log_output)
args.rcp_bypass, args.rcp_bert_train_samples, args.seed_checker_bypass, args.log_output)


if __name__ == '__main__':
Expand Down
58 changes: 22 additions & 36 deletions mlperf_logging/package_checker/seed_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ class SeedChecker:
""" Check if the seeds fit MLPerf submission requirements.
Current requirements are:
1. All seeds must be logged through mllog (if choose to log seeds). Any seed
logged via any other method will be discarded.
2. All seeds, if choose to be logged, must be valid integers (convertible
via int()).
3. If any run logs at least one seed, we expect all runs to log at least
one seed.
1. All seeds must be logged through mllog. Any seed logged via any other
method will be discarded.
2. All seeds, must be valid integers (convertible via int()).
3. We expect all runs to log at least one seed.
4. If one run logs one seed on a certain line in a certain source file, no
other run can log the same seed on the same line in the same file.
Expand All @@ -59,10 +57,6 @@ class SeedChecker:
A warning is raised for the following situations:
1. Any run logs more than one seed.
2. No seed is logged, however, the source files (after being converted to
lowercase characters) contain the keyword "seed". What files are
considered as source files are defined in SOURCE_FILE_EXT and
is_source_file().
"""
def __init__(self, ruleset):
self._ruleset = ruleset
Expand Down Expand Up @@ -96,14 +90,11 @@ def _assert_unique_seed_per_run(self, result_files):
"{}: {}".format(result_file, e))
continue

if not no_logged_seed and len(seed_records) == 0:
no_logged_seed = (len(seed_records) <= 0)
if no_logged_seed:
error_messages.append(
"Result file {} logs no seed. However, other "
"result files, including {}, already logs some "
"seeds.".format(result_file,
list(seed_to_result_file.keys())))
if no_logged_seed and len(seed_records) > 0:
no_logged_seed = False
"Result file {} logs no seed.".format(result_file)
)
if len(seed_records) > 1:
warnings.warn(
"Result file {} logs more than one seeds {}!".format(
Expand All @@ -123,7 +114,7 @@ def _assert_unique_seed_per_run(self, result_files):
else:
seed_to_result_file[(f, ln, s)] = result_file

return no_logged_seed, error_messages
return error_messages

def _has_seed_keyword(self, source_file):
with open(source_file, 'r') as file_handle:
Expand All @@ -132,31 +123,26 @@ def _has_seed_keyword(self, source_file):
return True
return False

def check_seeds(self, result_files, source_files):
def check_seeds(self, result_files, seed_checker_bypass = False):
""" Check the seeds for a specific benchmark submission.
Args:
result_files: An iterable contains paths to all the result files for
this benchmark.
source_files: An iterable contains paths to all the source files for
this benchmark.
"""
_print_divider_bar()
logging.info(" Running Seed Checker")
no_logged_seed, error_messages = self._assert_unique_seed_per_run(
result_files)

if len(error_messages) > 0:
logging.error(" Seed checker failed and found the following errors: %s", '\n'.join(error_messages))
#print("Seed checker failed and found the following "
# "errors:\n{}".format('\n'.join(error_messages)))
return False

if no_logged_seed:
for source_file in source_files:
if self._has_seed_keyword(source_file):
warnings.warn(
"Source file {} contains the keyword 'seed' but no "
"seed value is logged!".format(source_file))
if seed_checker_bypass:
logging.info("Bypassing Seed Checker")
else:
error_messages = self._assert_unique_seed_per_run(
result_files
)

if len(error_messages) > 0:
logging.error(" Seed checker failed and found the following errors: %s", '\n'.join(error_messages))
#print("Seed checker failed and found the following "
# "errors:\n{}".format('\n'.join(error_messages)))
return False
return True

0 comments on commit 9a66740

Please sign in to comment.