Skip to content

Commit

Permalink
Add some more tests for functionalities for MT-REXEE in replica_excha…
Browse files Browse the repository at this point in the history
…nge_EE.py
  • Loading branch information
wehs7661 committed Apr 19, 2024
1 parent 6164680 commit 84558ca
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 25 deletions.
2 changes: 1 addition & 1 deletion ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,6 @@ def get_swapping_pattern(self, dhdl_files, states):
print('No swap is proposed because there is no swappable pair at all.')
break
else:
self.n_swap_attempts += 1
if self.proposal == 'exhaustive':
n_ex_exhaustive += 1

Expand All @@ -955,6 +954,7 @@ def get_swapping_pattern(self, dhdl_files, states):
print('No swap is proposed because there is no swappable pair at all.')
break # no need to re-identify swappable pairs and draw new samples
else:
self.n_swap_attempts += 1
if self.verbose is True and self.proposal != 'exhaustive':
print(f'A swap ({i + 1}/{n_ex}) is proposed between the configurations of Simulation {swap[0]} (state {states[swap[0]]}) and Simulation {swap[1]} (state {states[swap[1]]}) ...') # noqa: E501

Expand Down
107 changes: 83 additions & 24 deletions ensemble_md/tests/test_replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import copy
import random
import pytest
import subprocess
import numpy as np
import ensemble_md
from unittest.mock import patch
from ensemble_md.utils import gmx_parser
from ensemble_md.replica_exchange_EE import ReplicaExchangeEE
from ensemble_md.utils.exceptions import ParameterError
Expand Down Expand Up @@ -383,6 +385,29 @@ def test_set_params_mtrexee(self, params_dict):
REXEE = get_REXEE_instance(params_dict)
os.remove('ensemble_md/tests/data/check_gro.py')

@patch('ensemble_md.replica_exchange_EE.subprocess.run')
@patch('builtins.print')
def test_check_gmx_executable(self, mock_print, mock_run, params_dict):
# Here we test the case where the GROMACS executable is not available or an unexpected error occurs.
# The case where the executable is found is tested in the other unit test.
# Note that in check_gmx_executable, exceptions are caught (i.e., not raised), with a message printed,
# so we can only check the printed messages.

# Test 1: The case where the GROMACS executable is not available
mock_run.side_effect = subprocess.CalledProcessError(1, ['which', 'gmx'])
get_REXEE_instance(params_dict)
mock_run.assert_called()
mock_print.assert_called_with("gmx is not available.")

# Test 2: The case where an unexpected error occurs
mock_run.reset_mock()
mock_print.reset_mock()

mock_run.side_effect = Exception("Some error")
get_REXEE_instance(params_dict)
mock_run.assert_called()
mock_print.assert_called_with("An error occurred:\nSome error")

def test_reformat_MDP(self, params_dict):
# Note that the function reformat_MDP is called in set_params
mdp = gmx_parser.MDP(os.path.join(input_path, "expanded.mdp"))
Expand Down Expand Up @@ -483,6 +508,9 @@ def test_get_ref_dist(self, params_dict):
REXEE.get_ref_dist('ensemble_md/tests/data/pullx.xvg')
REXEE.ref_dist = [0.428422]

with pytest.raises(FileNotFoundError):
REXEE.get_ref_dist() # the default f"{self.working_dir}/sim_0/iteration_0/pullx.xvg" would not exist

def test_update_MDP(self, params_dict):
new_template = "ensemble_md/tests/data/expanded.mdp"
iter_idx = 3
Expand Down Expand Up @@ -622,6 +650,13 @@ def test_identify_swappable_pairs(self, params_dict):
swappables_2 = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring')
assert swappables_2 == [(0, 1), (1, 2)]

# Case 3: Non-neighboring exchange, with add_swappables
REXEE.proposal = 'exhaustive'
REXEE.add_swappables = [[3, 7], [4, 7]]
states = [4, 3, 2, 7] # Without add_swappables, the swappables would be [(0, 1), (0, 2), (1, 2)]
swappables_3 = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring', REXEE.add_swappables) # noqa: E501
assert swappables_3 == [(0, 1), (0, 2), (1, 2), (0, 3), (1, 3)]

def test_propose_swap(self, params_dict):
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
Expand All @@ -635,78 +670,102 @@ def test_get_swapping_pattern(self, params_dict):
# state_ranges are: 0-5, 1-6, ..., 3-8
dhdl_files = [os.path.join(input_path, f"dhdl/dhdl_{i}.xvg") for i in range(4)]

# Case 1: Empty swap list
# Test 1: Empty swap list, exhaustive proposal
REXEE = get_REXEE_instance(params_dict)
REXEE.verbose = False
states = [0, 6, 7, 8] # No swappable pairs
f = copy.deepcopy(dhdl_files)
pattern_1, swap_list_1 = REXEE.get_swapping_pattern(f, states)
pattern, swap_list = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_empty_swappable == 1
assert REXEE.n_swap_attempts == 0
assert REXEE.n_rejected == 0
assert pattern == [0, 1, 2, 3]
assert swap_list == []

