Skip to content

Commit

Permalink
Branch to update to new dgl version (#179)
Browse files Browse the repository at this point in the history
* Branch to update to new dgl version

* Also only linux builds

* See if this makes the solver happy

* see if the shim solves the GPU problem

* see if the python 3.10 build works better

* unpin pytorch since the conda-forge package should have the right pins

* until we get CPU builds, I think I need to fake a cuda toolkit version

* Remove shim

* see if CPU builds JustWork TM

* update to testing channel that has osx support

* update API to new dgl 1.1.0

* figure out which version of python are giving us issues

* bump ci

* see if new logic works

* Update espaloma.yaml

* bump ci
  • Loading branch information
mikemhenry committed Aug 21, 2023
1 parent dd8ebb7 commit 6229a3e
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 42 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ['macos', 'ubuntu']
os: ['ubuntu','macos']
python-version:
- "3.9"
- "3.11"
- "3.10"
- "3.9"

env:
OPENMM: ${{ matrix.cfg.openmm }}
Expand Down
5 changes: 1 addition & 4 deletions devtools/conda-envs/espaloma.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: espaloma-test
channels:
- conda-forge
- dglteam
- openeye
dependencies:
# Base dependencies
Expand All @@ -20,9 +19,7 @@ dependencies:
- openmmforcefields >=0.11.2
- tqdm
- pydantic <2 # We need our deps to fix this
# Pytorch
- pytorch>=1.8.0
- dgl <1
- dgl =1.1.0
# Testing
- pytest
- pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion espaloma/app/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test(self):
# from time to time
# make it just one giant graph
# g = list(self.data)
# g = dgl.batch_hetero(g)
# g = dgl.batch(g)
# g = g.to(self.device)

if self.states is None:
Expand Down
4 changes: 2 additions & 2 deletions espaloma/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,13 @@ def batch(graphs):
import dgl

if all(isinstance(graph, esp.graphs.graph.Graph) for graph in graphs):
return dgl.batch_hetero([graph.heterograph for graph in graphs])
return dgl.batch([graph.heterograph for graph in graphs])

elif all(isinstance(graph, dgl.DGLGraph) for graph in graphs):
return dgl.batch(graphs)

elif all(isinstance(graph, dgl.DGLHeteroGraph) for graph in graphs):
return dgl.batch_hetero(graphs)
return dgl.batch(graphs)

else:
raise RuntimeError(
Expand Down
16 changes: 8 additions & 8 deletions espaloma/mm/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def energy_in_graph(
{
"%s_in_g"
% term: (
dgl.function.copy_src(src="u%s" % suffix, out="m_%s" % term),
dgl.function.copy_u(u="u%s" % suffix, out="m_%s" % term),
dgl.function.sum(
msg="m_%s" % term, out="u_%s%s" % (term, suffix)
),
Expand Down Expand Up @@ -462,31 +462,31 @@ def forward(self, g):
g.multi_update_all(
{
"n2_as_0_in_n3": (
dgl.function.copy_src("u", "m_u_0"),
dgl.function.copy_u("u", "m_u_0"),
dgl.function.sum("m_u_0", "u_left"),
),
"n2_as_1_in_n3": (
dgl.function.copy_src("u", "m_u_1"),
dgl.function.copy_u("u", "m_u_1"),
dgl.function.sum("m_u_1", "u_right"),
),
"n2_as_0_in_n4": (
dgl.function.copy_src("u", "m_u_0"),
dgl.function.copy_u("u", "m_u_0"),
dgl.function.sum("m_u_0", "u_bond_left"),
),
"n2_as_1_in_n4": (
dgl.function.copy_src("u", "m_u_1"),
dgl.function.copy_u("u", "m_u_1"),
dgl.function.sum("m_u_1", "u_bond_center"),
),
"n2_as_2_in_n4": (
dgl.function.copy_src("u", "m_u_2"),
dgl.function.copy_u("u", "m_u_2"),
dgl.function.sum("m_u_2", "u_bond_right"),
),
"n3_as_0_in_n4": (
dgl.function.copy_src("u", "m3_u_0"),
dgl.function.copy_u("u", "m3_u_0"),
dgl.function.sum("m3_u_0", "u_angle_left"),
),
"n3_as_1_in_n4": (
dgl.function.copy_src("u", "m3_u_1"),
dgl.function.copy_u("u", "m3_u_1"),
dgl.function.sum("m3_u_1", "u_angle_right"),
),
},
Expand Down
6 changes: 3 additions & 3 deletions espaloma/mm/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def geometry_in_graph(g):
**{
"n1_as_%s_in_n%s"
% (pos_idx, big_idx): (
dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx),
dgl.function.copy_u(u="xyz", out="m_xyz%s" % pos_idx),
dgl.function.sum(
msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx
),
Expand All @@ -199,7 +199,7 @@ def geometry_in_graph(g):
**{
"n1_as_%s_in_%s"
% (pos_idx, term): (
dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx),
dgl.function.copy_u(u="xyz", out="m_xyz%s" % pos_idx),
dgl.function.sum(
msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx
),
Expand All @@ -210,7 +210,7 @@ def geometry_in_graph(g):
**{
"n1_as_%s_in_%s"
% (pos_idx, term): (
dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx),
dgl.function.copy_u(u="xyz", out="m_xyz%s" % pos_idx),
dgl.function.sum(
msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx
),
Expand Down
8 changes: 4 additions & 4 deletions espaloma/mm/nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def lorentz_berthelot(g, suffix=""):
{
"n1_as_%s_in_%s"
% (pos_idx, term): (
dgl.function.copy_src(
src="epsilon%s" % suffix, out="m_epsilon"
dgl.function.copy_u(
u="epsilon%s" % suffix, out="m_epsilon"
),
geometric_mean(msg="m_epsilon", out="epsilon%s" % suffix),
)
Expand All @@ -63,7 +63,7 @@ def lorentz_berthelot(g, suffix=""):
{
"n1_as_%s_in_%s"
% (pos_idx, term): (
dgl.function.copy_src(src="sigma%s" % suffix, out="m_sigma"),
dgl.function.copy_u(u="sigma%s" % suffix, out="m_sigma"),
arithmetic_mean(msg="m_sigma", out="sigma%s" % suffix),
)
for pos_idx in [0, 1]
Expand Down Expand Up @@ -94,7 +94,7 @@ def multiply_charges(g, suffix=""):
{
"n1_as_%s_in_%s"
% (pos_idx, term): (
dgl.function.copy_src(src="q%s" % suffix, out="m_q"),
dgl.function.copy_u(u="q%s" % suffix, out="m_q"),
dgl.function.sum(msg="m_q", out="_q")
# lambda node: {"q%s" % suffix: node.mailbox["m_q"].prod(dim=1)}
)
Expand Down
14 changes: 7 additions & 7 deletions espaloma/mm/tests/test_energy_ii.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,31 +78,31 @@ def forward(self, g):
g.multi_update_all(
{
"n2_as_0_in_n3": (
dgl.function.copy_src("u", "m_u_0"),
dgl.function.copy_u("u", "m_u_0"),
dgl.function.sum("m_u_0", "u_left"),
),
"n2_as_1_in_n3": (
dgl.function.copy_src("u", "m_u_1"),
dgl.function.copy_u("u", "m_u_1"),
dgl.function.sum("m_u_1", "u_right"),
),
"n2_as_0_in_n4": (
dgl.function.copy_src("u", "m_u_0"),
dgl.function.copy_u("u", "m_u_0"),
dgl.function.sum("m_u_0", "u_bond_left"),
),
"n2_as_1_in_n4": (
dgl.function.copy_src("u", "m_u_1"),
dgl.function.copy_u("u", "m_u_1"),
dgl.function.sum("m_u_1", "u_bond_center"),
),
"n2_as_2_in_n4": (
dgl.function.copy_src("u", "m_u_2"),
dgl.function.copy_u("u", "m_u_2"),
dgl.function.sum("m_u_2", "u_bond_right"),
),
"n3_as_0_in_n4": (
dgl.function.copy_src("u", "m3_u_0"),
dgl.function.copy_u("u", "m3_u_0"),
dgl.function.sum("m3_u_0", "u_angle_left"),
),
"n3_as_1_in_n4": (
dgl.function.copy_src("u", "m3_u_1"),
dgl.function.copy_u("u", "m3_u_1"),
dgl.function.sum("m3_u_1", "u_angle_right"),
),
},
Expand Down
12 changes: 6 additions & 6 deletions espaloma/nn/readout/charge_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, g, total_charge=0.0):
if "q_ref" in g.nodes["n1"].data:
# get total charge
g.update_all(
dgl.function.copy_src(src="q_ref", out="m_q"),
dgl.function.copy_u(u="q_ref", out="m_q"),
dgl.function.sum(msg="m_q", out="sum_q"),
etype="n1_in_g",
)
Expand All @@ -95,32 +95,32 @@ def forward(self, g, total_charge=0.0):
)

g.update_all(
dgl.function.copy_src(src="sum_q", out="m_sum_q"),
dgl.function.copy_u(u="sum_q", out="m_sum_q"),
dgl.function.sum(msg="m_sum_q", out="sum_q"),
etype="g_has_n1",
)

# get the sum of $s^{-1}$ and $m_s^{-1}$
g.update_all(
dgl.function.copy_src(src="s_inv", out="m_s_inv"),
dgl.function.copy_u(u="s_inv", out="m_s_inv"),
dgl.function.sum(msg="m_s_inv", out="sum_s_inv"),
etype="n1_in_g",
)

g.update_all(
dgl.function.copy_src(src="e_s_inv", out="m_e_s_inv"),
dgl.function.copy_u(u="e_s_inv", out="m_e_s_inv"),
dgl.function.sum(msg="m_e_s_inv", out="sum_e_s_inv"),
etype="n1_in_g",
)

g.update_all(
dgl.function.copy_src(src="sum_s_inv", out="m_sum_s_inv"),
dgl.function.copy_u(u="sum_s_inv", out="m_sum_s_inv"),
dgl.function.sum(msg="m_sum_s_inv", out="sum_s_inv"),
etype="g_has_n1",
)

g.update_all(
dgl.function.copy_src(src="sum_e_s_inv", out="m_sum_e_s_inv"),
dgl.function.copy_u(u="sum_e_s_inv", out="m_sum_e_s_inv"),
dgl.function.sum(msg="m_sum_e_s_inv", out="sum_e_s_inv"),
etype="g_has_n1",
)
Expand Down
2 changes: 1 addition & 1 deletion espaloma/nn/readout/graph_level_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, g):
)

g.update_all(
dgl.function.copy_src("h_global", "m"),
dgl.function.copy_u("h_global", "m"),
self.pool("m", "h_global"),
etype="n1_in_g",
)
Expand Down
8 changes: 4 additions & 4 deletions espaloma/nn/readout/janossy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def forward(self, g):
{
"n1_as_%s_in_n%s"
% (relationship_idx, big_idx): (
dgl.function.copy_src("h", "m%s" % relationship_idx),
dgl.function.copy_u("h", "m%s" % relationship_idx),
dgl.function.mean(
"m%s" % relationship_idx, "h%s" % relationship_idx
),
Expand Down Expand Up @@ -240,7 +240,7 @@ def forward(self, g):
{
"n1_as_%s_in_%s"
% (relationship_idx, big_idx): (
dgl.function.copy_src("h", "m%s" % relationship_idx),
dgl.function.copy_u("h", "m%s" % relationship_idx),
dgl.function.mean(
"m%s" % relationship_idx, "h%s" % relationship_idx
),
Expand Down Expand Up @@ -358,7 +358,7 @@ def forward(self, g):
{
"n1_as_%s_in_%s"
% (relationship_idx, big_idx): (
dgl.function.copy_src("h", "m%s" % relationship_idx),
dgl.function.copy_u("h", "m%s" % relationship_idx),
dgl.function.mean(
"m%s" % relationship_idx, "h%s" % relationship_idx
),
Expand Down Expand Up @@ -475,7 +475,7 @@ def forward(self, g):
{
"n1_as_%s_in_%s"
% (relationship_idx, big_idx): (
dgl.function.copy_src("h", "m%s" % relationship_idx),
dgl.function.copy_u("h", "m%s" % relationship_idx),
dgl.function.mean(
"m%s" % relationship_idx, "h%s" % relationship_idx
),
Expand Down

0 comments on commit 6229a3e

Please sign in to comment.