Skip to content

Commit

Permalink
Add exception catching / writing to termination log (foundation-model…
Browse files Browse the repository at this point in the history
…-stack#149)

* Rebase exception handling in launch_training

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* add exception catching to accelerate launch script

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* linter fixes

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* 🧹 linting fixes for launch scripts

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* 🧹 fmt fixes for launch scripts

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* Make exit codes configurable

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* Add tests for accelerate_launch

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* Add launch script var, fix linting

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* Fail job when logs cannot be copied

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* fix rebase mistake

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* fix rebase mistake

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* Exception catching improvements

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

* Update model loading exception message

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>

---------

Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>
  • Loading branch information
kellyaa authored May 14, 2024
1 parent 762be59 commit 38c4f22
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 67 deletions.
64 changes: 59 additions & 5 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,77 @@
# Standard
import os
import logging
import subprocess
import sys
import traceback

# Third Party
from accelerate.commands.launch import launch_command

# Local
from build.utils import process_accelerate_launch_args, get_job_config
from build.utils import (
process_accelerate_launch_args,
get_job_config,
write_termination_log,
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)


def main():
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)

job_config = get_job_config()
##########
#
# Parse arguments
#
##########
try:
job_config = get_job_config()

args = process_accelerate_launch_args(job_config)
logging.debug("accelerate launch parsed args: %s", args)
launch_command(args)
args = process_accelerate_launch_args(job_config)
logging.debug("accelerate launch parsed args: %s", args)
except FileNotFoundError as e:
logging.error(traceback.format_exc())
write_termination_log("Unable to load file: {}".format(e))
sys.exit(USER_ERROR_EXIT_CODE)
except (TypeError, ValueError, EnvironmentError) as e:
logging.error(traceback.format_exc())
write_termination_log(
f"Exception raised during training. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

##########
#
# Launch training
#
##########
try:
launch_command(args)
except subprocess.CalledProcessError as e:
# If the subprocess throws an exception, the base exception is hidden in the subprocess call
# and is difficult to access at this level. However, that is not an issue because
# launch_training.py would have already written the exception message to termination log.
logging.error(traceback.format_exc())
# The exit code that launch_training.py threw is captured in e.returncode

return_code = e.returncode
if return_code not in [INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE]:
return_code = INTERNAL_ERROR_EXIT_CODE
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(return_code)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

return 0


if __name__ == "__main__":
Expand Down
188 changes: 127 additions & 61 deletions build/launch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import os
import tempfile
import shutil
import sys
import traceback

# Third Party
from huggingface_hub.utils._validators import HFValidationError
from torch.cuda import OutOfMemoryError

# First Party
import logging
Expand All @@ -29,7 +35,13 @@
from tuning import sft_trainer
from tuning.utils.merge_model_utils import create_merged_model
from tuning.config.tracker_configs import TrackerConfigFactory
from build.utils import process_launch_training_args, get_job_config
from build.utils import (
process_launch_training_args,
get_job_config,
write_termination_log,
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)


def get_highest_checkpoint(dir_path):
Expand All @@ -53,78 +65,132 @@ def main():

logging.info("Initializing launch training script")

job_config = get_job_config()

logging.debug("Input params parsed: %s", job_config)

(
model_args,
data_args,
training_args,
tune_config,
merge_model,
file_logger_config,
aim_config,
) = process_launch_training_args(job_config)
try:
job_config = get_job_config()
logging.debug("Input params parsed: %s", job_config)

(
model_args,
data_args,
training_args,
tune_config,
merge_model,
file_logger_config,
aim_config,
) = process_launch_training_args(job_config)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception raised during training. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)

original_output_dir = training_args.output_dir
with tempfile.TemporaryDirectory() as tempdir:
training_args.output_dir = tempdir
tracker_config_args = TrackerConfigFactory(
file_logger_config=file_logger_config, aim_config=aim_config
)
sft_trainer.train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
peft_config=tune_config,
tracker_configs=tracker_config_args,
)

