Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check seeds are always logged #384

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions mlperf_logging/package_checker/package_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _print_divider_bar():


def check_training_result_files(folder, usage, ruleset, quiet, werror,
rcp_bypass, rcp_bert_train_samples):
rcp_bypass, rcp_bert_train_samples, seed_checker_bypass):
"""Checks all result files for compliance.

Args:
Expand All @@ -44,6 +44,7 @@ def check_training_result_files(folder, usage, ruleset, quiet, werror,
allowed_benchmarks = get_allowed_benchmarks(usage, ruleset)
benchmark_file_counts = get_result_file_counts(usage)
global_rcp_bypass = rcp_bypass
global_seed_checker_bypass = seed_checker_bypass

seed_checker = SeedChecker(ruleset)
too_many_errors = False
Expand All @@ -58,12 +59,15 @@ def check_training_result_files(folder, usage, ruleset, quiet, werror,
# Set system wide rcp-bypass
params_path = os.path.join(system_folder, "package_checker_params")
system_rcp_bypass = False
system_seed_checker_bypass = False
if os.path.exists(params_path):
with open(params_path) as f:
lines = f.readlines()
for line in lines:
if line == "rcp-bypass":
system_rcp_bypass = True
if line == "seed-checker-bypass":
system_seed_checker_bypass = True
for benchmark_folder in benchmark_folders:
folder_parts = benchmark_folder.split('/')
benchmark = folder_parts[-1]
Expand Down Expand Up @@ -98,13 +102,16 @@ def check_training_result_files(folder, usage, ruleset, quiet, werror,
params_files = []
params_path = os.path.join(benchmark_folder, "package_checker_params")
result_rcp_bypass = False
result_seed_checker_bypass = False
if os.path.exists(params_path):
params_files.append(params_path)
with open(params_path) as f:
lines = f.readlines()
for line in lines:
if line == "rcp-bypass":
result_rcp_bypass = True
if line == "seed-checker-bypass":
result_seed_checker_bypass = True

# Find all source codes for this benchmark.
source_files = find_source_files_under(
Expand Down Expand Up @@ -176,7 +183,8 @@ def check_training_result_files(folder, usage, ruleset, quiet, werror,

# Check if each run use unique seeds.
if ruleset in {'1.0.0', '1.1.0', '2.0.0', '2.1.0', '3.0.0', '3.1.0', '4.0.0', '4.1.0'} and division == 'closed':
if not seed_checker.check_seeds(result_files, source_files):
seed_checker_bypass = (global_seed_checker_bypass or system_seed_checker_bypass or result_seed_checker_bypass)
if not seed_checker.check_seeds(result_files, seed_checker_bypass):
too_many_errors = True
logging.error('Seed checker failed')

Expand Down Expand Up @@ -226,7 +234,7 @@ def check_systems(folder, usage, ruleset):

return not too_many_errors

def check_training_package(folder, usage, ruleset, quiet, werror, rcp_bypass, rcp_bert_train_samples, log_output):
def check_training_package(folder, usage, ruleset, quiet, werror, rcp_bypass, rcp_bert_train_samples, seed_checker_bypass, log_output):
"""Checks a training package for compliance.

Args:
Expand All @@ -242,7 +250,7 @@ def check_training_package(folder, usage, ruleset, quiet, werror, rcp_bypass, rc
if not system_description_pass:
logging.error('System description file checker failed')

training_pass = check_training_result_files(folder, usage, ruleset, quiet, werror, rcp_bypass, rcp_bert_train_samples)
training_pass = check_training_result_files(folder, usage, ruleset, quiet, werror, rcp_bypass, rcp_bert_train_samples, seed_checker_bypass)
too_many_errors = too_many_errors or not training_pass
if too_many_errors:
logging.info('PACKAGE CHECKER FOUND ERRORS, LOOK INTO ERROR LOG LINES AND FIX THEM.')
Expand Down Expand Up @@ -286,17 +294,22 @@ def get_parser():
help='Suppress warnings. Does nothing if --werror is set',
)
parser.add_argument(
'--rcp_bypass',
'--rcp-bypass',
action='store_true',
help='Bypass failed RCP checks so that submission uploads'
)
parser.add_argument(
'--rcp_bert_train_samples',
'--rcp-bert-train-samples',
action='store_true',
help='If set, num samples used for training '
'bert benchmark is taken from train_samples, '
'istead of epoch_num',
)
parser.add_argument(
'--seed-checker-bypass',
action='store_true',
help='If set, Seed checker is bypassed '
)
parser.add_argument(
'--log_output',
type=str,
Expand All @@ -317,7 +330,7 @@ def main():
logging.getLogger().handlers[1].setFormatter(formatter)

check_training_package(args.folder, args.usage, args.ruleset, args.quiet, args.werror,
args.rcp_bypass, args.rcp_bert_train_samples, args.log_output)
args.rcp_bypass, args.rcp_bert_train_samples, args.seed_checker_bypass, args.log_output)


