diff --git a/build/launch_training.py b/build/launch_training.py index f615e5e0a..1d6a518f6 100644 --- a/build/launch_training.py +++ b/build/launch_training.py @@ -162,7 +162,8 @@ def main(): shutil.copy(train_logs_filepath, original_output_dir) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) - # Continue, don't fail the training because of this + write_termination_log("Exception encountered in capturing training logs") + sys.exit(INTERNAL_ERROR_EXIT_CODE) return 0 diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index 8ddb35434..11520f649 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -56,7 +56,7 @@ "prompt_tuning_init_text": "hello", "tokenizer_name_or_path": MODEL_NAME, "save_strategy": "epoch", - "output_dir": "tmp" + "output_dir": "tmp", } @@ -138,3 +138,9 @@ def test_config_parsing_error(): main() assert pytest_wrapped_e.type == SystemExit assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE + + +def test_cleanup(): + # This runs to unset env variables that could disrupt other tests + os.environ.pop("LAUNCH_TRAINING_SCRIPT", None) + assert True