Skip to content

Commit

Permalink
Bug fixes for embedding and pretraining data expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Sep 4, 2024
1 parent ab28865 commit 0a27576
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 53 deletions.
10 changes: 0 additions & 10 deletions configs/embed/head.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,5 @@ prepare:
ckpt_path: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/checkpoints/epoch=99-step=6200.ckpt
hparams_path: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/hparams.yaml
save_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/
nth_layer: -1
pre-transforms:
model:
# - name: gifflar
# feat_dim: 128
# hidden_dim: 1024
# batch_size: 256
# num_layers: 8
# epochs: 100
# learning_rate: 0.001
# optimizer: Adam
# loss: dynamic
15 changes: 10 additions & 5 deletions gifflar/acquisition/collect_pretrain_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
if not glycans_path.exists():
import collect_glycan_data

print("Collecting subglycan data...\n============================")

with open(glycans_path, "rb") as f:
_, iupacs, _ = pickle.load(f)


def node_label_hash(label):
"""Hash function for individual node labels."""
Expand Down Expand Up @@ -70,10 +65,20 @@ def cut_and_add(glycan):
cut_and_add(G)


print("Collecting subglycan data...\n============================")

with open(glycans_path, "rb") as f:
_, iupacs, _ = pickle.load(f)
iupacs = sorted(iupacs, key=lambda x: x.count("("))

