Skip to content

Commit

Permalink
Pretraining code and minor reorganization of configs
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Aug 13, 2024
1 parent 475f2d6 commit a048970
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 38 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file added configs/pretraining/test.yaml
Empty file.
24 changes: 0 additions & 24 deletions gifflar/baselines/gnngly.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@ def __init__(self, output_dim, task, **kwargs):
del self.convs
del self.head

# Define the encoders (sizes based on table 1)
self.atom_encoder = torch.eye(101)
self.chiral_encoder = torch.eye(4)
self.degree_encoder = torch.eye(13)
self.charge_encoder = torch.eye(5)
self.h_encoder = torch.eye(5)
self.hybrid_encoder = torch.eye(5)

# Five layers of plain graph convolution with a hidden dimension of 14.
self.layers = [
GCNConv(133, 14),
Expand All @@ -66,12 +58,6 @@ def __init__(self, output_dim, task, **kwargs):
def to(self, device):
super(GNNGLY, self).to(device)
self.layers = [l.to(device) for l in self.layers]
self.atom_encoder.to(device)
self.chiral_encoder.to(device)
self.degree_encoder.to(device)
self.charge_encoder.to(device)
self.h_encoder.to(device)
self.hybrid_encoder.to(device)


def forward(self, batch):
Expand All @@ -89,16 +75,6 @@ def forward(self, batch):
batch_ids = batch["gnngly_batch"]
edge_index = batch["gnngly_edge_index"]

# Compute the atom-wise encodings
x = torch.stack([torch.concat([
self.atom_encoder[a[0]],
self.chiral_encoder[a[1]],
self.degree_encoder[a[2]],
self.charge_encoder[a[3]],
self.h_encoder[a[4]],
self.hybrid_encoder[a[5]],
]) for a in x]).to(x.device)

# Propagate the data through the model
for layer in self.layers:
x = layer(x, edge_index)
Expand Down
18 changes: 13 additions & 5 deletions gifflar/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tqdm import tqdm
import networkx as nx

from gifflar.pretransforms import GIFFLARTransform
from gifflar.utils import S3NMerger, nx2mol

Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps)
Expand Down Expand Up @@ -108,6 +109,8 @@ def to(self, device: str):
for key, value in v.items():
if hasattr(value, "to"):
v[key] = value.to(device)
else:
raise ValueError(f"Attribute {k} cannot be converted to device {device}.")
return self

def __getitem__(self, item: str) -> Any:
Expand Down Expand Up @@ -441,8 +444,6 @@ def __init__(
root: str | Path,
filename: str | Path,
hash_code: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
**dataset_args
):
"""
Expand All @@ -456,8 +457,8 @@ def __init__(
pre_transform: The pre-transform to apply to the data
**dataset_args: Additional arguments to pass to the dataset
"""
super().__init__(root=root, filename=filename, hash_code=hash_code, transform=transform,
pre_transform=pre_transform, **dataset_args)
super().__init__(root=root, filename=filename, hash_code=hash_code,
pre_transform=GIFFLARTransform(), **dataset_args)

@property
def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:
Expand All @@ -467,7 +468,14 @@ def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:
def process(self):
"""Process the data and store it."""
data = []
# to be implemented
gs = GlycanStorage(Path(self.root).parent)
df = pd.read_csv(self.filename, sep="\t" if self.filename.suffix.lower().endswith(".tsv") else ",")
for index, (_, row) in tqdm(enumerate(df.iterrows())):
d = gs.query(row["IUPAC"])
d["ID"] = index
data.append(d)
gs.close()
print("Processed", len(data), "entries")
self.process_(data)


Expand Down
27 changes: 18 additions & 9 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __call__(self, data):
data["bonds", "coboundary", "bonds"].edge_index = torch.tensor(
[(bond1, bond2) for ring in data["mol"].GetRingInfo().BondRings() for bond1 in ring for bond2 in ring if
bond1 != bond2], dtype=torch.long).T
data["monosacchs"].x = torch.tensor([
data["monosacchs"].x = torch.tensor([ # This does not make sense. The monomer-ids are categorical features
lib_map.get(data["tree"].nodes[node]["name"], len(lib_map)) for node in data["tree"].nodes
])
data["monosacchs"].num_nodes = len(data["monosacchs"].x)
Expand All @@ -98,15 +98,24 @@ def __call__(self, data):


class GNNGLYTransform(RootTransform):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.atom_encoder = torch.eye(101)
self.chiral_encoder = torch.eye(4)
self.degree_encoder = torch.eye(13)
self.charge_encoder = torch.eye(5)
self.h_encoder = torch.eye(5)
self.hybrid_encoder = torch.eye(5)

def __call__(self, data):
data["gnngly_x"] = torch.tensor([[
min(atom.GetAtomicNum(), 100),
min(atom.GetChiralTag(), 3),
min(atom.GetDegree(), 12),
min(atom.GetFormalCharge(), 4),
min(atom.GetTotalNumHs(), 4),
min(atom.GetHybridization(), 4),
] for atom in data["mol"].GetAtoms()])
data["gnngly_x"] = torch.stack([torch.concat([
self.atom_encoder[min(atom.GetAtomicNum(), 100)],
self.chiral_encoder[min(atom.GetChiralTag(), 3)],
self.degree_encoder[min(atom.GetDegree(), 12)],
self.charge_encoder[min(atom.GetFormalCharge(), 4)],
self.h_encoder[min(atom.GetTotalNumHs(), 4)],
self.hybrid_encoder[min(atom.GetHybridization(), 4)],
]) for atom in data["mol"].GetAtoms()])
data["gnngly_num_nodes"] = len(data["gnngly_x"])
data["gnngly_edge_index"] = copy.deepcopy(data["atoms", "coboundary", "atoms"].edge_index)
return data
Expand Down
6 changes: 6 additions & 0 deletions gifflar/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from gifflar.pretransforms import RootTransform


class MonosaccharideMasking(RootTransform):
def __call__(self, data):
pass

0 comments on commit a048970

Please sign in to comment.