if __name__ == '__main__':
Expand Down
58 changes: 22 additions & 36 deletions mlperf_logging/package_checker/seed_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ class SeedChecker:
""" Check if the seeds fit MLPerf submission requirements.
Current requirements are:

1. All seeds must be logged through mllog (if choose to log seeds). Any seed
logged via any other method will be discarded.
2. All seeds, if choose to be logged, must be valid integers (convertible
via int()).
3. If any run logs at least one seed, we expect all runs to log at least
one seed.
1. All seeds must be logged through mllog. Any seed logged via any other
method will be discarded.
2. All seeds, must be valid integers (convertible via int()).
3. We expect all runs to log at least one seed.
4. If one run logs one seed on a certain line in a certain source file, no
other run can log the same seed on the same line in the same file.

Expand All @@ -59,10 +57,6 @@ class SeedChecker:
A warning is raised for the following situations:

1. Any run logs more than one seed.
2. No seed is logged, however, the source files (after being converted to
lowercase characters) contain the keyword "seed". What files are
considered as source files are defined in SOURCE_FILE_EXT and
is_source_file().
"""
def __init__(self, ruleset):
self._ruleset = ruleset
Expand Down Expand Up @@ -96,14 +90,11 @@ def _assert_unique_seed_per_run(self, result_files):
"{}: {}".format(result_file, e))
continue

if not no_logged_seed and len(seed_records) == 0:
no_logged_seed = (len(seed_records) <= 0)
if no_logged_seed:
error_messages.append(
"Result file {} logs no seed. However, other "
"result files, including {}, already logs some "
"seeds.".format(result_file,
list(seed_to_result_file.keys())))
if no_logged_seed and len(seed_records) > 0:
no_logged_seed = False
"Result file {} logs no seed.".format(result_file)
)
if len(seed_records) > 1:
warnings.warn(
"Result file {} logs more than one seeds {}!".format(
Expand All @@ -123,7 +114,7 @@ def _assert_unique_seed_per_run(self, result_files):
else:
seed_to_result_file[(f, ln, s)] = result_file

return no_logged_seed, error_messages
return error_messages

def _has_seed_keyword(self, source_file):
with open(source_file, 'r') as file_handle:
Expand All @@ -132,31 +123,26 @@ def _has_seed_keyword(self, source_file):
return True
return False

def check_seeds(self, result_files, source_files):
def check_seeds(self, result_files, seed_checker_bypass = False):
""" Check the seeds for a specific benchmark submission.

Args:
result_files: An iterable contains paths to all the result files for
this benchmark.
source_files: An iterable contains paths to all the source files for
this benchmark.

"""
_print_divider_bar()
logging.info(" Running Seed Checker")
no_logged_seed, error_messages = self._assert_unique_seed_per_run(
result_files)

if len(error_messages) > 0:
logging.error(" Seed checker failed and found the following errors: %s", '\n'.join(error_messages))
#print("Seed checker failed and found the following "
# "errors:\n{}".format('\n'.join(error_messages)))
return False

if no_logged_seed:
for source_file in source_files:
if self._has_seed_keyword(source_file):
warnings.warn(
"Source file {} contains the keyword 'seed' but no "
"seed value is logged!".format(source_file))
if seed_checker_bypass:
logging.info("Bypassing Seed Checker")
else:
error_messages = self._assert_unique_seed_per_run(
result_files
)

if len(error_messages) > 0:
logging.error(" Seed checker failed and found the following errors: %s", '\n'.join(error_messages))
#print("Seed checker failed and found the following "
# "errors:\n{}".format('\n'.join(error_messages)))
return False
return True
Loading