Skip to content

Commit

Permalink
Fixed unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Nov 2, 2023
1 parent 174b45d commit 21ff0fd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 52 deletions.
8 changes: 5 additions & 3 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ def print_params(self, params_analysis=False):
print(f"Proposal scheme: {self.proposal}")
print(f"Acceptance scheme for swapping simulations: {self.acceptance}")
print(f"Whether to perform weight combination: {self.w_combine}")
print(f"Histogram cutoff: {self.N_cutoff}")
print(f"Type of means for weight combination: {self.w_mean_type}")
print(f"Whether to perform histogram correction: {self.hist_corr}")
print(f"Histogram cutoff for weight correction: {self.N_cutoff}")
print(f"Number of replicas: {self.n_sim}")
print(f"Number of iterations: {self.n_iter}")
print(f"Length of each replica: {self.dt * self.nst_sim} ps")
Expand Down Expand Up @@ -1179,8 +1181,8 @@ def histogram_correction(self, hist, print_values=True):
# (1) Print the original histogram counts
if print_values is True:
print(' Original histogram counts:')
for i in range(len(self.hist)):
print(f' Rep {i}: {self.hist[i]}')
for i in range(len(hist)):
print(f' Rep {i}: {hist[i]}')

# (2) Calculate adjacent weight differences and g_vec
N_ratio_vec = [] # N_{k-1}/N_k for the whole range
Expand Down
65 changes: 16 additions & 49 deletions ensemble_md/tests/test_replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@ def test_set_params_error(self, params_dict):
params_dict['n_sim'] = 4 # so params_dict can be read without failing in the assertions below

# 2. Available options
check_param_error(params_dict, 'proposal', "The specified proposal scheme is not available. Available options include 'single', 'neighboring', 'exhaustive', and 'multiple'.", 'cool', 'multiple') # set as multiple for later tests for n_ex # noqa: E501
check_param_error(params_dict, 'proposal', "The specified proposal scheme is not available. Available options include 'single', 'neighboring', and 'exhaustive'.", 'cool', 'exhaustive') # noqa: E501
check_param_error(params_dict, 'acceptance', "The specified acceptance scheme is not available. Available options include 'same-state' and 'metropolis'.") # noqa: E501
check_param_error(params_dict, 'df_method', "The specified free energy estimator is not available. Available options include 'TI', 'BAR', and 'MBAR'.") # noqa: E501
check_param_error(params_dict, 'err_method', "The specified method for error estimation is not available. Available options include 'propagate', and 'bootstrap'.") # noqa: E501

# 3. Integer parameters
check_param_error(params_dict, 'nst_sim', "The parameter 'nst_sim' should be an integer.")
check_param_error(params_dict, 'n_ex', "The parameter 'n_ex' should be an integer.")
check_param_error(params_dict, 'seed', "The parameter 'seed' should be an integer.")
check_param_error(params_dict, 'n_sim', "The parameter 'n_sim' should be an integer.", 4.1, 4)

Expand All @@ -98,7 +97,6 @@ def test_set_params_error(self, params_dict):
check_param_error(params_dict, 'n_iter', "The parameter 'n_iter' should be positive.", 0, 10)

# 5. Non-negative parameters
check_param_error(params_dict, 'n_ex', "The parameter 'n_ex' should be non-negative.", -1)
check_param_error(params_dict, 'N_cutoff', "The parameter 'N_cutoff' should be non-negative unless no weight correction is needed, i.e. N_cutoff = -1.", -5) # noqa: E501

# 6. String parameters
Expand Down Expand Up @@ -169,9 +167,8 @@ def test_set_params(self, params_dict):
# 2. Check the default values of the parameters not specified in params.yaml
assert REXEE.proposal == "exhaustive"
assert REXEE.acceptance == "metropolis"
assert REXEE.w_combine is None
assert REXEE.w_combine is False
assert REXEE.N_cutoff == 1000
assert REXEE.n_ex == 'N^3'
assert REXEE.verbose is True
assert REXEE.runtime_args is None
assert REXEE.n_ckpt == 100
Expand Down Expand Up @@ -270,9 +267,11 @@ def test_print_params(self, capfd, params_dict):
L += "Verbose log file: True\n"
L += "Proposal scheme: exhaustive\n"
L += "Acceptance scheme for swapping simulations: metropolis\n"
L += "Type of weights to be combined: None\n"
L += "Histogram cutoff: 1000\nNumber of replicas: 4\nNumber of iterations: 10\n"
L += "Number of attempted swaps in one exchange interval: N^3\n"
L += "Whether to perform weight combination: False\n"
L += "Type of means for weight combination: simple\n"
L += "Whether to perform histogram correction: False\n"
L += "Histogram cutoff for weight correction: 1000\n"
L += "Number of replicas: 4\nNumber of iterations: 10\n"
L += "Length of each replica: 1.0 ps\nFrequency for checkpointing: 100 iterations\n"
L += "Total number of states: 9\n"
L += "Additionally defined swappable states: None\n"
Expand Down Expand Up @@ -485,35 +484,6 @@ def test_get_swapping_pattern(self, params_dict):
assert pattern_4_2 == [1, 0, 3, 2]
assert swap_list_4_2 == [(2, 3), (0, 1)]

