From 38cc95c3ec47385393d3b4ab297db672ce17998f Mon Sep 17 00:00:00 2001 From: GlycanConnector Date: Fri, 18 Oct 2024 13:27:10 +0200 Subject: [PATCH] Server merge conflicts resolved --- .gitignore | 2 +- gifflar/pretransforms.py | 18 +++++++++++++----- gifflar/train.py | 1 + 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 644b774..7ade3a6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,8 +2,8 @@ logs/* data/* **/.ipynb_checkpoints/ +**/__pycache__/ *.ipynb *.tsv tests/data/* tests/logs/* -**/__pycache__/ diff --git a/gifflar/pretransforms.py b/gifflar/pretransforms.py index bb95fb9..75e7474 100644 --- a/gifflar/pretransforms.py +++ b/gifflar/pretransforms.py @@ -170,10 +170,16 @@ def __call__(self, data: HeteroData) -> HeteroData: ]) data["rgcn_node_type"] = ["atoms"] * len(data["atoms"].x) + ["bonds"] * len(data["bonds"].x) + ["monosacchs"] * len(data["monosacchs"].x) data["rgcn_num_nodes"] = len(data["rgcn_x"]) - tmp = [0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1] + [1] * data["atoms", "to", "bonds"].edge_index.shape[1] + [2] * data["bonds", "boundary", "bonds"].edge_index.shape[1] + [3] * data["bonds", "to", "monosacchs"].edge_index.shape[1] - if len(data["monosacchs"].x) > 1: - tmp += [4] * data["monosacchs", "boundary", "monosacchs"].edge_index.shape[1] - data["rgcn_edge_type"] = torch.tensor(tmp) + edges = [0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1] + edges += [1] * data["atoms", "to", "bonds"].edge_index.shape[1] + edges += [2] * data["bonds", "boundary", "bonds"].edge_index.shape[1] + edges += [3] * data["bonds", "to", "monosacchs"].edge_index.shape[1] + try: + edges += [4] * data["monosacchs", "boundary", "monosacchs"].edge_index.shape[1] + except: + pass + data["rgcn_edge_type"] = torch.tensor(edges) + tmp = [] offset = {"atoms": 0, "bonds": data["atoms"]["num_nodes"], @@ -183,11 +189,13 @@ def __call__(self, data: HeteroData) -> HeteroData: ("bonds", "to", "monosacchs"), ("bonds", "boundary", "bonds"), ("monosacchs", "boundary", "monosacchs")]: - if len(data[key].edge_index.shape) == 2: + try: tmp.append(torch.stack([ data[key].edge_index[0] + offset[key[0]], data[key].edge_index[1] + offset[key[2]], ])) + except: + pass data["rgcn_edge_index"] = torch.cat(tmp, dim=1) return data diff --git a/gifflar/train.py b/gifflar/train.py index 400cec2..00e50c2 100644 --- a/gifflar/train.py +++ b/gifflar/train.py @@ -173,6 +173,7 @@ def pretrain(**kwargs: Any) -> None: logger.log_hyperparams(kwargs) trainer = Trainer( + devices=[1], callbacks=[ ModelCheckpoint(save_top_k=-1), RichModelSummary(),