Skip to content

Commit

Permalink
Enhanced the code coverage of replica_exchange_EE.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Mar 25, 2024
1 parent 206fc11 commit 06070b5
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 7 deletions.
3 changes: 2 additions & 1 deletion ensemble_md/cli/run_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def main():
start_idx = comm.bcast(start_idx, root=0) # so that all the ranks are aware of start_idx

# 2-3. Get the reference distance for the distance restraint specified in the pull code, if any.
REXEE.get_ref_dist()
pullx_file = 'sim_0/iteration_0/pullx.xvg'
REXEE.get_ref_dist(pullx_file)

for i in range(start_idx, REXEE.n_iter):
# For a large code block like below executed on rank 0, we try to catch any exception and abort the simulation.
Expand Down
15 changes: 11 additions & 4 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def check_gmx_executable(self):
self.gmx_version = line.split()[-1]
break
except subprocess.CalledProcessError:
print(f"{self.gmx_executable} is not available on this system.")
print(f"{self.gmx_executable} is not available.")
except Exception as e:
print(f"An error occurred:\n{e}")

Expand Down Expand Up @@ -592,15 +592,21 @@ def initialize_MDP(self, idx):

return MDP

def get_ref_dist(self):
def get_ref_dist(self, pullx_file = 'sim_0/iteration_0/pullx.xvg'):
"""
Gets the reference distance(s) to use starting from the second iteration if distance restraint(s) are used.
Specifically, a reference distance determined here is the initial COM distance between the pull groups
in the input GRO file. This function initializes the attribute :code:`ref_dist`.
Parameter
---------
pullx_file : str
The path to the pullx file whose initial value will be used as the reference distance.
Usually, this should be the path of the pullx file of the first iteration. The default
is :code:`sim_0/iteration_0/pullx.xvg`.
"""
if hasattr(self, 'set_ref_dist'):
self.ref_dist = []
pullx_file = 'sim_0/iteration_0/pullx.xvg'
for i in range(len(self.set_ref_dist)):
if self.set_ref_dist[i] is True:
# dist = list(extract_dataframe(pullx_file, headers=headers)[f'{i+1}'])[0]
Expand Down Expand Up @@ -900,6 +906,7 @@ def get_swapping_pattern(self, dhdl_files, states):
# This should only happen when the method of exhaustive swaps is used.
if i == 0:
self.n_empty_swappable += 1
print('No swap is proposed because there is no swappable pair at all.')
break
else:
self.n_swap_attempts += 1
Expand All @@ -908,7 +915,7 @@ def get_swapping_pattern(self, dhdl_files, states):

swap = ReplicaExchangeEE.propose_swap(swappables)
print(f'\nProposed swap: {swap}')
if swap == []:
if swap == []: # the same as len(swappables) == 0, self.proposal must not be exhaustive if this line is reached.
self.n_empty_swappable += 1
print('No swap is proposed because there is no swappable pair at all.')
break # no need to re-identify swappable pairs and draw new samples
Expand Down
28 changes: 28 additions & 0 deletions ensemble_md/tests/data/pullx.xvg
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This file was created Thu Feb 15 02:05:13 2024
# Created by:
# :-) GROMACS - gmx mdrun, 2022.5-dev-20230428-fdf57150ad (-:
#
# Executable: /jet/home/wehs7661/pkgs/gromacs/2022.5/bin/gmx
# Data prefix: /jet/home/wehs7661/pkgs/gromacs/2022.5
# Working dir: /ocean/projects/bio230014p/wehs7661/EEXE_experiments/CB7-10/complex/REXEE/fixed/Group_1/test_1/rep_1/sim_0/iteration_0
# Command line:
# gmx mdrun -s sys_EE.tpr -nt 16 -ntmpi 1
# gmx mdrun is part of G R O M A C S:
#
# GROwing Monsters And Cloning Shrimps
#
@ title "Pull COM"
@ xaxis label "Time (ps)"
@ yaxis label "Position (nm)"
@TYPE xy
@ view 0.15, 0.15, 0.75, 0.85
@ legend on
@ legend box on
@ legend loctype view
@ legend 0.78, 0.8
@ legend length 2
@ s0 legend "1"
@ s1 legend "1 ref"
0.0000 0.428422 0.428422
2.0000 0.457696 0.428422
4.0000 0.374694 0.428422
84 changes: 82 additions & 2 deletions ensemble_md/tests/test_replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,21 @@ def test_print_params(self, capfd, params_dict):
L += "Note that the input MDP file has been reformatted by replacing hypens with underscores. The original mdp file has been renamed as *backup.mdp.\n" # noqa: E501
assert out_2 == L

