Skip to content

Commit

Permalink
Minor bug fix with directory for embedding saving
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Sep 4, 2024
1 parent 0a27576 commit 8cfcaa4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions gifflar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,12 @@ def embed(prep_args: dict[str, str], **kwargs: Any) -> None:
f"{kwargs['dataset']['name']}_{prep_args['name']}_{hash_dict(prep_args, 8)}")
if output_name.exists():
return
else:
output_name.mkdir(parents=True)

with open(prep_args["hparams_path"], "r") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
model = PretrainGGIN(**config["model"], tasks=None, pre_transform_args=kwargs.get("pre-transforms", {}))
model = PretrainGGIN(**config["model"], tasks=None, pre_transform_args=kwargs.get("pre-transforms", {}), save_dir=output_name)
if torch.cuda.is_available():
model.load_state_dict(torch.load(prep_args["ckpt_path"])["state_dict"])
else:
Expand All @@ -210,7 +212,6 @@ def embed(prep_args: dict[str, str], **kwargs: Any) -> None:
model.save_nth_layer(int(prep_args["nth_layer"]))

data_config, data, _, _ = setup(2, **kwargs)
data.save_dir = output_name
trainer = Trainer()
trainer.predict(model, data.predict_dataloader())

Expand Down

0 comments on commit 8cfcaa4

Please sign in to comment.