Skip to content

Commit

Permalink
Migrate to pydantic 2.0+
Browse files Browse the repository at this point in the history
Migrate to pydantic 2.0+
  • Loading branch information
hmacdope authored Oct 10, 2024
2 parents 8e157d8 + 965f0b7 commit 36b9f8b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 28 deletions.
4 changes: 3 additions & 1 deletion devtools/conda-envs/mtenn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ channels:
- conda-forge
dependencies:
- pytorch
- pytorch_geometric
- pytorch_geometric >=2.5.0
- pytorch_cluster
- pytorch_scatter
- pytorch_sparse
- pydantic >=2.0.0a0
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
- fsspec
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- pytorch_cluster
- pytorch_scatter
- pytorch_sparse
- pydantic >=2.0.0a0
- numpy
- h5py
- e3nn
Expand All @@ -19,5 +20,4 @@ dependencies:
- pytest
- pytest-cov
- codecov
- pydantic >=1.10.8,<2.0.0a0
- fsspec
59 changes: 33 additions & 26 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import abc
from enum import Enum
from pydantic import BaseModel, Field, root_validator
from pydantic import model_validator, ConfigDict, BaseModel, Field
import random
from typing import Callable, ClassVar
from typing import Literal, Callable, ClassVar
import mtenn.combination
import mtenn.readout
import mtenn.model
Expand Down Expand Up @@ -140,7 +140,8 @@ class ModelConfigBase(BaseModel):
to implement the ``_build`` method in order to be used.
"""

model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False)
model_type: Literal[ModelType.INVALID] = ModelType.INVALID


# Random seed optional for reproducibility
rand_seed: int | None = Field(
Expand Down Expand Up @@ -240,9 +241,7 @@ class ModelConfigBase(BaseModel):
"``comb_substrate``."
),
)

class Config:
validate_assignment = True
model_config = ConfigDict(validate_assignment=True)

def build(self) -> mtenn.model.Model:
"""
Expand Down Expand Up @@ -394,7 +393,7 @@ def _check_grouped(values):
Makes sure that a Combination method is passed if using a GroupedModel. Only
needs to be called for structure-based models.
"""
if values["grouped"] and (not values["combination"]):
if values.grouped and not values.combination:
raise ValueError("combination must be specified for a GroupedModel.")


Expand Down Expand Up @@ -436,7 +435,7 @@ class GATModelConfig(ModelConfigBase):
"biases": bool,
} #: :meta private:

model_type: ModelType = Field(ModelType.GAT, const=True)
model_type: Literal[ModelType.GAT] = ModelType.GAT

in_feats: int = Field(
_CanonicalAtomFeaturizer().feat_size(),
Expand Down Expand Up @@ -527,14 +526,16 @@ class GATModelConfig(ModelConfigBase):
# num_layers
_from_num_layers = False

@root_validator(pre=False)
def massage_into_lists(cls, values) -> GATModelConfig:
@model_validator(mode="after")
def massage_into_lists(self) -> GATModelConfig:
"""
Validator to handle unifying all the values into the proper list forms based on
the rules described in the class docstring.
"""
values = self.dict()

# First convert string lists to actual lists
for param, param_type in cls.LIST_PARAMS.items():
for param, param_type in self.LIST_PARAMS.items():
param_val = values[param]
if isinstance(param_val, str):
try:
Expand All @@ -548,7 +549,7 @@ def massage_into_lists(cls, values) -> GATModelConfig:

# Get sizes of all lists
list_lens = {}
for p in cls.LIST_PARAMS:
for p in self.LIST_PARAMS:
param_val = values[p]
if not isinstance(param_val, list):
# Shouldn't be possible at this point but just in case
Expand Down Expand Up @@ -577,14 +578,17 @@ def massage_into_lists(cls, values) -> GATModelConfig:
# If we just want a model with one layer, can return early since we've already
# converted everything into lists
if num_layers == 1:
return values
# update self with the new values
self.__dict__.update(values)


# Adjust any length 1 list to be the right length
for p, list_len in list_lens.items():
if list_len == 1:
values[p] = values[p] * num_layers

return values
self.__dict__.update(values)
return self

def _build(self, mtenn_params={}):
"""
Expand Down Expand Up @@ -681,7 +685,7 @@ class SchNetModelConfig(ModelConfigBase):
given in PyG.
"""

model_type: ModelType = Field(ModelType.schnet, const=True)
model_type: Literal[ModelType.schnet] = ModelType.schnet

hidden_channels: int = Field(128, description="Hidden embedding size.")
num_filters: int = Field(
Expand Down Expand Up @@ -738,13 +742,14 @@ class SchNetModelConfig(ModelConfigBase):
),
)

@root_validator(pre=False)
@model_validator(mode="after")
@classmethod
def validate(cls, values):
# Make sure the grouped stuff is properly assigned
ModelConfigBase._check_grouped(values)

# Make sure atomref length is correct (this is required by PyG)
atomref = values["atomref"]
atomref = values.atomref
if (atomref is not None) and (len(atomref) != 100):
raise ValueError(f"atomref must be length 100 (got {len(atomref)})")

Expand Down Expand Up @@ -816,7 +821,7 @@ class E3NNModelConfig(ModelConfigBase):
Class for constructing an e3nn ML model.
"""