REXEE.gro = ['ensemble_md/tests/data/sys.gro', 'ensemble_md/tests/data/sys.gro'] # noqa: E501
REXEE.top = ['ensemble_md/tests/data/sys.top', 'ensemble_md/tests/data/sys.top']
REXEE.mdp_args = {'ref_p': [1.0, 1.01, 1.02, 1.03], 'ref_t': [298, 300, 302, 303]}
REXEE.print_params()
out_3, err = capfd.readouterr()
line_1 = 'Simulation inputs: ensemble_md/tests/data/sys.gro, ensemble_md/tests/data/sys.gro, ensemble_md/tests/data/sys.top, ensemble_md/tests/data/sys.top, ensemble_md/tests/data/expanded.mdp\n' # noqa: E501
line_2 = 'MDP parameters differing across replicas:\n - ref_p: [1.0, 1.01, 1.02, 1.03]\n - ref_t: [298, 300, 302, 303]' # noqa: E501
assert line_1 in out_3
assert line_2 in out_3

def test_initialize_MDP(self, params_dict):
params_dict['mdp_args'] = {'ref_p': [1.0, 1.01, 1.02, 1.03], 'ref_t': [298, 300, 302, 303]}
REXEE = get_REXEE_instance(params_dict)
MDP = REXEE.initialize_MDP(2) # the third replica
assert MDP["ref_p"] == 1.02
assert MDP["nsteps"] == 500
assert all(
[
Expand All @@ -420,6 +432,12 @@ def test_initialize_MDP(self, params_dict):
[a == b for a, b in zip(MDP["init_lambda_weights"], [0, 0, 0, 0, 0, 0])]
)

def test_get_ref_dist(self, params_dict):
params_dict['set_ref_dist'] = [True]
REXEE = get_REXEE_instance(params_dict)
REXEE.get_ref_dist('ensemble_md/tests/data/pullx.xvg')
REXEE.ref_dist = [0.428422]

def test_update_MDP(self, params_dict):
new_template = "ensemble_md/tests/data/expanded.mdp"
iter_idx = 3
Expand All @@ -430,13 +448,22 @@ def test_update_MDP(self, params_dict):
[0, 0, 0, 0, 0, 0],
[3.48, 2.78, 3.21, 4.56, 8.79, 0.48],
[8.45, 0.52, 3.69, 2.43, 4.56, 6.73], ]
counts = [
[4, 11, 9, 9, 11, 6],
[9, 8, 8, 11, 7, 7],
[3, 1, 1, 9, 15, 21],
[0, 0, 0, 1, 18, 31],
]
params_dict['set_ref_dist'] = [True]

REXEE = get_REXEE_instance(params_dict)
REXEE.equil = [-1, 1, 0, -1] # i.e. the 3rd replica will use fixed weights in the next iteration
REXEE.equil = [-1, 1, 0, -1] # i.e., the 3rd replica will use fixed weights in the next iteration
MDP_1 = REXEE.update_MDP(
new_template, 2, iter_idx, states, wl_delta, weights) # third replica

REXEE.get_ref_dist('ensemble_md/tests/data/pullx.xvg') # so that we can test the pull code
MDP_2 = REXEE.update_MDP(
new_template, 3, iter_idx, states, wl_delta, weights) # fourth replica
new_template, 3, iter_idx, states, wl_delta, weights, counts) # fourth replica

assert MDP_1["tinit"] == MDP_2["tinit"] == 3
assert MDP_1["nsteps"] == MDP_2["nsteps"] == 500
Expand All @@ -461,6 +488,9 @@ def test_update_MDP(self, params_dict):
)
]
)
assert MDP_2['init_histogram_counts'] == [0, 0, 0, 1, 18, 31]
assert MDP_2['pull_coord1_start'] == 'no'
assert MDP_2['pull_coord1_init'] == 0.428422

