Skip to content

Commit

Permalink
Make embedding extraction less memory heavy
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Sep 3, 2024
1 parent c65e03e commit bc0a530
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 7 deletions.
40 changes: 40 additions & 0 deletions configs/downstream/pretrained.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
seed: 42
data_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_embed
datasets:
- name: Immunogenicity
task: classification
# - name: Glycosylation
# task: classification
# - name: Taxonomy_Domain
# task: multilabel
# - name: Taxonomy_Kingdom
# task: multilabel
# - name: Taxonomy_Phylum
# task: multilabel
# - name: Taxonomy_Class
# task: multilabel
# - name: Taxonomy_Order
# task: multilabel
# - name: Taxonomy_Family
# task: multilabel
# - name: Taxonomy_Genus
# task: multilabel
# - name: Taxonomy_Species
# task: multilabel
pre-transforms:
PretrainEmbed:
folder: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/
model_name: GIFFLAR
hash_str: 3fd297ab
model:
- name: mlp
feat_dim: 1024
hidden_dim: 1024
batch_size: 256
num_layers: 3
epochs: 100
patience: 30
learning_rate: 0
optimizer: Adam
1 change: 1 addition & 0 deletions configs/embed/head.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ prepare:
ckpt_path: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/checkpoints/epoch=99-step=6200.ckpt
hparams_path: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pret/gifflar_dyn_re_pretrain/version_0/hparams.json
save_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_embed/
nth_layer: -1
pre-transforms:
model:
20 changes: 17 additions & 3 deletions gifflar/model/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, hidden_dim: int, tasks: list[dict[str, Any]] | None, num_laye
= get_prediction_head(hidden_dim, 16, "multilabel", "mods")

self.loss = MultiLoss(4, dynamic=kwargs.get("loss", "static") == "dynamic")
self.n = -1

def to(self, device: torch.device) -> "PretrainGGIN":
"""
Expand Down Expand Up @@ -73,6 +74,15 @@ def to(self, device: torch.device) -> "PretrainGGIN":
super(PretrainGGIN, self).to(device)
return self

def save_nth_layer(self, n: int) -> None:
"""
Save the nth layer of the model to the specified path.
Args:
n: The layer to save
"""
self.nth_layer = n

def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:
"""
Forward pass of the model.
Expand Down Expand Up @@ -118,15 +128,19 @@ def predict_step(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:

batch.x_dict[key] = torch.concat(pes, dim=1)

node_embeds = []
layer_count = 0
for conv in self.convs:
# save the nth layer as the final node embeddings
if layer_count == self.nth_layer:
return {"node_embeds": batch.x_dict, "batch_ids": batch.batch_dict, "smiles": batch["smiles"]}

if isinstance(conv, HeteroConv):
batch.x_dict = conv(batch.x_dict, batch.edge_index_dict)
node_embeds.append(copy.deepcopy(batch.x_dict))
layer_count += 1
else: # the layer is an activation function from the RGCN
batch.x_dict = conv(batch.x_dict)

return {"node_embeds": node_embeds, "batch_ids": batch.batch_dict, "smiles": batch["smiles"]}
return {"node_embeds": batch.x_dict, "batch_ids": batch.batch_dict, "smiles": batch["smiles"]}

def shared_step(self, batch: HeteroDataBatch, stage: Literal["train", "val", "test"]) -> dict[str, torch.Tensor]:
"""
Expand Down
5 changes: 2 additions & 3 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,6 @@ def __init__(self, folder: str, dataset_name: str, model_name: str, hash_str: st
self.data = torch.load(Path(folder) / f"{dataset_name}_{model_name}_{hash_str}.pt")
self.lookup = {smiles: (i, j) for i in range(len(self.data)) for j, smiles in enumerate(self.data[i]["smiles"])}
self.pooling = GIFFLARPooling()
self.layer = kwargs.get("layer", -1)

def __call__(self, data: HeteroData) -> HeteroData:
"""
Expand All @@ -456,7 +455,7 @@ def __call__(self, data: HeteroData) -> HeteroData:
mask = {key: self.data[a]["batch_ids"][key] == b for key in ["atoms", "bonds", "monosacchs"]}

# apply the masks and extract the node embeddings and compute batch ids
node_embeds = {key: self.data[a]["node_embeds"][self.layer][key][mask[key]] for key in
node_embeds = {key: self.data[a]["node_embeds"][key][mask[key]] for key in
["atoms", "bonds", "monosacchs"]}
batch_ids = {key: torch.zeros(len(node_embeds[key]), dtype=torch.long) for key in
["atoms", "bonds", "monosacchs"]}
Expand Down Expand Up @@ -485,7 +484,7 @@ def forward(self, data: list[Union[Data, HeteroData]]):
return data


def get_pretransforms(dataset_name, **pre_transform_args) -> TQDMCompose:
def get_pretransforms(dataset_name: str = "", **pre_transform_args: dict[str, dict]) -> TQDMCompose:
"""
Calculate the list of pre-transforms to be applied to the data.
Expand Down
6 changes: 5 additions & 1 deletion gifflar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,12 @@ def embed(prep_args: dict[str, str], **kwargs: Any) -> None:
with open(prep_args["hparams_path"], "r") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
model = PretrainGGIN(**config["model"], tasks=None, pre_transform_args=kwargs.get("pre-transforms", {}))
model.load_state_dict(torch.load(prep_args["ckpt_path"], map_location=torch.device("cpu"))["state_dict"])
if torch.is_cuda_available():
model.load_state_dict(torch.load(prep_args["ckpt_path"])["state_dict"])
else:
model.load_state_dict(torch.load(prep_args["ckpt_path"], map_location=torch.device("cpu"))["state_dict"])
model.eval()
model.save_nth_layer(int(prep_args["nth_layer"]))

data_config, data, _, _ = setup(2, **kwargs)
trainer = Trainer()
Expand Down

0 comments on commit bc0a530

Please sign in to comment.