From 9215cdc19e89f1ae064e5e8deea101d3a412af50 Mon Sep 17 00:00:00 2001 From: Roman Joeres Date: Mon, 9 Sep 2024 13:33:30 +0200 Subject: [PATCH] Local Mean Pooling --- configs/downstream/all.yaml | 81 +++++++++++++++--------------- configs/downstream/pretrained.yaml | 1 + gifflar/model/utils.py | 22 ++++---- gifflar/pretransforms.py | 6 +-- 4 files changed, 58 insertions(+), 52 deletions(-) diff --git a/configs/downstream/all.yaml b/configs/downstream/all.yaml index e4f556d..ceaf146 100644 --- a/configs/downstream/all.yaml +++ b/configs/downstream/all.yaml @@ -9,39 +9,39 @@ datasets: task: classification - name: Taxonomy_Domain task: multilabel - - name: Taxonomy_Kingdom - task: multilabel - - name: Taxonomy_Phylum - task: multilabel - - name: Taxonomy_Class - task: multilabel - - name: Taxonomy_Order - task: multilabel - - name: Taxonomy_Family - task: multilabel - - name: Taxonomy_Genus - task: multilabel - - name: Taxonomy_Species - task: multilabel +# - name: Taxonomy_Kingdom +# task: multilabel +# - name: Taxonomy_Phylum +# task: multilabel +# - name: Taxonomy_Class +# task: multilabel +# - name: Taxonomy_Order +# task: multilabel +# - name: Taxonomy_Family +# task: multilabel +# - name: Taxonomy_Genus +# task: multilabel +# - name: Taxonomy_Species +# task: multilabel pre-transforms: model: - # - name: rf - # n_estimators: 500 - # n_jobs: -1 - # random_state: 42 - #- name: svm - # random_state: 42 - #- name: xgb - # random_state: 42 - - name: mlp - feat_dim: 1024 - hidden_dim: 1024 - batch_size: 256 - num_layers: 3 - epochs: 100 - patience: 30 - learning_rate: 0 - optimizer: Adam +# - name: rf +# n_estimators: 500 +# n_jobs: -1 +# random_state: 42 +# - name: svm +# random_state: 42 +# - name: xgb +# random_state: 42 +# - name: mlp +# feat_dim: 1024 +# hidden_dim: 1024 +# batch_size: 256 +# num_layers: 3 +# epochs: 100 +# patience: 30 +# learning_rate: 0 +# optimizer: Adam # - name: sweetnet # feat_dim: 128 # hidden_dim: 1024 @@ -71,12 +71,13 @@ model: # learning_rate: 0.001 # optimizer: Adam # suffix: -# - name: gifflar -# feat_dim: 128 -# hidden_dim: 1024 -# batch_size: 256 -# num_layers: 8 -# epochs: 1 -# learning_rate: 0.001 -# optimizer: Adam -# suffix: + - name: gifflar + feat_dim: 128 + hidden_dim: 1024 + batch_size: 256 + num_layers: 8 + epochs: 1 + learning_rate: 0.001 + optimizer: Adam + pooling: local_mean + suffix: diff --git a/configs/downstream/pretrained.yaml b/configs/downstream/pretrained.yaml index 66f33da..91511cc 100644 --- a/configs/downstream/pretrained.yaml +++ b/configs/downstream/pretrained.yaml @@ -28,6 +28,7 @@ pre-transforms: folder: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/ model_name: GIFFLAR hash_str: 659a0fbd + pooling: global_mean model: - name: mlp feat_dim: 1024 diff --git a/gifflar/model/utils.py b/gifflar/model/utils.py index 9cdf50f..155aeb5 100644 --- a/gifflar/model/utils.py +++ b/gifflar/model/utils.py @@ -107,11 +107,7 @@ def __init__(self, mode: str = "global_mean"): output_dim: The output dimension of the pooling layer """ super().__init__() - match (mode): - case "global_mean": - self.pooling = global_mean_pool - case _: - raise ValueError(f"Pooling mode {mode} not implemented yet.") + self.mode = mode def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Tensor]) -> torch.Tensor: """ @@ -120,7 +116,15 @@ def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Ten Args: x: The input to the pooling layer """ - return self.pooling( - torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), - torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0) - ) + match self.mode: + case "global_mean": + return global_mean_pool( + torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0) + ) + case "local_mean": + return torch.sum(torch.stack([ + global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys() + ]), dim=0) + case _: + raise ValueError(f"Pooling mode {self.mode} not supported") diff --git a/gifflar/pretransforms.py b/gifflar/pretransforms.py index f294eb5..0cb5506 100644 --- a/gifflar/pretransforms.py +++ b/gifflar/pretransforms.py @@ -368,11 +368,11 @@ def __call__(self, data: HeteroData) -> HeteroData: """ if self.individual: # compute the random walk positional encodings for each node type individually for d, name in zip(split_hetero_graph(data), ["atoms", "bonds", "monosacchs"]): - self.forward(d) + d = self.forward(d) data[f"{name}_{self.attr_name}"] = d[self.attr_name] else: # or for the whole graph d = hetero_to_homo(data) - self.forward(d) + d = self.forward(d) data[f"atoms_{self.attr_name}"] = d[self.attr_name][:data["atoms"]["num_nodes"]] data[f"bonds_{self.attr_name}"] = d[self.attr_name][ data["atoms"]["num_nodes"]:-data["monosacchs"]["num_nodes"]] @@ -419,7 +419,7 @@ def __call__(self, data: HeteroData) -> HeteroData: class PretrainEmbed(RootTransform): """Run a GIFFLAR model to embed the input data.""" - def __init__(self, folder: str, dataset_name: str, model_name: str, hash_str: str, **kwargs: Any): + def __init__(self, folder: str, dataset_name: str, model_name: str, pooling: str, hash_str: str, **kwargs: Any): """ Set up the Embedding from a pretrained model.