diff --git a/gifflar/train.py b/gifflar/train.py index 9316635..d52376c 100644 --- a/gifflar/train.py +++ b/gifflar/train.py @@ -198,10 +198,12 @@ def embed(prep_args: dict[str, str], **kwargs: Any) -> None: f"{kwargs['dataset']['name']}_{prep_args['name']}_{hash_dict(prep_args, 8)}") if output_name.exists(): return + else: + output_name.mkdir(parents=True) with open(prep_args["hparams_path"], "r") as file: config = yaml.load(file, Loader=yaml.FullLoader) - model = PretrainGGIN(**config["model"], tasks=None, pre_transform_args=kwargs.get("pre-transforms", {})) + model = PretrainGGIN(**config["model"], tasks=None, pre_transform_args=kwargs.get("pre-transforms", {}), save_dir=output_name) if torch.cuda.is_available(): model.load_state_dict(torch.load(prep_args["ckpt_path"])["state_dict"]) else: @@ -210,7 +212,6 @@ def embed(prep_args: dict[str, str], **kwargs: Any) -> None: model.save_nth_layer(int(prep_args["nth_layer"])) data_config, data, _, _ = setup(2, **kwargs) - data.save_dir = output_name trainer = Trainer() trainer.predict(model, data.predict_dataloader())