if merge_model:
export_path = os.getenv(
"LORA_MERGE_MODELS_EXPORT_PATH", original_output_dir
try:
tracker_config_args = TrackerConfigFactory(
file_logger_config=file_logger_config, aim_config=aim_config
)

# get the highest checkpoint dir (last checkpoint)
lora_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
full_checkpoint_dir = os.path.join(
training_args.output_dir, lora_checkpoint_dir
sft_trainer.train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
peft_config=tune_config,
tracker_configs=tracker_config_args,
)

logging.info(
"Merging lora tuned checkpoint %s with base model into output path: %s",
lora_checkpoint_dir,
export_path,
except (MemoryError, OutOfMemoryError) as e:
logging.error(traceback.format_exc())
write_termination_log(f"OOM error during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)
except FileNotFoundError as e:
logging.error(traceback.format_exc())
write_termination_log("Unable to load file: {}".format(e))
sys.exit(USER_ERROR_EXIT_CODE)
except HFValidationError as e:
logging.error(traceback.format_exc())
write_termination_log(
f"There may be a problem with loading the model. Exception: {e}"
)

create_merged_model(
checkpoint_models=full_checkpoint_dir,
export_path=export_path,
base_model=model_args.model_name_or_path,
save_tokenizer=True,
sys.exit(USER_ERROR_EXIT_CODE)
except (TypeError, ValueError, EnvironmentError) as e:
logging.error(traceback.format_exc())
write_termination_log(
f"Exception raised during training. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Unhandled exception during training: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

if merge_model:
try:
export_path = os.getenv(
"LORA_MERGE_MODELS_EXPORT_PATH", original_output_dir
)

# get the highest checkpoint dir (last checkpoint)
lora_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
full_checkpoint_dir = os.path.join(
training_args.output_dir, lora_checkpoint_dir
)

logging.info(
"Merging lora tuned checkpoint %s with base model into output path: %s",
lora_checkpoint_dir,
export_path,
)

create_merged_model(
checkpoint_models=full_checkpoint_dir,
export_path=export_path,
base_model=model_args.model_name_or_path,
save_tokenizer=True,
)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered merging base model with checkpoint. {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)
else:
# copy last checkpoint into mounted output dir
pt_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
logging.info(
"Copying last checkpoint %s into output dir %s",
pt_checkpoint_dir,
original_output_dir,
try:
# copy last checkpoint into mounted output dir
pt_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
logging.info(
"Copying last checkpoint %s into output dir %s",
pt_checkpoint_dir,
original_output_dir,
)
shutil.copytree(
os.path.join(training_args.output_dir, pt_checkpoint_dir),
original_output_dir,
dirs_exist_ok=True,
)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered writing output model to storage: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# copy over any loss logs
try:
train_logs_filepath = os.path.join(
training_args.output_dir,
tracker_config_args.file_logger_config.training_logs_filename,
)
shutil.copytree(
os.path.join(training_args.output_dir, pt_checkpoint_dir),
original_output_dir,
dirs_exist_ok=True,
if os.path.exists(train_logs_filepath):
shutil.copy(train_logs_filepath, original_output_dir)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered in capturing training logs: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# copy over any loss logs
train_logs_filepath = os.path.join(
training_args.output_dir,
tracker_config_args.file_logger_config.training_logs_filename,
)
if os.path.exists(train_logs_filepath):
shutil.copy(train_logs_filepath, original_output_dir)
return 0


if __name__ == "__main__":
Expand Down
21 changes: 20 additions & 1 deletion build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@
# Local
from tuning.config import configs, peft_config, tracker_configs

# The USER_ERROR_EXIT_CODE will be thrown when the process must exit
# as result of a user input error. User-related errors should be
# >= 1 and <=127 due to how some kubernetes operators interpret them.
USER_ERROR_EXIT_CODE = 1
# The INTERNAL_ERROR_EXIT_CODE will be thrown when training
# abnormally terminates, and it is not clearly fault of the user.
# System-level errors should be >= 128 and <= 254
INTERNAL_ERROR_EXIT_CODE = 203


def write_termination_log(text):
log_file = os.environ.get("TERMINATION_LOG_FILE", "/dev/termination-log")
try:
with open(log_file, "a", encoding="utf-8") as handle:
handle.write(text)
except Exception as e: # pylint: disable=broad-except
logging.warning("Unable to write termination log due to error {}".format(e))


def txt_to_obj(txt):
base64_bytes = txt.encode("ascii")
Expand Down Expand Up @@ -203,7 +221,8 @@ def process_accelerate_launch_args(job_config_dict):
)

# Add training_script
accelerate_launch_args.append("/app/launch_training.py")
script = os.environ.get("LAUNCH_TRAINING_SCRIPT", "/app/launch_training.py")
accelerate_launch_args.append(script)

logging.debug("accelerate_launch_args: %s", accelerate_launch_args)
args = parser.parse_args(args=accelerate_launch_args)
Expand Down
Loading

0 comments on commit 38c4f22

Please sign in to comment.