diff --git a/caikit/runtime/train.py b/caikit/runtime/train.py index 6308d467f..f5fb32e7d 100644 --- a/caikit/runtime/train.py +++ b/caikit/runtime/train.py @@ -52,8 +52,6 @@ class ArgumentParserError(Exception): """Custom exception class for ArgumentParser errors.""" - pass - class TrainArgumentParser(argparse.ArgumentParser): def error(self, message): @@ -63,9 +61,8 @@ def error(self, message): def write_termination_log(text: str, log_file="/dev/termination-log"): try: - f = open(log_file, "a") - f.write(text) - f.close() + with open(log_file, "a") as handle: + handle.write(text) except Exception as e: log.warning( "", @@ -145,14 +142,12 @@ def main() -> int: train_kwargs["trainer"] = args.trainer except Exception as e: - message = "Exception raised during training. This may be a problem with your input: {}".format( - e - ) + message = "Exception raised during training. This may be a problem with your input: {e}" log.warning( { "log_code": "", "message": message, - "stacK_trace": traceback.format_exc(), + "stack_trace": traceback.format_exc(), } ) write_termination_log(message) @@ -165,7 +160,13 @@ def main() -> int: importlib.import_module(library) except Exception as e: message = "Unable to import module {}".format(library) - log.error("", message) + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + } + ) write_termination_log(message) exit(USER_ERROR_EXIT_CODE) @@ -192,11 +193,11 @@ def main() -> int: "Unable to find module {} to train", args.module, ) - except Exception as e: - message = "Unable to import module {}".format(args.module) + except (ValueError, Exception) as e: + message = "Unable to find module {} to train".format(args.module) log.warning( { - "log_code": "", + "log_code": "", "message": message, "stack_trace": traceback.format_exc, } @@ -211,30 +212,7 @@ def main() -> int: training_kwargs = json.load(handle) else: training_kwargs = json.loads(args.training_kwargs) - except json.decoder.JSONDecodeError: - message = "training-kwargs must be valid json or point to a valid json file" - log.warning( - { - "log_code": "", - "message": message, - "stack_trace": traceback.format_exc(), - } - ) - write_termination_log(message) - exit(USER_ERROR_EXIT_CODE) - except Exception as e: - message = "Exception encountered when attempting to parse input parameters" - log.warning( - { - "log_code": "", - "message": message, - "stack_trace": traceback.format_exc(), - } - ) - write_termination_log(message) - exit(USER_ERROR_EXIT_CODE) - try: # Convert datatypes to match the training API training_service = ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.TRAINING, @@ -264,11 +242,31 @@ def main() -> int: ) train_kwargs.update(req_kwargs) log.debug3("All train kwargs: %s", train_kwargs) + except json.decoder.JSONDecodeError: + message = "training-kwargs must be valid json or point to a valid json file" + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + } + ) + write_termination_log(message) + exit(USER_ERROR_EXIT_CODE) + except ValueError as e: + message = "Invalid value for one or more input parameters: {e}" + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + } + ) except Exception as e: message = "Exception encountered when attempting to parse input parameters" log.warning( { - "log_code": "", + "log_code": "", "message": message, "stack_trace": traceback.format_exc(), }