Skip to content

Commit

Permalink
TQDMCompose transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Aug 13, 2024
1 parent 4189fbd commit 25c01d7
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import copy
from typing import Any
from typing import Any, Union

import torch
from glycowork.glycan_data.loader import lib
from glycowork.motif.graph import glycan_to_nxGraph
from rdkit.Chem import AllChem, rdFingerprintGenerator
from torch_geometric import transforms as T
from torch_geometric.data import Data
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import Compose, AddLaplacianEigenvectorPE
from torch_geometric.transforms.base_transform import BaseTransform
from torch_geometric.utils import from_networkx, to_dense_adj
from tqdm import tqdm

from gifflar.utils import bond_map, lib_map, atom_map

Expand Down Expand Up @@ -236,6 +237,18 @@ def __call__(self, data):
return data


class TQDMCompose(Compose):
def forward(self, data: Union[Data, HeteroData]):
with tqdm(total=len(self.transforms), desc="Transforms") as t_bar:
for transform in self.transforms:
if isinstance(data, (list, tuple)):
data = transform(data)
else:
data = [transform(d) for d in tqdm(data, total=len(data), desc="Samples", leave=False)]
t_bar.update(1)
return data


def get_pretransforms(**pre_transform_args) -> [T.Compose]:
pre_transforms = [
GIFFLARTransform(**pre_transform_args.get("GIFFLARTransform", {})),
Expand All @@ -249,4 +262,4 @@ def get_pretransforms(**pre_transform_args) -> [T.Compose]:
pre_transforms.append(LaplacianPE(**args))
if name == "RandomWalkPE":
pre_transforms.append(RandomWalkPE(**args))
return Compose(pre_transforms)
return TQDMCompose(pre_transforms)

0 comments on commit 25c01d7

Please sign in to comment.