diff --git a/ensemble_md/replica_exchange_EE.py b/ensemble_md/replica_exchange_EE.py index 8cb1c7a0..90a33ebb 100644 --- a/ensemble_md/replica_exchange_EE.py +++ b/ensemble_md/replica_exchange_EE.py @@ -493,7 +493,9 @@ def print_params(self, params_analysis=False): print(f"Proposal scheme: {self.proposal}") print(f"Acceptance scheme for swapping simulations: {self.acceptance}") print(f"Whether to perform weight combination: {self.w_combine}") - print(f"Histogram cutoff: {self.N_cutoff}") + print(f"Type of means for weight combination: {self.w_mean_type}") + print(f"Whether to perform histogram correction: {self.hist_corr}") + print(f"Histogram cutoff for weight correction: {self.N_cutoff}") print(f"Number of replicas: {self.n_sim}") print(f"Number of iterations: {self.n_iter}") print(f"Length of each replica: {self.dt * self.nst_sim} ps") @@ -1179,8 +1181,8 @@ def histogram_correction(self, hist, print_values=True): # (1) Print the original histogram counts if print_values is True: print(' Original histogram counts:') - for i in range(len(self.hist)): - print(f' Rep {i}: {self.hist[i]}') + for i in range(len(hist)): + print(f' Rep {i}: {hist[i]}') # (2) Calculate adjacent weight differences and g_vec N_ratio_vec = [] # N_{k-1}/N_k for the whole range diff --git a/ensemble_md/tests/test_replica_exchange_EE.py b/ensemble_md/tests/test_replica_exchange_EE.py index 2ac56cd3..05ebf744 100644 --- a/ensemble_md/tests/test_replica_exchange_EE.py +++ b/ensemble_md/tests/test_replica_exchange_EE.py @@ -82,14 +82,13 @@ def test_set_params_error(self, params_dict): params_dict['n_sim'] = 4 # so params_dict can be read without failing in the assertions below # 2. Available options - check_param_error(params_dict, 'proposal', "The specified proposal scheme is not available. Available options include 'single', 'neighboring', 'exhaustive', and 'multiple'.", 'cool', 'multiple') # set as multiple for later tests for n_ex # noqa: E501 + check_param_error(params_dict, 'proposal', "The specified proposal scheme is not available. Available options include 'single', 'neighboring', and 'exhaustive'.", 'cool', 'exhaustive') # noqa: E501 check_param_error(params_dict, 'acceptance', "The specified acceptance scheme is not available. Available options include 'same-state' and 'metropolis'.") # noqa: E501 check_param_error(params_dict, 'df_method', "The specified free energy estimator is not available. Available options include 'TI', 'BAR', and 'MBAR'.") # noqa: E501 check_param_error(params_dict, 'err_method', "The specified method for error estimation is not available. Available options include 'propagate', and 'bootstrap'.") # noqa: E501 # 3. Integer parameters check_param_error(params_dict, 'nst_sim', "The parameter 'nst_sim' should be an integer.") - check_param_error(params_dict, 'n_ex', "The parameter 'n_ex' should be an integer.") check_param_error(params_dict, 'seed', "The parameter 'seed' should be an integer.") check_param_error(params_dict, 'n_sim', "The parameter 'n_sim' should be an integer.", 4.1, 4) @@ -98,7 +97,6 @@ def test_set_params_error(self, params_dict): check_param_error(params_dict, 'n_iter', "The parameter 'n_iter' should be positive.", 0, 10) # 5. Non-negative parameters - check_param_error(params_dict, 'n_ex', "The parameter 'n_ex' should be non-negative.", -1) check_param_error(params_dict, 'N_cutoff', "The parameter 'N_cutoff' should be non-negative unless no weight correction is needed, i.e. N_cutoff = -1.", -5) # noqa: E501 # 6. String parameters @@ -169,9 +167,8 @@ def test_set_params(self, params_dict): # 2. Check the default values of the parameters not specified in params.yaml assert REXEE.proposal == "exhaustive" assert REXEE.acceptance == "metropolis" - assert REXEE.w_combine is None + assert REXEE.w_combine is False assert REXEE.N_cutoff == 1000 - assert REXEE.n_ex == 'N^3' assert REXEE.verbose is True assert REXEE.runtime_args is None assert REXEE.n_ckpt == 100 @@ -270,9 +267,11 @@ def test_print_params(self, capfd, params_dict): L += "Verbose log file: True\n" L += "Proposal scheme: exhaustive\n" L += "Acceptance scheme for swapping simulations: metropolis\n" - L += "Type of weights to be combined: None\n" - L += "Histogram cutoff: 1000\nNumber of replicas: 4\nNumber of iterations: 10\n" - L += "Number of attempted swaps in one exchange interval: N^3\n" + L += "Whether to perform weight combination: False\n" + L += "Type of means for weight combination: simple\n" + L += "Whether to perform histogram correction: False\n" + L += "Histogram cutoff for weight correction: 1000\n" + L += "Number of replicas: 4\nNumber of iterations: 10\n" L += "Length of each replica: 1.0 ps\nFrequency for checkpointing: 100 iterations\n" L += "Total number of states: 9\n" L += "Additionally defined swappable states: None\n" @@ -485,35 +484,6 @@ def test_get_swapping_pattern(self, params_dict): assert pattern_4_2 == [1, 0, 3, 2] assert swap_list_4_2 == [(2, 3), (0, 1)] - """ - We will deprecate multiple swaps anyway - # Case 5-1: Multiple swaps (proposal = 'multiple', n_ex = 5) - print('test 5-1') - random.seed(0) - REXEE = get_REXEE_instance(params_dict) - REXEE.n_ex = 5 - REXEE.proposal = 'multiple' - states = [3, 1, 4, 6] # swappable pairs: [(0, 1), (0, 2), (1, 2)], first swap = (0, 2), accept - f = copy.deepcopy(dhdl_files) - pattern_5_1, swap_list_5_1 = REXEE.get_swapping_pattern(f, states) - assert REXEE.n_swap_attempts == 5 - assert REXEE.n_rejected == 4 - assert pattern_5_1 == [2, 1, 0, 3] - assert swap_list_5_1 == [(0, 2)] - - # Case 5-2: Multiple swaps but with only one swappable pair (proposal = 'multiple') - # This is specifically for testing n_swap_attempts in the case where n_ex reduces to 1. - random.seed(0) - REXEE = get_REXEE_instance(params_dict) - states = [0, 2, 3, 8] # The only swappable pair is [(1, 2)] --> accept - f = copy.deepcopy(dhdl_files) - pattern_5_2, swap_list_5_2 = REXEE.get_swapping_pattern(f, states) - assert REXEE.n_swap_attempts == 1 # since there is only 1 swappable pair - assert REXEE.n_rejected == 0 - assert pattern_5_2 == [0, 2, 1, 3] - assert swap_list_5_2 == [(1, 2)] - """ - def test_calc_prob_acc(self, capfd, params_dict): # k = 1.380649e-23; NA = 6.0221408e23; T = 298; kT = k * NA * T / 1000 = 2.4777098766670016 REXEE = get_REXEE_instance(params_dict) @@ -587,9 +557,9 @@ def test_weight_correction(self, params_dict): weights_2 = REXEE.weight_correction(weights_2, counts_2) assert np.allclose(weights_2, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]]) - def test_combine_weights_1(self, params_dict): + def test_combine_weights(self, params_dict): """ - Here we just test the combined weights, so the values of hist does not matter. + Here we just test the combined weights. """ REXEE = get_REXEE_instance(params_dict) REXEE.n_tot = 6 @@ -597,37 +567,34 @@ def test_combine_weights_1(self, params_dict): REXEE.s = 1 REXEE.n_sim = 3 REXEE.state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]] - weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] - hist = [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]] - _, w_1, g_vec_1 = REXEE.combine_weights(hist, weights) + # Test 1: simple means + weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] + w_1, g_vec_1 = REXEE.combine_weights(weights) assert np.allclose(w_1, [ [0, 2.1, 3.9, 3.5], [0, 1.8, 1.4, 2.75], [0, -0.4, 0.95, 1.95]]) assert np.allclose(list(g_vec_1), [0, 2.1, 3.9, 3.5, 4.85, 5.85]) + # Test 2: weighted means weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] errors = [[0, 0.1, 0.15, 0.1], [0, 0.12, 0.1, 0.12], [0, 0.12, 0.15, 0.1]] - _, w_2, g_vec_2 = REXEE.combine_weights(hist, weights, errors) + w_2, g_vec_2 = REXEE.combine_weights(weights, errors) assert np.allclose(w_2, [ [0, 2.1, 3.86140725, 3.45417313], [0, 1.76140725, 1.35417313, 2.71436889], [0, -0.40723412, 0.95296164, 1.95296164]]) assert np.allclose(list(g_vec_2), [0, 2.1, 3.861407249466951, 3.4541731330165306, 4.814368891580968, 5.814368891580968]) # noqa: E501 - def test_combine_weights_2(self, params_dict): - """ - Here we just test the modified histograms, so the values of weights does not matter. - """ + def test_histogram_correction(self, params_dict): REXEE = get_REXEE_instance(params_dict) REXEE.n_tot = 6 REXEE.n_sub = 5 REXEE.s = 1 REXEE.n_sim = 2 REXEE.state_ranges = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]] - weights = [[0, 2.1, 4.0, 3.7, 5], [0, 1.7, 1.2, 2.6, 4]] hist = [[416, 332, 130, 71, 61], [303, 181, 123, 143, 260]] - hist_modified, _, _ = REXEE.combine_weights(hist, weights) + hist_modified = REXEE.histogram_correction(hist) assert hist_modified == [[416, 332, 161, 98, 98], [332, 161, 98, 98, 178]]