def test_extract_final_dhdl_info(self, params_dict):
REXEE = get_REXEE_instance(params_dict)
Expand Down Expand Up @@ -488,11 +518,48 @@ def test_extract_final_log_info(self, params_dict):
[0, 0, 0, 1, 18, 31], ]
assert REXEE.equil == [-1, -1, -1, -1]

# Below is a case where one of the replicas (the first replica) got equilibrated
log_files[0] = os.path.join(input_path, "log/case2_1.log") # equilibrated weights
wl_delta, weights, counts = REXEE.extract_final_log_info(log_files)
assert np.allclose(REXEE.equil, [6.06, -1, -1, -1])
assert REXEE.equilibrated_weights == [[0.00000, 1.40453, 2.85258, 2.72480, 3.46220, 5.88607], [], [], []]

def test_get_averaged_weights(self, params_dict):
REXEE = get_REXEE_instance(params_dict)
log_files = [
os.path.join(input_path, f"log/EXE_{i}.log") for i in range(REXEE.n_sim)]
avg, err = REXEE.get_averaged_weights(log_files)
assert REXEE.current_wl_delta == [0.4, 0.5, 0.5, 0.5]
assert REXEE.updating_weights == [
[
[0, 3.83101, 4.95736, 5.63808, 6.07220, 6.13408],
[0, 3.43101, 3.75736, 5.23808, 4.87220, 5.33408],
[0, 2.63101, 2.95736, 5.23808, 4.47220, 5.73408],
[0, 1.83101, 2.55736, 4.43808, 4.47220, 6.13408],
[0, 1.03101, 2.55736, 3.63808, 4.47220, 6.13408],
], # the weights of the first replica at 5 different time frames
[
[0, 0.72635, 0.80707, 1.44120, 2.10308, 4.03106],
[0, 0.72635, 1.30707, 1.44120, 2.10308, 4.53106],
[0, 0.72635, 2.80707, 2.94120, 4.10308, 6.53106],
[0, 1.72635, 2.30707, 2.44120, 5.10308, 6.53106],
[0, 1.22635, 2.30707, 2.44120, 4.10308, 6.03106],
], # the weights of the second replica at 5 different time frames
[
[0, -0.33569, -0.24525, 2.74443, 4.59472, 7.70726],
[0, -0.33569, -0.24525, 2.74443, 3.59472, 3.70726],
[0, -0.33569, -0.24525, 2.74443, 2.09472, 0.20726],
[0, -0.33569, -0.24525, 1.74443, -0.90528, -0.79274],
[0, 0.66431, 1.25475, 0.24443, 0.59472, 0.70726]
], # the weights of the third replica at 5 different time frames
[
[0, 0.09620, 1.59937, -4.31679, -14.89436, -16.08701],
[0, 0.09620, 1.59937, -4.31679, -15.89436, -20.08701],
[0, 0.09620, 1.59937, -4.31679, -18.39436, -22.58701],
[0, 0.09620, 1.59937, -4.31679, -20.39436, -25.58701],
[0, 0.09620, 1.59937, -4.31679, -22.89436, -28.08701]
]
]
assert np.allclose(avg[0], [0, 2.55101, 3.35736, 4.83808, 4.8722, 5.89408])
assert np.allclose(err[0], [0, 1.14542569, 1.0198039, 0.8, 0.69282032, 0.35777088])

Expand Down Expand Up @@ -584,6 +651,19 @@ def test_get_swapping_pattern(self, params_dict):
assert pattern_4_2 == [1, 0, 3, 2]
assert swap_list_4_2 == [(2, 3), (0, 1)]

# Case 4-3: REXEE.proposal is set to exhaustive but there is only one swappable pair anyway.
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
REXEE.proposal = 'exhaustive'
states = [0, 2, 2, 8] # swappable pair: [(1, 2)], swap: (1, 2), accept
f = copy.deepcopy(dhdl_files)
pattern_4_3, swap_list_4_3 = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 1
assert REXEE.n_rejected == 0
assert pattern_4_3 == [0, 2, 1, 3]
assert swap_list_4_3 == [(1, 2)]


def test_calc_prob_acc(self, capfd, params_dict):
# k = 1.380649e-23; NA = 6.0221408e23; T = 298; kT = k * NA * T / 1000 = 2.4777098766670016
REXEE = get_REXEE_instance(params_dict)
Expand Down

0 comments on commit 06070b5

Please sign in to comment.