# Test 2: Empty swap list, neighboring proposal
REXEE = get_REXEE_instance(params_dict)
REXEE.proposal = 'neighboring' # n_ex will be set to 1 automatically.
states = [0, 6, 7, 8] # No swappable pairs
f = copy.deepcopy(dhdl_files)
pattern, swap_list = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_empty_swappable == 1
assert REXEE.n_swap_attempts == 0
assert REXEE.n_rejected == 0
assert pattern_1 == [0, 1, 2, 3]
assert swap_list_1 == []
assert pattern == [0, 1, 2, 3]
assert swap_list == []

# Case 2: Single swap (proposal = 'single')
# Test 3: Single swap (proposal = 'single')
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
REXEE.verbose = True
REXEE.proposal = 'single' # n_ex will be set to 1 automatically.
states = [5, 2, 2, 8] # swappable pairs: [(0, 1), (0, 2), (1, 2)], swap = (1, 2), accept
f = copy.deepcopy(dhdl_files)
pattern_2, swap_list_2 = REXEE.get_swapping_pattern(f, states)
pattern, swap_list = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 1
assert REXEE.n_rejected == 0
assert pattern_2 == [0, 2, 1, 3]
assert swap_list_2 == [(1, 2)]
assert pattern == [0, 2, 1, 3]
assert swap_list == [(1, 2)]

# Case 3: Neighboring swap
# Test 4: Neighboring swap
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
REXEE.proposal = 'neighboring' # n_ex will be set to 1 automatically.
states = [5, 2, 2, 8] # swappable pairs: [(0, 1), (0, 2), (1, 2)], swap = (1, 2), accept
f = copy.deepcopy(dhdl_files)
pattern_3, swap_list_3 = REXEE.get_swapping_pattern(f, states)
pattern, swap_list = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 1
assert REXEE.n_rejected == 0
assert pattern_3 == [0, 2, 1, 3]
assert swap_list_3 == [(1, 2)]
assert pattern == [0, 2, 1, 3]
assert swap_list == [(1, 2)]

# Case 4-1: Exhaustive swaps that end up in a single swap
# Test 5: Exhaustive swaps that end up in a single swap
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
REXEE.proposal = 'exhaustive'
states = [5, 2, 2, 8] # swappable pairs: [(0, 1), (0, 2), (1, 2)], swap = (1, 2), accept
f = copy.deepcopy(dhdl_files)
pattern_4_1, swap_list_4_1 = REXEE.get_swapping_pattern(f, states)
pattern, swap_list = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 1
assert REXEE.n_rejected == 0
assert pattern_4_1 == [0, 2, 1, 3]
assert swap_list_4_1 == [(1, 2)]
assert pattern == [0, 2, 1, 3]
assert swap_list == [(1, 2)]

# Case 4-2: Exhaustive swaps that involve multiple attempted swaps
# Test 6: Exhaustive swaps that involve multiple attempted swaps
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
REXEE.proposal = 'exhaustive'
states = [4, 2, 4, 3] # swappable pairs: [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3)]; swap 1: (2, 3), accepted; swap 2: (0, 1), accept # noqa: E501
f = copy.deepcopy(dhdl_files)
pattern_4_2, swap_list_4_2 = REXEE.get_swapping_pattern(f, states)
pattern, swap_list = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 2 # \Delta is negative for both swaps -> both accepted
assert REXEE.n_rejected == 0
assert pattern_4_2 == [1, 0, 3, 2]
assert swap_list_4_2 == [(2, 3), (0, 1)]
assert pattern == [1, 0, 3, 2]
assert swap_list == [(2, 3), (0, 1)]

# Case 4-3: REXEE.proposal is set to exhaustive but there is only one swappable pair anyway.
# Test 7: REXEE.proposal is set to exhaustive but there is only one swappable pair anyway.
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
REXEE.proposal = 'exhaustive'
states = [0, 2, 2, 8] # swappable pair: [(1, 2)], swap: (1, 2), accept
f = copy.deepcopy(dhdl_files)
pattern_4_3, swap_list_4_3 = REXEE.get_swapping_pattern(f, states)
pattern, swap_list = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 1
assert REXEE.n_rejected == 0
assert pattern == [0, 2, 1, 3]
assert swap_list == [(1, 2)]

# Test 8: modify_coords_fn is not None, so swap_bool is always True
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
REXEE.modify_coords_fn = 'Cool'
states = [5, 2, 2, 8] # swappable pairs: [(0, 1), (0, 2), (1, 2)], swap = (1, 2), accept
f = copy.deepcopy(dhdl_files)
pattern, swap_list = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 1
assert REXEE.n_rejected == 0
assert pattern_4_3 == [0, 2, 1, 3]
assert swap_list_4_3 == [(1, 2)]
assert pattern == [0, 2, 1, 3]
assert swap_list == [(1, 2)]

def test_calc_prob_acc(self, capfd, params_dict):
# k = 1.380649e-23; NA = 6.0221408e23; T = 298; kT = k * NA * T / 1000 = 2.4777098766670016
Expand Down

0 comments on commit 84558ca

Please sign in to comment.