model_type: ModelType = Field(ModelType.e3nn, const=True)
model_type: Literal[ModelType.e3nn] = ModelType.e3nn

num_atom_types: int = Field(
100,
Expand Down Expand Up @@ -862,7 +867,8 @@ class E3NNModelConfig(ModelConfigBase):
num_neighbors: float = Field(25, description="Typical number of neighbor nodes.")
num_nodes: float = Field(4700, description="Typical number of nodes in a graph.")

@root_validator(pre=False)
@model_validator(mode="after")
@classmethod
def massage_irreps(cls, values):
"""
Check that the value given for ``irreps_hidden`` can be converted into an Irreps
Expand All @@ -874,7 +880,7 @@ def massage_irreps(cls, values):
ModelConfigBase._check_grouped(values)

# Now deal with irreps
irreps = values["irreps_hidden"]
irreps = values.irreps_hidden
# First see if this string should be converted into a dict
if isinstance(irreps, str):
if ":" in irreps:
Expand Down Expand Up @@ -923,7 +929,7 @@ def massage_irreps(cls, values):
except ValueError:
raise ValueError(f"Couldn't parse irreps dict: {orig_irreps}")

values["irreps_hidden"] = irreps
values.irreps_hidden = irreps
return values

def _build(self, mtenn_params={}):
Expand Down Expand Up @@ -994,7 +1000,7 @@ class ViSNetModelConfig(ModelConfigBase):
given in PyG.
"""

model_type: ModelType = Field(ModelType.visnet, const=True)
model_type: Literal[ModelType.visnet] = ModelType.visnet
lmax: int = Field(1, description="The maximum degree of the spherical harmonics.")
vecnorm_type: str | None = Field(
None, description="The type of normalization to apply to the vectors."
Expand Down Expand Up @@ -1041,7 +1047,8 @@ class ViSNetModelConfig(ModelConfigBase):
),
)

@root_validator(pre=False)
@model_validator(mode="after")
@classmethod
def validate(cls, values):
"""
Check that ``atomref`` and ``max_z`` agree.
Expand All @@ -1050,10 +1057,10 @@ def validate(cls, values):
ModelConfigBase._check_grouped(values)

# Make sure atomref length is correct (this is required by PyG)
atomref = values["atomref"]
if (atomref is not None) and (len(atomref) != values["max_z"]):
atomref = values.atomref
if (atomref is not None) and (len(atomref) != values.max_z):
raise ValueError(
f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})"
f"atomref length must match max_z. (Expected {values.max_z}, got {len(atomref)})"
)

return values
Expand Down

0 comments on commit 36b9f8b

Please sign in to comment.