Skip to content

Commit

Permalink
Final debugging steps for pretrained model evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Joeres authored and Roman Joeres committed Aug 30, 2024
1 parent afcbd83 commit 336ad74
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 25 deletions.
3 changes: 2 additions & 1 deletion configs/downstream/pretrained.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ datasets:
# task: multilabel
pre-transforms:
PretrainEmbed:
file_path: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/
folder: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/
model_name: GIFFLAR
hash_str: 3fd297ab
model:
- name: mlp
feat_dim: 1024
hidden_dim: 1024
batch_size: 256
num_layers: 3
Expand Down
11 changes: 8 additions & 3 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,20 @@ class PretrainEmbed(RootTransform):
def __init__(self, folder: str, dataset_name: str, model_name: str, hash_str: str, **kwargs: Any):
super(PretrainEmbed, self).__init__(**kwargs)
self.data = torch.load(Path(folder) / f"{dataset_name}_{model_name}_{hash_str}.pt")
self.lookup = {smiles: (i, j) for i in range(len(self.data)) for j, smiles in enumerate(self.data[i])}
self.lookup = {smiles: (i, j) for i in range(len(self.data)) for j, smiles in enumerate(self.data[i]["smiles"])}
self.pooling = GIFFLARPooling()
self.layer = kwargs.get("layer", -1)

def __call__(self, data: HeteroData) -> HeteroData:
if data["smiles"] not in self.lookup:
print(data["smiles"], "not found in preprocessed data.")
data["fp"] = torch.zeros_like(data["fp"])
return data

a, b = self.lookup[data["smiles"]]
mask = {key: self.data[a]["batch_ids"][key] == b for key in ["atoms", "bonds", "monosacchs"]}
node_embeds = {key: self.data[a]["node_embeds"][self.layer][mask[key]] for key in ["atoms", "bonds", "monosacchs"]}
batch_ids = {key: torch.zeros_like(node_embeds[key], dtype=torch.long) for key in ["atoms", "bonds", "monosacchs"]}
node_embeds = {key: self.data[a]["node_embeds"][self.layer][key][mask[key]] for key in ["atoms", "bonds", "monosacchs"]}
batch_ids = {key: torch.zeros(len(node_embeds[key]), dtype=torch.long) for key in ["atoms", "bonds", "monosacchs"]}
data["fp"] = self.pooling(node_embeds, batch_ids)
return data

Expand Down
43 changes: 22 additions & 21 deletions gifflar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,24 +232,25 @@ def main(config):


if __name__ == '__main__':
embed(
prep_args={
"model_name": "GIFFLAR",
"ckpt_path": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/checkpoints/epoch=99-step=6200.ckpt",
"hparams_path": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/hparams.yaml",
"save_dir":"/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/",
},
**{
"seed": 42,
"data_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/",
"root_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed",
"logs_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_embed",
"dataset": {"name": "Immunogenicity", "task": "classification"},
"pre-transforms": {},
"hash": "12345678",
"model": {},
}
)
# parser = ArgumentParser()
# parser.add_argument("config", type=str, help="Path to YAML config file")
# main(parser.parse_args().config)
#embed(
# prep_args={
# "model_name": "GIFFLAR",
# "ckpt_path": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/checkpoints/epoch=99-step=6200.ckpt",
# "hparams_path": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/hparams.yaml",
# "save_dir":"/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/",
# },
# **{
# "seed": 42,
# "data_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/",
# "root_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed",
# "logs_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_embed",
# "dataset": {"name": "Immunogenicity", "task": "classification"},
# "pre-transforms": {},
# "hash": "12345678",
# "model": {},
# }
#)
parser = ArgumentParser()
parser.add_argument("config", type=str, help="Path to YAML config file")
main(parser.parse_args().config)

0 comments on commit 336ad74

Please sign in to comment.