From 6164680d58bdb66cf9cff576b646c9f1d8abd4ed Mon Sep 17 00:00:00 2001 From: Wei-Tse Hsu Date: Fri, 19 Apr 2024 18:17:36 +0800 Subject: [PATCH] Tweaked set_params to better handle modify_coords; Added tests relevant to MT-REXEE --- docs/simulations.rst | 2 +- ensemble_md/replica_exchange_EE.py | 19 +++++--- ensemble_md/tests/test_replica_exchange_EE.py | 45 +++++++++++++++++++ 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/docs/simulations.rst b/docs/simulations.rst index 9a549c61..4207c082 100644 --- a/docs/simulations.rst +++ b/docs/simulations.rst @@ -243,7 +243,7 @@ include parameters for data analysis here. the parameters :code:`gro` and :code:`top`, only one MDP file can be specified for the parameter :code:`mdp`. If you wish to use different parameters for different replicas, please use the parameter :code:`mdp_args`. - :code:`modify_coords`: (Optional, Default: :code:`None`) - The name of the Python module (without including the :code:`.py` extension) for modifying the output coordinates of the swapping replicas + The file path to the Python module for modifying the output coordinates of the swapping replicas before the coordinate exchange, which is generally required in REXEE simulations for multiple serial mutations. For the CLI :code:`run_REXEE` to work, here is the predefined contract for the module/function based on the assumptions :code:`run_REXEE` makes. Modules/functions not obeying the contract are unlikely to work. diff --git a/ensemble_md/replica_exchange_EE.py b/ensemble_md/replica_exchange_EE.py index 4104a218..0f11803c 100644 --- a/ensemble_md/replica_exchange_EE.py +++ b/ensemble_md/replica_exchange_EE.py @@ -264,14 +264,14 @@ def set_params(self, analysis): raise ParameterError(f"The parameter '{i}' should be a dictionary.") if self.add_swappables is not None: - if not isinstance(self.add_swappables, list): - raise ParameterError("The parameter 'add_swappables' should be a nested list.") for sublist in self.add_swappables: if not isinstance(sublist, list): raise ParameterError("The parameter 'add_swappables' should be a nested list.") for item in sublist: if not isinstance(item, int) or item < 0: raise ParameterError("Each number specified in 'add_swappables' should be a non-negative integer.") # noqa: E501 + if [len(i) for i in self.add_swappables] != [2] * len(self.add_swappables): + raise ParameterError("Each sublist in 'add_swappables' should contain two integers.") if self.mdp_args is not None: # Note that mdp_args is a dictionary including MDP parameters DIFFERING across replicas. @@ -441,9 +441,17 @@ def set_params(self, analysis): # 7-12. External module for coordinate modification if self.modify_coords is not None: - sys.path.append(os.getcwd()) - module = importlib.import_module(self.modify_coords) - self.modify_coords_fn = getattr(module, self.modify_coords) + module_file = os.path.basename(self.modify_coords) + module_dir = os.path.dirname(self.modify_coords) + if module_dir not in sys.path: + sys.path.append(module_dir) # so that the module can be imported + module_name = os.path.splitext(module_file)[0] + module = importlib.import_module(module_name) + if not hasattr(module, module_name): + err_msg = f'The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., {module_name}.' # noqa: E501 + raise ParameterError(err_msg) + else: + self.modify_coords_fn = getattr(module, module_name) else: self.modify_coords_fn = None @@ -509,6 +517,7 @@ def print_params(self, params_analysis=False): print(f"Additionally defined swappable states: {self.add_swappables}") print(f"Additional grompp arguments: {self.grompp_args}") print(f"Additional runtime arguments: {self.runtime_args}") + print(f"External modules for coordinate manipulation: {self.modify_coords}") # print(f"Number of attempted swaps in one exchange interval: {self.n_ex}") if self.mdp_args is not None and len(self.mdp_args.keys()) > 1: print("MDP parameters differing across replicas:") diff --git a/ensemble_md/tests/test_replica_exchange_EE.py b/ensemble_md/tests/test_replica_exchange_EE.py index 3e079052..4ce8ec4b 100644 --- a/ensemble_md/tests/test_replica_exchange_EE.py +++ b/ensemble_md/tests/test_replica_exchange_EE.py @@ -10,6 +10,7 @@ """ Unit tests for the module replica_exchange_EE.py. """ +import re import os import sys import yaml @@ -338,6 +339,49 @@ def test_set_params_edge_cases(self, params_dict): mdp.write(os.path.join(input_path, "expanded_test.mdp")) REXEE = get_REXEE_instance(params_dict) assert REXEE.fixed_weights is True + assert REXEE.modify_coords_fn is None # Just an additional test for modify_coords_fn + + def test_set_params_mtrexee(self, params_dict): + # Test 1: Below we check if the parameter "add_swappables" is well-defined. + params_dict['add_swappables'] = 5 + params_dict['mdp'] = 'expanded.mdp' # irrelevant to MT-REXEE but just to cover some lines + with pytest.raises(ParameterError, match="The parameter 'add_swappables' should be a list."): + get_REXEE_instance(params_dict) + + params_dict['add_swappables'] = [15, 16] + with pytest.raises(ParameterError, match="The parameter 'add_swappables' should be a nested list."): + get_REXEE_instance(params_dict) + + params_dict['add_swappables'] = [[-3, 1], [4, 5]] + with pytest.raises(ParameterError, match="Each number specified in 'add_swappables' should be a non-negative integer."): # noqa: E501 + get_REXEE_instance(params_dict) + + params_dict['add_swappables'] = [[1, 2, 3], [4, 5]] + with pytest.raises(ParameterError, match="Each sublist in 'add_swappables' should contain two integers."): # noqa: E501 + get_REXEE_instance(params_dict) + + # Test 2: Below are some checks for the parameter "modify_coords" + # 2-1. The case where a function has the same name as the module. + params_dict['mdp'] = 'ensemble_md/tests/data/expanded.mdp' + params_dict['modify_coords'] = 'ensemble_md/tests/data/edit_gro.py' + params_dict['add_swappables'] = [[2, 3], [4, 5]] + with open('ensemble_md/tests/data/edit_gro.py', 'w') as f: + f.write('def edit_gro():\n') + f.write(' pass\n') + REXEE = get_REXEE_instance(params_dict) + assert REXEE.modify_coords_fn.__name__ == 'edit_gro' + os.remove('ensemble_md/tests/data/edit_gro.py') + + # 2-2. The case where no function has the same name as the module. + params_dict['modify_coords'] = 'ensemble_md/tests/data/check_gro.py' + with open('ensemble_md/tests/data/check_gro.py', 'w') as f: + f.write('def test_gro():\n') + f.write(' pass\n') + # Below we escape the error message since things like "." could be interpreted as special characters + err_msg = re.escape("The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., check_gro.") # noqa: E501 + with pytest.raises(ParameterError, match=err_msg): # noqa: E501 + REXEE = get_REXEE_instance(params_dict) + os.remove('ensemble_md/tests/data/check_gro.py') def test_reformat_MDP(self, params_dict): # Note that the function reformat_MDP is called in set_params @@ -377,6 +421,7 @@ def test_print_params(self, capfd, params_dict): L += "Additionally defined swappable states: None\n" L += "Additional grompp arguments: None\n" L += "Additional runtime arguments: None\n" + L += "External modules for coordinate manipulation: None\n" L += "MDP parameters differing across replicas: None\n" L += "Alchemical ranges of each replica in REXEE:\n - Replica 0: States [0, 1, 2, 3, 4, 5]\n" L += " - Replica 1: States [1, 2, 3, 4, 5, 6]\n - Replica 2: States [2, 3, 4, 5, 6, 7]\n"