From 336ad740412346bcff3a08d0502300dd7e1cb165 Mon Sep 17 00:00:00 2001 From: Roman Joeres Date: Fri, 30 Aug 2024 13:47:20 +0200 Subject: [PATCH] Final debugging steps for pretrained model evaluation --- configs/downstream/pretrained.yaml | 3 ++- gifflar/pretransforms.py | 11 +++++--- gifflar/train.py | 43 +++++++++++++++--------------- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/configs/downstream/pretrained.yaml b/configs/downstream/pretrained.yaml index cb6845d..46bfabf 100644 --- a/configs/downstream/pretrained.yaml +++ b/configs/downstream/pretrained.yaml @@ -25,11 +25,12 @@ datasets: # task: multilabel pre-transforms: PretrainEmbed: - file_path: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/ + folder: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/ model_name: GIFFLAR hash_str: 3fd297ab model: - name: mlp + feat_dim: 1024 hidden_dim: 1024 batch_size: 256 num_layers: 3 diff --git a/gifflar/pretransforms.py b/gifflar/pretransforms.py index a09b15c..174c6ed 100644 --- a/gifflar/pretransforms.py +++ b/gifflar/pretransforms.py @@ -414,15 +414,20 @@ class PretrainEmbed(RootTransform): def __init__(self, folder: str, dataset_name: str, model_name: str, hash_str: str, **kwargs: Any): super(PretrainEmbed, self).__init__(**kwargs) self.data = torch.load(Path(folder) / f"{dataset_name}_{model_name}_{hash_str}.pt") - self.lookup = {smiles: (i, j) for i in range(len(self.data)) for j, smiles in enumerate(self.data[i])} + self.lookup = {smiles: (i, j) for i in range(len(self.data)) for j, smiles in enumerate(self.data[i]["smiles"])} self.pooling = GIFFLARPooling() self.layer = kwargs.get("layer", -1) def __call__(self, data: HeteroData) -> HeteroData: + if data["smiles"] not in self.lookup: + print(data["smiles"], "not found in preprocessed data.") + data["fp"] = torch.zeros_like(data["fp"]) + return data + a, b = self.lookup[data["smiles"]] mask = {key: self.data[a]["batch_ids"][key] == b for key in ["atoms", "bonds", "monosacchs"]} - node_embeds = {key: self.data[a]["node_embeds"][self.layer][mask[key]] for key in ["atoms", "bonds", "monosacchs"]} - batch_ids = {key: torch.zeros_like(node_embeds[key], dtype=torch.long) for key in ["atoms", "bonds", "monosacchs"]} + node_embeds = {key: self.data[a]["node_embeds"][self.layer][key][mask[key]] for key in ["atoms", "bonds", "monosacchs"]} + batch_ids = {key: torch.zeros(len(node_embeds[key]), dtype=torch.long) for key in ["atoms", "bonds", "monosacchs"]} data["fp"] = self.pooling(node_embeds, batch_ids) return data diff --git a/gifflar/train.py b/gifflar/train.py index 2a180d9..afc3422 100644 --- a/gifflar/train.py +++ b/gifflar/train.py @@ -232,24 +232,25 @@ def main(config): if __name__ == '__main__': - embed( - prep_args={ - "model_name": "GIFFLAR", - "ckpt_path": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/checkpoints/epoch=99-step=6200.ckpt", - "hparams_path": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/hparams.yaml", - "save_dir":"/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/", - }, - **{ - "seed": 42, - "data_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/", - "root_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed", - "logs_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_embed", - "dataset": {"name": "Immunogenicity", "task": "classification"}, - "pre-transforms": {}, - "hash": "12345678", - "model": {}, - } - ) - # parser = ArgumentParser() - # parser.add_argument("config", type=str, help="Path to YAML config file") - # main(parser.parse_args().config) + #embed( + # prep_args={ + # "model_name": "GIFFLAR", + # "ckpt_path": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/checkpoints/epoch=99-step=6200.ckpt", + # "hparams_path": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/hparams.yaml", + # "save_dir":"/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/", + # }, + # **{ + # "seed": 42, + # "data_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/", + # "root_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed", + # "logs_dir": "/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_embed", + # "dataset": {"name": "Immunogenicity", "task": "classification"}, + # "pre-transforms": {}, + # "hash": "12345678", + # "model": {}, + # } + #) + parser = ArgumentParser() + parser.add_argument("config", type=str, help="Path to YAML config file") + main(parser.parse_args().config) +