known_iupacs = []
known = set()
for i, iupac in enumerate(iupacs):
print(f"\r{i}/{len(iupacs)}\t{iupac}", end="")
if iupac.count("(") == 15:
print()
print(f"Stopped calculation due to high complexity after {i} of {len(iupacs)} glycans.")
break
try:
cut_and_add(glycan_to_nxGraph(iupac))
except:
Expand Down
4 changes: 1 addition & 3 deletions gifflar/data/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def predict_dataloader(self) -> DataLoader:
DataLoader for the combined data
"""
predict = ConcatDataset([self.train, self.val, self.test])
self.batch_size = 4
print("Batch-Size:", min(self.batch_size, len(predict)))
return DataLoader(predict, batch_size=min(self.batch_size, len(predict)), shuffle=False,
return DataLoader(predict, batch_size=1, shuffle=False,
collate_fn=hetero_collate, num_workers=self.num_workers)


Expand Down
26 changes: 8 additions & 18 deletions gifflar/model/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from pathlib import Path
from typing import Literal, Optional, Any

import torch
Expand All @@ -13,7 +14,7 @@

class PretrainGGIN(GlycanGIN):
def __init__(self, hidden_dim: int, tasks: list[dict[str, Any]] | None, num_layers: int = 3, batch_size: int = 32,
pre_transform_args: Optional[dict] = None, **kwargs: Any):
pre_transform_args: Optional[dict] = None, save_dir: Path | str | None = None, **kwargs: Any):
"""
Initialize the PretrainGGIN model, a pre-training model for downstream tasks.
Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(self, hidden_dim: int, tasks: list[dict[str, Any]] | None, num_laye
= get_prediction_head(hidden_dim, 16, "multilabel", "mods")

self.loss = MultiLoss(4, dynamic=kwargs.get("loss", "static") == "dynamic")
self.n = -1
self.save_dir = save_dir

def to(self, device: torch.device) -> "PretrainGGIN":
"""
Expand Down Expand Up @@ -74,15 +75,6 @@ def to(self, device: torch.device) -> "PretrainGGIN":
super(PretrainGGIN, self).to(device)
return self

def save_nth_layer(self, n: int) -> None:
"""
Save the nth layer of the model to the specified path.
Args:
n: The layer to save
"""
self.nth_layer = n

def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:
"""
Forward pass of the model.
Expand Down Expand Up @@ -128,19 +120,17 @@ def predict_step(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:

batch.x_dict[key] = torch.concat(pes, dim=1)

layer_count = 0
layers = [copy.deepcopy(batch.x_dict)]
for conv in self.convs:
# save the nth layer as the final node embeddings
if layer_count == self.nth_layer:
return {"node_embeds": batch.x_dict, "batch_ids": batch.batch_dict, "smiles": batch["smiles"]}

if isinstance(conv, HeteroConv):
batch.x_dict = conv(batch.x_dict, batch.edge_index_dict)
layer_count += 1
layers.append(copy.deepcopy(batch.x_dict))
else: # the layer is an activation function from the RGCN
batch.x_dict = conv(batch.x_dict)

return {"node_embeds": batch.x_dict, "batch_ids": batch.batch_dict, "smiles": batch["smiles"]}
torch.save(layers, self.save_dir / f"{hash(batch.smiles[0])}.pt")

return {}

def shared_step(self, batch: HeteroDataBatch, stage: Literal["train", "val", "test"]) -> dict[str, torch.Tensor]:
"""
Expand Down
15 changes: 6 additions & 9 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,9 @@ def __init__(self, folder: str, dataset_name: str, model_name: str, hash_str: st
kwargs: Additional arguments
"""
super(PretrainEmbed, self).__init__(**kwargs)
self.data = torch.load(Path(folder) / f"{dataset_name}_{model_name}_{hash_str}.pt")
self.lookup = {smiles: (i, j) for i in range(len(self.data)) for j, smiles in enumerate(self.data[i]["smiles"])}
self.data_dir = Path(folder) / f"{dataset_name}_{model_name}_{hash_str}"
self.pooling = GIFFLARPooling()
self.layer = -1

def __call__(self, data: HeteroData) -> HeteroData:
"""
Expand All @@ -445,18 +445,15 @@ def __call__(self, data: HeteroData) -> HeteroData:
Returns:
The transformed data.
"""
if data["smiles"] not in self.lookup:
h = hash(data["smiles"])
p = self.data_dir / f"{h}.pt"
if not p.exists():
print(data["smiles"], "not found in preprocessed data.")
data["fp"] = torch.zeros_like(data["fp"])
return data

# Lookup positional indices of the smiles string and compute masks for which data to use
a, b = self.lookup[data["smiles"]]
mask = {key: self.data[a]["batch_ids"][key] == b for key in ["atoms", "bonds", "monosacchs"]}

# apply the masks and extract the node embeddings and compute batch ids
node_embeds = {key: self.data[a]["node_embeds"][key][mask[key]] for key in
["atoms", "bonds", "monosacchs"]}
node_embeds = torch.load(p)[self.layer]
batch_ids = {key: torch.zeros(len(node_embeds[key]), dtype=torch.long) for key in
["atoms", "bonds", "monosacchs"]}
data["fp"] = self.pooling(node_embeds, batch_ids)
Expand Down
13 changes: 5 additions & 8 deletions gifflar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,8 @@ def embed(prep_args: dict[str, str], **kwargs: Any) -> None:
pkl_dir: The directory to save the embeddings.
kwargs: The configuration for the training.
"""
output_name = Path(prep_args["save_dir"]) / ("_".join([
kwargs["dataset"]["name"],
prep_args["name"],
hash_dict(prep_args, 8),
]) + ".pt")
output_name = (Path(prep_args["save_dir"]) /
f"{kwargs['dataset']['name']}_{prep_args['name']}_{hash_dict(prep_args, 8)}")
if output_name.exists():
return

Expand All @@ -213,9 +210,9 @@ def embed(prep_args: dict[str, str], **kwargs: Any) -> None:
model.save_nth_layer(int(prep_args["nth_layer"]))

data_config, data, _, _ = setup(2, **kwargs)
data.save_dir = output_name
trainer = Trainer()
preds = trainer.predict(model, data.predict_dataloader())
torch.save(preds, output_name)
trainer.predict(model, data.predict_dataloader())


def main(config: str | Path) -> None:
Expand All @@ -226,7 +223,7 @@ def main(config: str | Path) -> None:
for args in unfold_config(custom_args):
print(args)
if "prepare" in args:
args = embed(args["prepare"], **args)
embed(args["prepare"], **args)
else:
if args["model"]["name"] in ["rf", "svm", "xgb"]:
fit(**args)
Expand Down

0 comments on commit 0a27576

Please sign in to comment.