Skip to content

Commit

Permalink
Tweaked set_params to better handle modify_coords; Added tests releva…
Browse files Browse the repository at this point in the history
…nt to MT-REXEE
  • Loading branch information
wehs7661 committed Apr 19, 2024
1 parent 27849b3 commit 6164680
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ include parameters for data analysis here.
the parameters :code:`gro` and :code:`top`, only one MDP file can be specified for the parameter :code:`mdp`. If you wish to use
different parameters for different replicas, please use the parameter :code:`mdp_args`.
- :code:`modify_coords`: (Optional, Default: :code:`None`)
The name of the Python module (without including the :code:`.py` extension) for modifying the output coordinates of the swapping replicas
The file path to the Python module for modifying the output coordinates of the swapping replicas
before the coordinate exchange, which is generally required in REXEE simulations for multiple serial mutations.
For the CLI :code:`run_REXEE` to work, here is the predefined contract for the module/function based on the assumptions :code:`run_REXEE` makes.
Modules/functions not obeying the contract are unlikely to work.
Expand Down
19 changes: 14 additions & 5 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,14 @@ def set_params(self, analysis):
raise ParameterError(f"The parameter '{i}' should be a dictionary.")

if self.add_swappables is not None:
if not isinstance(self.add_swappables, list):
raise ParameterError("The parameter 'add_swappables' should be a nested list.")
for sublist in self.add_swappables:
if not isinstance(sublist, list):
raise ParameterError("The parameter 'add_swappables' should be a nested list.")
for item in sublist:
if not isinstance(item, int) or item < 0:
raise ParameterError("Each number specified in 'add_swappables' should be a non-negative integer.") # noqa: E501
if [len(i) for i in self.add_swappables] != [2] * len(self.add_swappables):
raise ParameterError("Each sublist in 'add_swappables' should contain two integers.")

if self.mdp_args is not None:
# Note that mdp_args is a dictionary including MDP parameters DIFFERING across replicas.
Expand Down Expand Up @@ -441,9 +441,17 @@ def set_params(self, analysis):

# 7-12. External module for coordinate modification
if self.modify_coords is not None:
sys.path.append(os.getcwd())
module = importlib.import_module(self.modify_coords)
self.modify_coords_fn = getattr(module, self.modify_coords)
module_file = os.path.basename(self.modify_coords)
module_dir = os.path.dirname(self.modify_coords)
if module_dir not in sys.path:
sys.path.append(module_dir) # so that the module can be imported
module_name = os.path.splitext(module_file)[0]
module = importlib.import_module(module_name)
if not hasattr(module, module_name):
err_msg = f'The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., {module_name}.' # noqa: E501
raise ParameterError(err_msg)
else:
self.modify_coords_fn = getattr(module, module_name)
else:
self.modify_coords_fn = None

Expand Down Expand Up @@ -509,6 +517,7 @@ def print_params(self, params_analysis=False):
print(f"Additionally defined swappable states: {self.add_swappables}")
print(f"Additional grompp arguments: {self.grompp_args}")
print(f"Additional runtime arguments: {self.runtime_args}")
print(f"External modules for coordinate manipulation: {self.modify_coords}")
# print(f"Number of attempted swaps in one exchange interval: {self.n_ex}")
if self.mdp_args is not None and len(self.mdp_args.keys()) > 1:
print("MDP parameters differing across replicas:")
Expand Down
45 changes: 45 additions & 0 deletions ensemble_md/tests/test_replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""
Unit tests for the module replica_exchange_EE.py.
"""
import re
import os
import sys
import yaml
Expand Down Expand Up @@ -338,6 +339,49 @@ def test_set_params_edge_cases(self, params_dict):
mdp.write(os.path.join(input_path, "expanded_test.mdp"))
REXEE = get_REXEE_instance(params_dict)
assert REXEE.fixed_weights is True
assert REXEE.modify_coords_fn is None # Just an additional test for modify_coords_fn

def test_set_params_mtrexee(self, params_dict):
# Test 1: Below we check if the parameter "add_swappables" is well-defined.
params_dict['add_swappables'] = 5
params_dict['mdp'] = 'expanded.mdp' # irrelevant to MT-REXEE but just to cover some lines
with pytest.raises(ParameterError, match="The parameter 'add_swappables' should be a list."):
get_REXEE_instance(params_dict)

params_dict['add_swappables'] = [15, 16]
with pytest.raises(ParameterError, match="The parameter 'add_swappables' should be a nested list."):
get_REXEE_instance(params_dict)

params_dict['add_swappables'] = [[-3, 1], [4, 5]]
with pytest.raises(ParameterError, match="Each number specified in 'add_swappables' should be a non-negative integer."): # noqa: E501
get_REXEE_instance(params_dict)

params_dict['add_swappables'] = [[1, 2, 3], [4, 5]]
with pytest.raises(ParameterError, match="Each sublist in 'add_swappables' should contain two integers."): # noqa: E501
get_REXEE_instance(params_dict)

# Test 2: Below are some checks for the parameter "modify_coords"
# 2-1. The case where a function has the same name as the module.
params_dict['mdp'] = 'ensemble_md/tests/data/expanded.mdp'
params_dict['modify_coords'] = 'ensemble_md/tests/data/edit_gro.py'
params_dict['add_swappables'] = [[2, 3], [4, 5]]
with open('ensemble_md/tests/data/edit_gro.py', 'w') as f:
f.write('def edit_gro():\n')
f.write(' pass\n')
REXEE = get_REXEE_instance(params_dict)
assert REXEE.modify_coords_fn.__name__ == 'edit_gro'
os.remove('ensemble_md/tests/data/edit_gro.py')

# 2-2. The case where no function has the same name as the module.
params_dict['modify_coords'] = 'ensemble_md/tests/data/check_gro.py'
with open('ensemble_md/tests/data/check_gro.py', 'w') as f:
f.write('def test_gro():\n')
f.write(' pass\n')
# Below we escape the error message since things like "." could be interpreted as special characters
err_msg = re.escape("The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., check_gro.") # noqa: E501
with pytest.raises(ParameterError, match=err_msg): # noqa: E501
REXEE = get_REXEE_instance(params_dict)
os.remove('ensemble_md/tests/data/check_gro.py')

def test_reformat_MDP(self, params_dict):
# Note that the function reformat_MDP is called in set_params
Expand Down Expand Up @@ -377,6 +421,7 @@ def test_print_params(self, capfd, params_dict):
L += "Additionally defined swappable states: None\n"
L += "Additional grompp arguments: None\n"
L += "Additional runtime arguments: None\n"
L += "External modules for coordinate manipulation: None\n"
L += "MDP parameters differing across replicas: None\n"
L += "Alchemical ranges of each replica in REXEE:\n - Replica 0: States [0, 1, 2, 3, 4, 5]\n"
L += " - Replica 1: States [1, 2, 3, 4, 5, 6]\n - Replica 2: States [2, 3, 4, 5, 6, 7]\n"
Expand Down

0 comments on commit 6164680

Please sign in to comment.