Skip to content

Commit

Permalink
Local Mean Pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Sep 9, 2024
1 parent 9e2af21 commit 9215cdc
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 52 deletions.
81 changes: 41 additions & 40 deletions configs/downstream/all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,39 @@ datasets:
task: classification
- name: Taxonomy_Domain
task: multilabel
- name: Taxonomy_Kingdom
task: multilabel
- name: Taxonomy_Phylum
task: multilabel
- name: Taxonomy_Class
task: multilabel
- name: Taxonomy_Order
task: multilabel
- name: Taxonomy_Family
task: multilabel
- name: Taxonomy_Genus
task: multilabel
- name: Taxonomy_Species
task: multilabel
# - name: Taxonomy_Kingdom
# task: multilabel
# - name: Taxonomy_Phylum
# task: multilabel
# - name: Taxonomy_Class
# task: multilabel
# - name: Taxonomy_Order
# task: multilabel
# - name: Taxonomy_Family
# task: multilabel
# - name: Taxonomy_Genus
# task: multilabel
# - name: Taxonomy_Species
# task: multilabel
pre-transforms:
model:
# - name: rf
# n_estimators: 500
# n_jobs: -1
# random_state: 42
#- name: svm
# random_state: 42
#- name: xgb
# random_state: 42
- name: mlp
feat_dim: 1024
hidden_dim: 1024
batch_size: 256
num_layers: 3
epochs: 100
patience: 30
learning_rate: 0
optimizer: Adam
# - name: rf
# n_estimators: 500
# n_jobs: -1
# random_state: 42
# - name: svm
# random_state: 42
# - name: xgb
# random_state: 42
# - name: mlp
# feat_dim: 1024
# hidden_dim: 1024
# batch_size: 256
# num_layers: 3
# epochs: 100
# patience: 30
# learning_rate: 0
# optimizer: Adam
# - name: sweetnet
# feat_dim: 128
# hidden_dim: 1024
Expand Down Expand Up @@ -71,12 +71,13 @@ model:
# learning_rate: 0.001
# optimizer: Adam
# suffix:
# - name: gifflar
# feat_dim: 128
# hidden_dim: 1024
# batch_size: 256
# num_layers: 8
# epochs: 1
# learning_rate: 0.001
# optimizer: Adam
# suffix:
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 1
learning_rate: 0.001
optimizer: Adam
pooling: local_mean
suffix:
1 change: 1 addition & 0 deletions configs/downstream/pretrained.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pre-transforms:
folder: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/
model_name: GIFFLAR
hash_str: 659a0fbd
pooling: global_mean
model:
- name: mlp
feat_dim: 1024
Expand Down
22 changes: 13 additions & 9 deletions gifflar/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ def __init__(self, mode: str = "global_mean"):
output_dim: The output dimension of the pooling layer
"""
super().__init__()
match (mode):
case "global_mean":
self.pooling = global_mean_pool
case _:
raise ValueError(f"Pooling mode {mode} not implemented yet.")
self.mode = mode

def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Expand All @@ -120,7 +116,15 @@ def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Ten
Args:
x: The input to the pooling layer
"""
return self.pooling(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0)
)
match self.mode:
case "global_mean":
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0)
)
case "local_mean":
return torch.sum(torch.stack([
global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()
]), dim=0)
case _:
raise ValueError(f"Pooling mode {self.mode} not supported")
6 changes: 3 additions & 3 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,11 @@ def __call__(self, data: HeteroData) -> HeteroData:
"""
if self.individual: # compute the random walk positional encodings for each node type individually
for d, name in zip(split_hetero_graph(data), ["atoms", "bonds", "monosacchs"]):
self.forward(d)
d = self.forward(d)
data[f"{name}_{self.attr_name}"] = d[self.attr_name]
else: # or for the whole graph
d = hetero_to_homo(data)
self.forward(d)
d = self.forward(d)
data[f"atoms_{self.attr_name}"] = d[self.attr_name][:data["atoms"]["num_nodes"]]
data[f"bonds_{self.attr_name}"] = d[self.attr_name][
data["atoms"]["num_nodes"]:-data["monosacchs"]["num_nodes"]]
Expand Down Expand Up @@ -419,7 +419,7 @@ def __call__(self, data: HeteroData) -> HeteroData:
class PretrainEmbed(RootTransform):
"""Run a GIFFLAR model to embed the input data."""

def __init__(self, folder: str, dataset_name: str, model_name: str, hash_str: str, **kwargs: Any):
def __init__(self, folder: str, dataset_name: str, model_name: str, pooling: str, hash_str: str, **kwargs: Any):
"""
Set up the Embedding from a pretrained model.
Expand Down

0 comments on commit 9215cdc

Please sign in to comment.