From 4fae06cc9cc928180070504badc6f2b96f6d328e Mon Sep 17 00:00:00 2001 From: Old-Shatterhand Date: Fri, 18 Oct 2024 13:14:12 +0200 Subject: [PATCH] SB server update --- .gitignore | 5 +++-- collect.py | 12 ++++++++---- gifflar/data/hetero.py | 1 + gifflar/pretransforms.py | 20 +++++++++----------- gifflar/train.py | 2 +- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 2bf02d1..644b774 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,9 @@ .idea/* logs/* data/* -gifflar/.ipynb_checkpoints/* +**/.ipynb_checkpoints/ *.ipynb *.tsv tests/data/* -tests/logs/* \ No newline at end of file +tests/logs/* +**/__pycache__/ diff --git a/collect.py b/collect.py index 564ae82..ee3b520 100644 --- a/collect.py +++ b/collect.py @@ -65,7 +65,11 @@ def order_models(name: str) -> tuple: worksheet = writer.sheets[metric] bold_format = workbook.add_format({"bold": True}) for r, (_, row) in enumerate(df.iterrows()): - max_val = row.values.max() - for c, val in enumerate(row.values): - if val == max_val: - worksheet.write(r + 1, c + 1, val, bold_format) + try: + max_val = row.values.max() + for c, val in enumerate(row.values): + if val == max_val: + worksheet.write(r + 1, c + 1, val, bold_format) + except: + pass + diff --git a/gifflar/data/hetero.py b/gifflar/data/hetero.py index b891d2d..dd07888 100644 --- a/gifflar/data/hetero.py +++ b/gifflar/data/hetero.py @@ -129,6 +129,7 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData] # For each baseline, collate its node features and edge indices as well for b in baselines: + break kwargs[f"{b}_x"] = torch.cat([d[f"{b}_x"] for d in data], dim=0) edges = [] batch = [] diff --git a/gifflar/pretransforms.py b/gifflar/pretransforms.py index 3218f67..bb95fb9 100644 --- a/gifflar/pretransforms.py +++ b/gifflar/pretransforms.py @@ -170,13 +170,10 @@ def __call__(self, data: HeteroData) -> HeteroData: ]) data["rgcn_node_type"] = ["atoms"] * len(data["atoms"].x) + ["bonds"] * len(data["bonds"].x) + ["monosacchs"] * len(data["monosacchs"].x) data["rgcn_num_nodes"] = len(data["rgcn_x"]) - data["rgcn_edge_type"] = torch.tensor( - [0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1] - + [1] * data["atoms", "to", "bonds"].edge_index.shape[1] - + [2] * data["bonds", "boundary", "bonds"].edge_index.shape[1] - + [3] * data["bonds", "to", "monosacchs"].edge_index.shape[1] - + [4] * data["monosacchs", "boundary", "monosacchs"].edge_index.shape[1] - ) + tmp = [0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1] + [1] * data["atoms", "to", "bonds"].edge_index.shape[1] + [2] * data["bonds", "boundary", "bonds"].edge_index.shape[1] + [3] * data["bonds", "to", "monosacchs"].edge_index.shape[1] + if len(data["monosacchs"].x) > 1: + tmp += [4] * data["monosacchs", "boundary", "monosacchs"].edge_index.shape[1] + data["rgcn_edge_type"] = torch.tensor(tmp) tmp = [] offset = {"atoms": 0, "bonds": data["atoms"]["num_nodes"], @@ -186,10 +183,11 @@ def __call__(self, data: HeteroData) -> HeteroData: ("bonds", "to", "monosacchs"), ("bonds", "boundary", "bonds"), ("monosacchs", "boundary", "monosacchs")]: - tmp.append(torch.stack([ - data[key].edge_index[0] + offset[key[0]], - data[key].edge_index[1] + offset[key[2]], - ])) + if len(data[key].edge_index.shape) == 2: + tmp.append(torch.stack([ + data[key].edge_index[0] + offset[key[0]], + data[key].edge_index[1] + offset[key[2]], + ])) data["rgcn_edge_index"] = torch.cat(tmp, dim=1) return data diff --git a/gifflar/train.py b/gifflar/train.py index e727b8a..400cec2 100644 --- a/gifflar/train.py +++ b/gifflar/train.py @@ -146,7 +146,7 @@ def train(**kwargs: Any) -> None: ], max_epochs=kwargs["model"]["epochs"], logger=logger, - accelerator="cpu", + # accelerator="cpu", ) start = time.time() trainer.fit(model, datamodule)