Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Oct 18, 2024
2 parents 4d56434 + 38cc95c commit 2fe83c6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
logs/*
data/*
**/.ipynb_checkpoints/
**/__pycache__/
*.ipynb
*.tsv
tests/data/*
tests/logs/*
**/__pycache__/
18 changes: 13 additions & 5 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,16 @@ def __call__(self, data: HeteroData) -> HeteroData:
])
data["rgcn_node_type"] = ["atoms"] * len(data["atoms"].x) + ["bonds"] * len(data["bonds"].x) + ["monosacchs"] * len(data["monosacchs"].x)
data["rgcn_num_nodes"] = len(data["rgcn_x"])
tmp = [0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1] + [1] * data["atoms", "to", "bonds"].edge_index.shape[1] + [2] * data["bonds", "boundary", "bonds"].edge_index.shape[1] + [3] * data["bonds", "to", "monosacchs"].edge_index.shape[1]
if len(data["monosacchs"].x) > 1:
tmp += [4] * data["monosacchs", "boundary", "monosacchs"].edge_index.shape[1]
data["rgcn_edge_type"] = torch.tensor(tmp)
edges = [0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1]
edges += [1] * data["atoms", "to", "bonds"].edge_index.shape[1]
edges += [2] * data["bonds", "boundary", "bonds"].edge_index.shape[1]
edges += [3] * data["bonds", "to", "monosacchs"].edge_index.shape[1]
try:
edges += [4] * data["monosacchs", "boundary", "monosacchs"].edge_index.shape[1]
except:
pass
data["rgcn_edge_type"] = torch.tensor(edges)

tmp = []
offset = {"atoms": 0,
"bonds": data["atoms"]["num_nodes"],
Expand All @@ -183,11 +189,13 @@ def __call__(self, data: HeteroData) -> HeteroData:
("bonds", "to", "monosacchs"),
("bonds", "boundary", "bonds"),
("monosacchs", "boundary", "monosacchs")]:
if len(data[key].edge_index.shape) == 2:
try:
tmp.append(torch.stack([
data[key].edge_index[0] + offset[key[0]],
data[key].edge_index[1] + offset[key[2]],
]))
except:
pass
data["rgcn_edge_index"] = torch.cat(tmp, dim=1)
return data

Expand Down
1 change: 1 addition & 0 deletions gifflar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def pretrain(**kwargs: Any) -> None:
logger.log_hyperparams(kwargs)

trainer = Trainer(
devices=[1],
callbacks=[
ModelCheckpoint(save_top_k=-1),
RichModelSummary(),
Expand Down

0 comments on commit 2fe83c6

Please sign in to comment.