Skip to content

Commit

Permalink
SB server update
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Oct 18, 2024
1 parent 29cde05 commit 4fae06c
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 18 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
.idea/*
logs/*
data/*
gifflar/.ipynb_checkpoints/*
**/.ipynb_checkpoints/
*.ipynb
*.tsv
tests/data/*
tests/logs/*
tests/logs/*
**/__pycache__/
12 changes: 8 additions & 4 deletions collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def order_models(name: str) -> tuple:
worksheet = writer.sheets[metric]
bold_format = workbook.add_format({"bold": True})
for r, (_, row) in enumerate(df.iterrows()):
max_val = row.values.max()
for c, val in enumerate(row.values):
if val == max_val:
worksheet.write(r + 1, c + 1, val, bold_format)
try:
max_val = row.values.max()
for c, val in enumerate(row.values):
if val == max_val:
worksheet.write(r + 1, c + 1, val, bold_format)
except:
pass

1 change: 1 addition & 0 deletions gifflar/data/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]

# For each baseline, collate its node features and edge indices as well
for b in baselines:
break
kwargs[f"{b}_x"] = torch.cat([d[f"{b}_x"] for d in data], dim=0)
edges = []
batch = []
Expand Down
20 changes: 9 additions & 11 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,10 @@ def __call__(self, data: HeteroData) -> HeteroData:
])
data["rgcn_node_type"] = ["atoms"] * len(data["atoms"].x) + ["bonds"] * len(data["bonds"].x) + ["monosacchs"] * len(data["monosacchs"].x)
data["rgcn_num_nodes"] = len(data["rgcn_x"])
data["rgcn_edge_type"] = torch.tensor(
[0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1]
+ [1] * data["atoms", "to", "bonds"].edge_index.shape[1]
+ [2] * data["bonds", "boundary", "bonds"].edge_index.shape[1]
+ [3] * data["bonds", "to", "monosacchs"].edge_index.shape[1]
+ [4] * data["monosacchs", "boundary", "monosacchs"].edge_index.shape[1]
)
tmp = [0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1] + [1] * data["atoms", "to", "bonds"].edge_index.shape[1] + [2] * data["bonds", "boundary", "bonds"].edge_index.shape[1] + [3] * data["bonds", "to", "monosacchs"].edge_index.shape[1]
if len(data["monosacchs"].x) > 1:
tmp += [4] * data["monosacchs", "boundary", "monosacchs"].edge_index.shape[1]
data["rgcn_edge_type"] = torch.tensor(tmp)
tmp = []
offset = {"atoms": 0,
"bonds": data["atoms"]["num_nodes"],
Expand All @@ -186,10 +183,11 @@ def __call__(self, data: HeteroData) -> HeteroData:
("bonds", "to", "monosacchs"),
("bonds", "boundary", "bonds"),
("monosacchs", "boundary", "monosacchs")]:
tmp.append(torch.stack([
data[key].edge_index[0] + offset[key[0]],
data[key].edge_index[1] + offset[key[2]],
]))
if len(data[key].edge_index.shape) == 2:
tmp.append(torch.stack([
data[key].edge_index[0] + offset[key[0]],
data[key].edge_index[1] + offset[key[2]],
]))
data["rgcn_edge_index"] = torch.cat(tmp, dim=1)
return data

Expand Down
2 changes: 1 addition & 1 deletion gifflar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def train(**kwargs: Any) -> None:
],
max_epochs=kwargs["model"]["epochs"],
logger=logger,
accelerator="cpu",
# accelerator="cpu",
)
start = time.time()
trainer.fit(model, datamodule)
Expand Down

0 comments on commit 4fae06c

Please sign in to comment.