From 663b14a4420cdcf8b7d29005e89ddd71103e14e3 Mon Sep 17 00:00:00 2001 From: Roman Joeres Date: Thu, 24 Oct 2024 12:25:11 +0200 Subject: [PATCH] Move from InMemory to OnDisk data --- gifflar/data/datasets.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gifflar/data/datasets.py b/gifflar/data/datasets.py index c388dfa..0708eb7 100644 --- a/gifflar/data/datasets.py +++ b/gifflar/data/datasets.py @@ -5,13 +5,13 @@ import numpy as np import pandas as pd import torch -from torch_geometric.data import InMemoryDataset, HeteroData +from torch_geometric.data import InMemoryDataset, HeteroData, OnDiskDataset from tqdm import tqdm from gifflar.data.utils import GlycanStorage -class GlycanDataset(InMemoryDataset): +class GlycanDataset(OnDiskDataset): def __init__( self, root: str | Path, @@ -36,8 +36,8 @@ def __init__( """ self.filename = Path(filename) self.dataset_args = dataset_args - super().__init__(root=str(Path(root) / f"{self.filename.stem}_{hash_code}"), transform=transform, - pre_transform=pre_transform) + self.pre_transform = pre_transform + super().__init__(root=str(Path(root) / f"{self.filename.stem}_{hash_code}"), transform=transform) self.data, self.dataset_args = torch.load(self.processed_paths[path_idx]) def __len__(self) -> int: @@ -68,7 +68,6 @@ def process_(self, data: list[HeteroData], path_idx: int = 0) -> None: if self.pre_filter is not None: data = [d for d in data if self.pre_filter(d)] if self.pre_transform is not None: - # data = [self.pre_transform(d) for d in data] data = self.pre_transform(data) torch.save((data, self.dataset_args), self.processed_paths[path_idx])