Skip to content

Commit

Permalink
🧹 Tidy up train.py
Browse files Browse the repository at this point in the history
Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>
  • Loading branch information
kellyaa committed May 3, 2024
1 parent fb22bb4 commit 423f0d7
Showing 1 changed file with 35 additions and 37 deletions.
72 changes: 35 additions & 37 deletions caikit/runtime/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
class ArgumentParserError(Exception):
"""Custom exception class for ArgumentParser errors."""

pass


class TrainArgumentParser(argparse.ArgumentParser):
def error(self, message):
Expand All @@ -63,9 +61,8 @@ def error(self, message):

def write_termination_log(text: str, log_file="/dev/termination-log"):
try:
f = open(log_file, "a")
f.write(text)
f.close()
with open(log_file, "a") as handle:
handle.write(text)
except Exception as e:
log.warning(
"<COR96300323W>",
Expand Down Expand Up @@ -145,14 +142,12 @@ def main() -> int:
train_kwargs["trainer"] = args.trainer

except Exception as e:
message = "Exception raised during training. This may be a problem with your input: {}".format(
e
)
message = "Exception raised during training. This may be a problem with your input: {e}"
log.warning(
{
"log_code": "<COR39662029E>",
"message": message,
"stacK_trace": traceback.format_exc(),
"stack_trace": traceback.format_exc(),
}
)
write_termination_log(message)
Expand All @@ -165,7 +160,13 @@ def main() -> int:
importlib.import_module(library)
except Exception as e:
message = "Unable to import module {}".format(library)
log.error("<COR17776539E>", message)
log.warning(
{
"log_code": "<COR17776539E>",
"message": message,
"stack_trace": traceback.format_exc(),
}
)
write_termination_log(message)
exit(USER_ERROR_EXIT_CODE)

Expand All @@ -192,11 +193,11 @@ def main() -> int:
"Unable to find module {} to train",
args.module,
)
except Exception as e:
message = "Unable to import module {}".format(args.module)
except (ValueError, Exception) as e:
message = "Unable to find module {} to train".format(args.module)
log.warning(
{
"log_code": "<COR17776539E>",
"log_code": "<COR17476539E>",
"message": message,
"stack_trace": traceback.format_exc,
}
Expand All @@ -211,30 +212,7 @@ def main() -> int:
training_kwargs = json.load(handle)
else:
training_kwargs = json.loads(args.training_kwargs)
except json.decoder.JSONDecodeError:
message = "training-kwargs must be valid json or point to a valid json file"
log.warning(
{
"log_code": "<COR65834760E>",
"message": message,
"stack_trace": traceback.format_exc(),
}
)
write_termination_log(message)
exit(USER_ERROR_EXIT_CODE)
except Exception as e:
message = "Exception encountered when attempting to parse input parameters"
log.warning(
{
"log_code": "<COR17776549E>",
"message": message,
"stack_trace": traceback.format_exc(),
}
)
write_termination_log(message)
exit(USER_ERROR_EXIT_CODE)

try:
# Convert datatypes to match the training API
training_service = ServicePackageFactory.get_service_package(
ServicePackageFactory.ServiceType.TRAINING,
Expand Down Expand Up @@ -264,11 +242,31 @@ def main() -> int:
)
train_kwargs.update(req_kwargs)
log.debug3("All train kwargs: %s", train_kwargs)
except json.decoder.JSONDecodeError:
message = "training-kwargs must be valid json or point to a valid json file"
log.warning(
{
"log_code": "<COR65834760E>",
"message": message,
"stack_trace": traceback.format_exc(),
}
)
write_termination_log(message)
exit(USER_ERROR_EXIT_CODE)
except ValueError as e:
message = "Invalid value for one or more input parameters: {e}"
log.warning(
{
"log_code": "<COR65474760E>",
"message": message,
"stack_trace": traceback.format_exc(),
}
)
except Exception as e:
message = "Exception encountered when attempting to parse input parameters"
log.warning(
{
"log_code": "<COR17376549E>",
"log_code": "<COR17776549E>",
"message": message,
"stack_trace": traceback.format_exc(),
}
Expand Down

0 comments on commit 423f0d7

Please sign in to comment.