"""
We will deprecate multiple swaps anyway
# Case 5-1: Multiple swaps (proposal = 'multiple', n_ex = 5)
print('test 5-1')
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
REXEE.n_ex = 5
REXEE.proposal = 'multiple'
states = [3, 1, 4, 6] # swappable pairs: [(0, 1), (0, 2), (1, 2)], first swap = (0, 2), accept
f = copy.deepcopy(dhdl_files)
pattern_5_1, swap_list_5_1 = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 5
assert REXEE.n_rejected == 4
assert pattern_5_1 == [2, 1, 0, 3]
assert swap_list_5_1 == [(0, 2)]
# Case 5-2: Multiple swaps but with only one swappable pair (proposal = 'multiple')
# This is specifically for testing n_swap_attempts in the case where n_ex reduces to 1.
random.seed(0)
REXEE = get_REXEE_instance(params_dict)
states = [0, 2, 3, 8] # The only swappable pair is [(1, 2)] --> accept
f = copy.deepcopy(dhdl_files)
pattern_5_2, swap_list_5_2 = REXEE.get_swapping_pattern(f, states)
assert REXEE.n_swap_attempts == 1 # since there is only 1 swappable pair
assert REXEE.n_rejected == 0
assert pattern_5_2 == [0, 2, 1, 3]
assert swap_list_5_2 == [(1, 2)]
"""

def test_calc_prob_acc(self, capfd, params_dict):
# k = 1.380649e-23; NA = 6.0221408e23; T = 298; kT = k * NA * T / 1000 = 2.4777098766670016
REXEE = get_REXEE_instance(params_dict)
Expand Down Expand Up @@ -587,47 +557,44 @@ def test_weight_correction(self, params_dict):
weights_2 = REXEE.weight_correction(weights_2, counts_2)
assert np.allclose(weights_2, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]])

def test_combine_weights_1(self, params_dict):
def test_combine_weights(self, params_dict):
"""
Here we just test the combined weights, so the values of hist does not matter.
Here we just test the combined weights.
"""
REXEE = get_REXEE_instance(params_dict)
REXEE.n_tot = 6
REXEE.n_sub = 4
REXEE.s = 1
REXEE.n_sim = 3
REXEE.state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]]
weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]]
hist = [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]

_, w_1, g_vec_1 = REXEE.combine_weights(hist, weights)
# Test 1: simple means
weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]]
w_1, g_vec_1 = REXEE.combine_weights(weights)
assert np.allclose(w_1, [
[0, 2.1, 3.9, 3.5],
[0, 1.8, 1.4, 2.75],
[0, -0.4, 0.95, 1.95]])
assert np.allclose(list(g_vec_1), [0, 2.1, 3.9, 3.5, 4.85, 5.85])

# Test 2: weighted means
weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]]
errors = [[0, 0.1, 0.15, 0.1], [0, 0.12, 0.1, 0.12], [0, 0.12, 0.15, 0.1]]
_, w_2, g_vec_2 = REXEE.combine_weights(hist, weights, errors)
w_2, g_vec_2 = REXEE.combine_weights(weights, errors)
assert np.allclose(w_2, [
[0, 2.1, 3.86140725, 3.45417313],
[0, 1.76140725, 1.35417313, 2.71436889],
[0, -0.40723412, 0.95296164, 1.95296164]])
assert np.allclose(list(g_vec_2), [0, 2.1, 3.861407249466951, 3.4541731330165306, 4.814368891580968, 5.814368891580968]) # noqa: E501

def test_combine_weights_2(self, params_dict):
"""
Here we just test the modified histograms, so the values of weights does not matter.
"""
def test_histogram_correction(self, params_dict):
REXEE = get_REXEE_instance(params_dict)
REXEE.n_tot = 6
REXEE.n_sub = 5
REXEE.s = 1
REXEE.n_sim = 2
REXEE.state_ranges = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]]
weights = [[0, 2.1, 4.0, 3.7, 5], [0, 1.7, 1.2, 2.6, 4]]
hist = [[416, 332, 130, 71, 61], [303, 181, 123, 143, 260]]

hist_modified, _, _ = REXEE.combine_weights(hist, weights)
hist_modified = REXEE.histogram_correction(hist)
assert hist_modified == [[416, 332, 161, 98, 98], [332, 161, 98, 98, 178]]

0 comments on commit 21ff0fd

Please sign in to comment.