diff --git a/configs/downstream/all.yaml b/configs/downstream/all.yaml index 99940cc..ea544a9 100644 --- a/configs/downstream/all.yaml +++ b/configs/downstream/all.yaml @@ -1,7 +1,7 @@ seed: 42 data_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/ -root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_test -logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_test +root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_feat +logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/clean_logs datasets: - name: Immunogenicity task: classification @@ -9,20 +9,20 @@ datasets: task: classification - name: Taxonomy_Domain task: multilabel -# - name: Taxonomy_Kingdom -# task: multilabel -# - name: Taxonomy_Phylum -# task: multilabel -# - name: Taxonomy_Class -# task: multilabel -# - name: Taxonomy_Order -# task: multilabel -# - name: Taxonomy_Family -# task: multilabel -# - name: Taxonomy_Genus -# task: multilabel -# - name: Taxonomy_Species -# task: multilabel + - name: Taxonomy_Kingdom + task: multilabel + - name: Taxonomy_Phylum + task: multilabel + - name: Taxonomy_Class + task: multilabel + - name: Taxonomy_Order + task: multilabel + - name: Taxonomy_Family + task: multilabel + - name: Taxonomy_Genus + task: multilabel + - name: Taxonomy_Species + task: multilabel pre-transforms: model: # - name: rf @@ -79,5 +79,5 @@ model: epochs: 100 learning_rate: 0.001 optimizer: Adam - pooling: local_mean - suffix: _128_8_local + pooling: weighted_attention + suffix: _128_8_la_pool diff --git a/configs/downstream/pooling.yaml b/configs/downstream/pooling.yaml index 4e40249..5d260d2 100644 --- a/configs/downstream/pooling.yaml +++ b/configs/downstream/pooling.yaml @@ -68,36 +68,6 @@ model: num_layers: 8 epochs: 100 learning_rate: 0.001 - optimizer: Adam + optimizer: Adam pooling: weighted_attention suffix: _weighted_att - - name: gifflar - feat_dim: 128 - hidden_dim: 1024 - batch_size: 256 - num_layers: 8 - epochs: 100 - learning_rate: 0.001 - optimizer: Adam - pooling: cell_attention - suffix: _cell_att - - name: gifflar - feat_dim: 128 - hidden_dim: 1024 - batch_size: 256 - num_layers: 8 - epochs: 100 - learning_rate: 0.001 - optimizer: Adam - pooling: local_cell_attention - suffix: _local_cell_att - - name: gifflar - feat_dim: 128 - hidden_dim: 1024 - batch_size: 256 - num_layers: 8 - epochs: 100 - learning_rate: 0.001 - optimizer: Adam - pooling: weighted_cell_attention - suffix: _weighted_cell_att diff --git a/gifflar/model/downstream.py b/gifflar/model/downstream.py index 4f2f175..cf3e4d8 100644 --- a/gifflar/model/downstream.py +++ b/gifflar/model/downstream.py @@ -37,7 +37,8 @@ def __init__( super().__init__(feat_dim, hidden_dim, num_layers, batch_size, pre_transform_args) self.output_dim = output_dim - self.pooling = GIFFLARPooling() + self.pooling = GIFFLARPooling(kwargs.get("pooling", "global_mean")) + # self.add_module('pooling', GIFFLARPooling(kwargs.get("pooling", "global_mean"))) self.task = task self.head, self.loss, self.metrics = get_prediction_head(hidden_dim, output_dim, task) @@ -53,6 +54,7 @@ def to(self, device: torch.device) -> "DownstreamGGIN": self: The model moved to the specified device """ super(DownstreamGGIN, self).to(device) + self.pooling.to(device) for split, metric in self.metrics.items(): self.metrics[split] = metric.to(device) return self diff --git a/gifflar/model/utils.py b/gifflar/model/utils.py index 34aaff3..a3c70e8 100644 --- a/gifflar/model/utils.py +++ b/gifflar/model/utils.py @@ -5,11 +5,24 @@ from torch_geometric.nn.inits import reset from torch_geometric.utils import softmax from torch_geometric.nn import GINConv, global_mean_pool +from torch_scatter import scatter_add from gifflar.utils import get_metrics -class GlobalAttention(torch.nn.Module): +class MultiGlobalAttention(nn.Module): + def __init__(self, gas: dict): + super().__init__() + for name, ga in gas.items(): + setattr(self, name, ga) + + def __getitem__(self, item): + if hasattr(self, item): + return getattr(self, item) + return super().__getitem__(item) + + +class GlobalAttention(nn.Module): r"""Global soft attention layer from the `"Gated Graph Sequence Neural Networks" `_ paper @@ -44,21 +57,20 @@ def reset_parameters(self): reset(self.gate_nn) reset(self.nn) - def forward(self, x, batch, size=None): """""" + if batch is None: + batch = x.new_zeros(x.size(0), dtype=torch.int64) + x = x.unsqueeze(-1) if x.dim() == 1 else x - size = batch[-1].item() + 1 if size is None else size + size = int(batch.max()) + 1 if size is None else size gate = self.gate_nn(x).view(-1, 1) x = self.nn(x) if self.nn is not None else x assert gate.dim() == x.dim() and gate.size(0) == x.size(0) gate = softmax(gate, batch, num_nodes=size) - - # out = scatter_add(gate * x, index=batch, dim=0) # , dim_size=size) - src = torch.zeros_like(size) - out = src.scatter_add(src=gate * x, index=batch, dim=0) + out = scatter_add(gate * x, batch, dim=0, dim_size=size) return out @@ -190,18 +202,45 @@ def __init__(self, mode: str = "global_mean"): self.attention = None self.weights = None if "weighted" in mode: + print("Create weighting") self.weights = nn.Parameter(torch.ones(3), requires_grad=True) if "attention" in mode: + print("Create attention") if mode == "global_attention": - self.attention = {"": GlobalAttention( - gate_nn=get_gin_layer(1024, 1), - nn=get_gin_layer(1024, 1024), - )} + self.attention = MultiGlobalAttention({"": GlobalAttention( + gate_nn=nn.Linear(1024, 1), + nn=nn.Sequential( + nn.Linear(1024, 1024), + nn.PReLU(), + nn.Dropout(0.2), + nn.BatchNorm1d(1024), + nn.Linear(1024, 1024), + ), + )}) + else: + self.attention = MultiGlobalAttention({key: GlobalAttention( + gate_nn=nn.Linear(1024, 1), + nn=nn.Sequential( + nn.Linear(1024, 1024), + nn.PReLU(), + nn.Dropout(0.2), + nn.BatchNorm1d(1024), + nn.Linear(1024, 1024), + ), + ) for key in ["atoms", "bonds", "monosacchs"]}) + + def to(self, device): + if self.weights is not None: + self.weights = self.weights.to(device) + if self.attention is not None: + if self.mode == "global_attention": + self.attention[""].to(device) else: - self.attention = {key: GlobalAttention( - gate_nn=get_gin_layer(1024, 1), - nn=get_gin_layer(1024, 1024), - ) for key in ["atoms", "bonds", "monosacchs"]} + self.attention["atoms"].to(device) + self.attention["bonds"].to(device) + self.attention["monosacchs"].to(device) + super(GIFFLARPooling, self).to(device) + return self def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Tensor]) -> torch.Tensor: """ @@ -210,24 +249,73 @@ def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Ten Args: x: The input to the pooling layer """ + device = nodes["atoms"].device match self.mode: + case "global_mean": + return global_mean_pool( + torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0) + ) case "local_mean": nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()} + return global_mean_pool( + torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + torch.concat([torch.arange(len(nodes[key])).to(device) for key in ["atoms", "bonds", "monosacchs"]], dim=0) + ) case "weighted_mean": nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])} + return global_mean_pool( + torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + torch.concat([torch.arange(len(nodes[key])).to(device) for key in ["atoms", "bonds", "monosacchs"]], dim=0) + ) case "global_attention": - nodes = {key: self.attention[""](nodes[key], batch_ids[key]) for key in nodes.keys()} + return self.attention[""]( + torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0) + ) case "local_attention": nodes = {key: self.attention[key](nodes[key], batch_ids[key]) for key in nodes.keys()} + return global_mean_pool( + torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + torch.concat([torch.arange(len(nodes[key])).to(device) for key in ["atoms", "bonds", "monosacchs"]], dim=0) + ) case "weighted_attention": nodes = {key: self.attention[key](nodes[key], batch_ids[key]) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])} - case "cell_attention": - nodes = {key: self.attention[""](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()} - case "local_cell_attention": - nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()} - case "weighted_cell_attention": - nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])} - return global_mean_pool( - torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), - torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0) - ) + device = nodes["atoms"].device + return global_mean_pool( + torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + torch.concat([torch.arange(len(nodes[key])).to(device) for key in ["atoms", "bonds", "monosacchs"]], dim=0) + ) + #case "cell_attention": + # nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()} + # return self.attention[""]( + # torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + # torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0) + # ) + #case "local_cell_attention": + # nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()} + # return global_mean_pool( + # torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + # torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0) + # ) + #case "weighted_cell_attention": + # nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])} + # return global_mean_pool( + # torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + # torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0) + # ) + #case "double_attention": + # nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()} + # return global_mean_pool( + # torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + # torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0) + # ) + #case "weighted_double_attention": + # nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()} + # return global_mean_pool( + # torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0), + # torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0) + # ) + case _: + raise ValueError(f"Pooling method {self.mode} not supported.") +