diff --git a/src/nhl_model/ann.py b/src/nhl_model/ann.py index 3c3a3dd..bc8d0b9 100644 --- a/src/nhl_model/ann.py +++ b/src/nhl_model/ann.py @@ -137,8 +137,8 @@ def parseAnnArguments(config): # pylint: disable=too-many-branches if loadModel == "yes": - if exists(path_join(*[BASE_SAVE_DIR, "nhl_model"])): - outputs["savedModel"] = path_join(*[BASE_SAVE_DIR, "nhl_model"]) + if exists(path_join(*[BASE_SAVE_DIR, "nhl_model.keras"])): + outputs["savedModel"] = path_join(*[BASE_SAVE_DIR, "nhl_model.keras"]) else: # logger.debug("failed to find model, asking to create a new one") # allow the user to create the model @@ -458,7 +458,7 @@ def createModel(analysisFile, featureSelection, **kwargs): # attempt to save the model logger.debug(f"saving model as nhl_model, this will override the current model") - model.save(path_join(*[BASE_SAVE_DIR, "nhl_model"])) + model.save(path_join(*[BASE_SAVE_DIR, "nhl_model.keras"])) return model @@ -736,7 +736,7 @@ def execAnnSpecificDate(day, month, year): outputs = _askForCommonData(inputs) # load the model - expectedModelPath = path_join(*[BASE_SAVE_DIR, "nhl_model"]) + expectedModelPath = path_join(*[BASE_SAVE_DIR, "nhl_model.keras"]) if not exists(expectedModelPath): logger.critical(f"failed to find model {expectedModelPath}") return