diff --git a/docs/simulations.rst b/docs/simulations.rst index 69ecb7b3..e86ce445 100644 --- a/docs/simulations.rst +++ b/docs/simulations.rst @@ -277,34 +277,21 @@ include parameters for data analysis here. This could be useful for REXEE simulations for multiple serial mutations, where we enforce exchanges between states 4 and 5 (and 14 and 15) and perform coordinate manipulation. - :code:`proposal`: (Optional, Default: :code:`exhaustive`) - The method for proposing simulations to be swapped. Available options include :code:`single`, :code:`exhaustive`, :code:`neighboring`, and :code:`multiple`. + The method for proposing simulations to be swapped. Available options include :code:`single`, :code:`neighboring`, and :code:`exhaustive`. For more details, please refer to :ref:`doc_proposal`. - :code:`acceptance`: (Optional, Default: :code:`metropolis`) The Monte Carlo method for swapping simulations. Available options include :code:`same-state`/:code:`same_state`, :code:`metropolis`, and :code:`metropolis-eq`/:code:`metropolis_eq`. For more details, please refer to :ref:`doc_acceptance`. - - :code:`w_combine`: (Optional, Default: :code:`None`) - The type of weights to be combined across multiple replicas in a weight-updating REXEE simulation. The following options are available: - - - :code:`None`: No weight combination. - - :code:`final`: Combine the final weights. - - :code:`avg`: Combine the weights averaged over from last time the Wang-Landau incrementor was updated. - - For more details about weight combination, please refer to :ref:`doc_w_schemes`. - - - :code:`rmse_cutoff`: (Optional, Default: :code:`None`) - The cutoff for the root-mean-square error (RMSE) between the weights of the current iteration - and the weights averaged over from the last time the Wang-Landau incrementor was updated. - For each replica, the RMSE between the averaged weights and the current weights will be calculated. - When :code:`rmse_cutoff` is specified, weight combination will be performed only if the maximum RMSE across all replicas - is smaller than the cutoff. Otherwise, weight combination is deactivated (even if :code:`w_combine` is specified) - because a larger RMSE indicates that the weights are noisy and should not be combined. - The default value is infinity, which means that weight combination will always be performed if :code:`w_combine` is specified. - The units of the cutoff are :math:`k_B T`. + - :code:`w_combine`: (Optional, Default: :code:`False`) + Whether to perform weight combination or not. Note that weights averaged over from the last time the Wang-Landau incrementor was updated (instead of + final weights) will be used for weight combination. For more details about weight combination, please refer to :ref:`doc_w_schemes`. + - :code:`w_mean_type`: (Optional, Default: code:`simple`) + The type of mean to use when combining weights. Available options include :code:`simple` and :code:`weighted`. + For the later case, inverse-variance weighted means are used. - :code:`N_cutoff`: (Optional, Default: 1000) - The histogram cutoff. -1 means that no histogram correction will be performed. - - :code:`n_ex`: (Optional, Default: 1) - The number of attempts swap during an exchange interval. This option is only relevant if the option :code:`proposal` is :code:`multiple`. - Otherwise, this option is ignored. For more details, please refer to :ref:`doc_multiple_swaps`. + The histogram cutoff for weight corrections. -1 means that no histogram correction will be performed. + - :code:`hist_corr` (Optional, Default: :code:`False`) + Whether to perform histogram correction. - :code:`mdp_args`: (Optional, Default: :code:`None`) MDP parameters differing across replicas provided in a dictionary. For each key in the dictionary, the value should always be a list of length of the number of replicas. For example, :code:`{'ref_p': [1.0, 1.01, 1.02, 1.03]}` means that the @@ -389,10 +376,10 @@ infinity internally. add_swappables: null proposal: 'exhaustive' acceptance: 'metropolis' - w_combine: null - rmse_cutoff: null + w_combine: False + w_mean_type: 'simple' N_cutoff: 1000 - n_ex: 1 + hist_corr: False mdp_args: null grompp_args: null runtime_args: null diff --git a/docs/theory.rst b/docs/theory.rst index f3149f67..aae99f43 100644 --- a/docs/theory.rst +++ b/docs/theory.rst @@ -1,6 +1,8 @@ .. _doc_basic_idea: -.. note:: This page is still a work in progress. Please refer to our paper for more details before this page is completed. +.. note:: This page is still a work in progress. Please check `Issue 33`_ for the current progress. + +.. _`Issue 33`: https://github.com/wehs7661/ensemble_md/issues/33 1. Basic idea ============= diff --git a/ensemble_md/cli/run_REXEE.py b/ensemble_md/cli/run_REXEE.py index 4462ad86..f26bf57d 100644 --- a/ensemble_md/cli/run_REXEE.py +++ b/ensemble_md/cli/run_REXEE.py @@ -165,53 +165,81 @@ def main(): if wl_delta != [None for i in range(REXEE.n_sim)]: # weight-updating print(f'\nCurrent Wang-Landau incrementors: {wl_delta}\n') - # (1) First we prepare the weights to be combined. + # (1) First we prepare the time-averaged weights to be combined, if needed. # Note that although averaged weights are sometimes used for weight correction/weight combination, # the final weights are always used for calculating the acceptance ratio. if REXEE.N_cutoff != -1 or REXEE.w_combine is not None: # Only when weight correction/weight combination is needed. weights_avg, weights_err = REXEE.get_averaged_weights(log_files) - weights_input = REXEE.prepare_weights(weights_avg, weights) # weights_input is for weight combination # noqa: E501 + + # Calculate the RMSE between the averaged weights and the final weights by the way. + rmse_list = [utils.calc_rmse(weights_avg[i], weights[i]) for i in range(REXEE.n_sim)] + rmse_str = ', '.join([f'{i:.2f}' for i in rmse_list]) + print(f'RMSE between the final weights and time-averaged weights for each replica: {rmse_str} kT') # (2) Now we perform weight correction/weight combination. # The product of this step should always be named as "weights" to be used in update_MDP - if REXEE.N_cutoff != -1 and REXEE.w_combine is not None: - # perform both - if weights_input is None: - # Then only weight correction will be performed - print('Note: Weight combination is deactivated because the weights are too noisy.') - weights = REXEE.weight_correction(weights, counts) - _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 + if REXEE.N_cutoff != -1 and REXEE.w_combine is True: + # Perform both weight correction and weight combination + if REXEE.verbose is True: + print('Performing weight correction ...') + else: + print('Performing weight correction ...', end='') + weights_preprocessed = REXEE.weight_correction(weights_avg, counts) + + if REXEE.verbose is True: + print('Performing weight combination ...') else: - weights_preprocessed = REXEE.weight_correction(weights_input, counts) - if REXEE.verbose is True: - print('Performing weight combination ...') - else: - print('Performing weight combination ...', end='') - counts, weights, g_vec = REXEE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 - REXEE.g_vecs.append(g_vec) - elif REXEE.N_cutoff == -1 and REXEE.w_combine is not None: - # only perform weight combination + print('Performing weight combination ...', end='') + if REXEE.w_mean_type == 'simple': + weights, g_vec = REXEE.combine_weights(weights_preprocessed) # simple means + else: + # Note that here weights_err are acutally not the uncertainties for weights_prepocessed + # but weights_avg ... We might need to disable this feature in the future. + weights, g_vec = REXEE.combine_weights(weights_preprocessed, weights_err) # inverse-variance weighting # noqa: E501 + REXEE.g_vecs.append(g_vec) + + # Check if histogram correction is needed after weight combination + if REXEE.hist_corr is True: + print('Performing histogram correction ...') + counts = REXEE.histogram_correction(counts_) + else: + print('Note: No histogram correction will be performed.') + + elif REXEE.N_cutoff == -1 and REXEE.w_combine is True: + # Only perform weight combination print('Note: No weight correction will be performed.') - if weights_input is None: - print('Note: Weight combination is deactivated because the weights are too noisy.') - _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 + if REXEE.verbose is True: + print('Performing weight combination ...') else: - if REXEE.verbose is True: - print('Performing weight combination ...') - else: - print('Performing weight combination ...', end='') - counts, weights, g_vec = REXEE.combine_weights(counts_, weights_input) # inverse-variance weighting seems worse # noqa: E501 - REXEE.g_vecs.append(g_vec) - elif REXEE.N_cutoff != -1 and REXEE.w_combine is None: - # only perform weight correction + print('Performing weight combination ...', end='') + if REXEE.w_mean_type == 'simple': + weights, g_vec = REXEE.combine_weights(weights_avg) # simple means + else: + weights, g_vec = REXEE.combine_weights(weights_avg, weights_err) # inverse-variance weighting + REXEE.g_vecs.append(g_vec) + + # Check if histogram correction is needed after weight combination + if REXEE.hist_corr is True: + print('Performing histogram correction ...') + counts = REXEE.histogram_correction(counts_) + else: + print('Note: No histogram correction will be performed.') + + elif REXEE.N_cutoff != -1 and REXEE.w_combine is False: + # Only perform weight correction print('Note: No weight combination will be performed.') - weights = REXEE.weights_correction(weights_input, counts) - _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 + if REXEE.verbose is True: + print('Performing weight correction ...') + else: + print('Performing weight correction ...', end='') + weights = REXEE.weights_correction(weights_avg, counts) + _ = REXEE.combine_weights(weights, print_values=False)[1] # just to print the combined weights # noqa: E501 else: print('Note: No weight correction will be performed.') print('Note: No weight combination will be performed.') - _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 + # Note that in this case, the final weights will be used in the next iteration. + _ = REXEE.combine_weights(weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 # 3-5. Modify the MDP files and swap out the GRO files (if needed) # Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501 diff --git a/ensemble_md/replica_exchange_EE.py b/ensemble_md/replica_exchange_EE.py index bd60f297..90a33ebb 100644 --- a/ensemble_md/replica_exchange_EE.py +++ b/ensemble_md/replica_exchange_EE.py @@ -162,10 +162,10 @@ def set_params(self, analysis): "nst_sim": None, "proposal": 'exhaustive', "acceptance": "metropolis", - "w_combine": None, - "rmse_cutoff": np.inf, + "w_combine": False, + "w_mean_type": 'simple', "N_cutoff": 1000, - "n_ex": 'N^3', # only active for multiple swaps. + "hist_corr": False, "verbose": True, "mdp_args": None, "grompp_args": None, @@ -181,6 +181,7 @@ def set_params(self, analysis): "err_method": "propagate", "n_bootstrap": 50, "seed": None, + # "n_ex": 'N^3', # only active for multiple swaps. } for i in optional_args: if hasattr(self, i) is False or getattr(self, i) is None: @@ -193,48 +194,42 @@ def set_params(self, analysis): self.warnings.append(f'Warning: Parameter "{i}" specified in the input YAML file is not recognizable.') # Step 4: Check if the parameters in the YAML file are well-defined - if self.proposal not in [None, 'single', 'neighboring', 'exhaustive', 'multiple']: - raise ParameterError("The specified proposal scheme is not available. Available options include 'single', 'neighboring', 'exhaustive', and 'multiple'.") # noqa: E501 + if self.proposal not in [None, 'single', 'neighboring', 'exhaustive']: # deprecated option: multiple + raise ParameterError("The specified proposal scheme is not available. Available options include 'single', 'neighboring', and 'exhaustive'.") # noqa: E501 if self.acceptance not in [None, 'same-state', 'same_state', 'metropolis']: raise ParameterError("The specified acceptance scheme is not available. Available options include 'same-state' and 'metropolis'.") # noqa: E501 - if self.w_combine not in [None, 'final', 'avg']: - raise ParameterError("The specified type of weight to be combined is not available. Available options include 'final' and 'avg'.") # noqa: E501 - if self.df_method not in [None, 'TI', 'BAR', 'MBAR']: raise ParameterError("The specified free energy estimator is not available. Available options include 'TI', 'BAR', and 'MBAR'.") # noqa: E501 if self.err_method not in [None, 'propagate', 'bootstrap']: raise ParameterError("The specified method for error estimation is not available. Available options include 'propagate', and 'bootstrap'.") # noqa: E501 - if self.w_combine is not None and self.rmse_cutoff == np.inf: - self.warnings.append('Warning: We recommend setting rmse_cutoff when w_combine is used.') - - if self.rmse_cutoff != np.inf: - if type(self.rmse_cutoff) is not float and type(self.rmse_cutoff) is not int: - raise ParameterError("The parameter 'rmse_cutoff' should be a float.") - params_int = ['n_sim', 'n_iter', 's', 'N_cutoff', 'df_spacing', 'n_ckpt', 'n_bootstrap'] # integer parameters # noqa: E501 if self.nst_sim is not None: params_int.append('nst_sim') + """ if self.n_ex != 'N^3': # no need to add "and self.proposal == 'multiple' since if multiple swaps are not used, n_ex=1" # noqa: E501 params_int.append('n_ex') + """ if self.seed is not None: params_int.append('seed') for i in params_int: if type(getattr(self, i)) != int: raise ParameterError(f"The parameter '{i}' should be an integer.") - params_pos = ['n_sim', 'n_iter', 'n_ckpt', 'df_spacing', 'n_bootstrap', 'rmse_cutoff'] # positive parameters + params_pos = ['n_sim', 'n_iter', 'n_ckpt', 'df_spacing', 'n_bootstrap'] # positive parameters if self.nst_sim is not None: params_pos.append('nst_sim') for i in params_pos: if getattr(self, i) <= 0: raise ParameterError(f"The parameter '{i}' should be positive.") + """ if self.n_ex != 'N^3' and self.n_ex < 0: raise ParameterError("The parameter 'n_ex' should be non-negative.") + """ if self.s < 0: raise ParameterError("The parameter 's' should be non-negative.") @@ -256,7 +251,7 @@ def set_params(self, analysis): if type(getattr(self, i)) != str: raise ParameterError(f"The parameter '{i}' should be a string.") - params_bool = ['verbose', 'rm_cpt', 'msm', 'free_energy', 'subsampling_avg'] + params_bool = ['verbose', 'rm_cpt', 'msm', 'free_energy', 'subsampling_avg', 'w_combine'] for i in params_bool: if type(getattr(self, i)) != bool: raise ParameterError(f"The parameter '{i}' should be a boolean variable.") @@ -319,11 +314,11 @@ def set_params(self, analysis): self.equilibrated_weights = [None for i in range(self.n_sim)] if self.fixed_weights is True: - if self.N_cutoff != -1 or self.w_combine is not None: + if self.N_cutoff != -1 or self.w_combine is True: self.warnings.append('Warning: The weight correction/weight combination method is specified but will not be used since the weights are fixed.') # noqa: E501 # In the case that the warning is ignored, enforce the defaults. self.N_cutoff = -1 - self.w_combine = None + self.w_combine = False if 'lmc_seed' in self.template and self.template['lmc_seed'] != -1: self.warnings.append('Warning: We recommend setting lmc_seed as -1 so the random seed is different for each iteration.') # noqa: E501 @@ -497,17 +492,19 @@ def print_params(self, params_analysis=False): print(f"Verbose log file: {self.verbose}") print(f"Proposal scheme: {self.proposal}") print(f"Acceptance scheme for swapping simulations: {self.acceptance}") - print(f"Type of weights to be combined: {self.w_combine}") - print(f"Histogram cutoff: {self.N_cutoff}") + print(f"Whether to perform weight combination: {self.w_combine}") + print(f"Type of means for weight combination: {self.w_mean_type}") + print(f"Whether to perform histogram correction: {self.hist_corr}") + print(f"Histogram cutoff for weight correction: {self.N_cutoff}") print(f"Number of replicas: {self.n_sim}") print(f"Number of iterations: {self.n_iter}") - print(f"Number of attempted swaps in one exchange interval: {self.n_ex}") print(f"Length of each replica: {self.dt * self.nst_sim} ps") print(f"Frequency for checkpointing: {self.n_ckpt} iterations") print(f"Total number of states: {self.n_tot}") print(f"Additionally defined swappable states: {self.add_swappables}") print(f"Additional grompp arguments: {self.grompp_args}") print(f"Additional runtime arguments: {self.runtime_args}") + # print(f"Number of attempted swaps in one exchange interval: {self.n_ex}") if self.mdp_args is not None and len(self.mdp_args.keys()) > 1: print("MDP parameters differing across replicas:") for i in self.mdp_args.keys(): @@ -864,23 +861,24 @@ def get_swapping_pattern(self, dhdl_files, states): A list of tuples showing the accepted swaps. """ swap_list = [] - if self.proposal != 'multiple': - if self.proposal == 'exhaustive': - n_ex = int(np.floor(self.n_sim / 2)) # This is the maximum, not necessarily the number that will always be reached. # noqa - n_ex_exhaustive = 0 # The actual number of swaps atttempted. - else: - n_ex = 1 # single swap or neighboring swap + if self.proposal == 'exhaustive': + n_ex = int(np.floor(self.n_sim / 2)) # This is the maximum, not necessarily the number that will always be reached. # noqa + n_ex_exhaustive = 0 # The actual number of swaps atttempted. else: - # multiple swaps - if self.n_ex == 'N^3': - n_ex = self.n_tot ** 3 - else: - n_ex = self.n_ex + n_ex = 1 # single swap or neighboring swap + + """ + # multiple swaps + if self.n_ex == 'N^3': + n_ex = self.n_tot ** 3 + else: + n_ex = self.n_ex + """ shifts = list(self.s * np.arange(self.n_sim)) swap_pattern = list(range(self.n_sim)) # Can be regarded as the indices of DHDL files/configurations state_ranges = copy.deepcopy(self.state_ranges) - states_copy = copy.deepcopy(states) # only for re-identifying swappable pairs given updated state_ranges + # states_copy = copy.deepcopy(states) # only for re-identifying swappable pairs given updated state_ranges --> was needed for the multiple exchange proposal scheme # noqa: E501 swappables = ReplicaExchangeEE.identify_swappable_pairs(states, state_ranges, self.proposal == 'neighboring', self.add_swappables) # noqa: E501 # Note that if there is only 1 swappable pair, then it will still be the only swappable pair @@ -947,11 +945,13 @@ def get_swapping_pattern(self, dhdl_files, states): state_ranges[swap[0]], state_ranges[swap[1]] = state_ranges[swap[1]], state_ranges[swap[0]] self.configs[swap[0]], self.configs[swap[1]] = self.configs[swap[1]], self.configs[swap[0]] + """ if n_ex > 1 and self.proposal == 'multiple': # must be multiple swaps # After state_ranges have been updated, we re-identify the swappable pairs. # Notably, states_copy (instead of states) should be used. (They could be different.) swappables = ReplicaExchangeEE.identify_swappable_pairs(states_copy, state_ranges, self.proposal == 'neighboring', self.add_swappables) # noqa: E501 print(f" New swappable pairs: {swappables}") + """ else: # In this case, there is no need to update the swappables pass @@ -1077,6 +1077,44 @@ def accept_or_reject(self, prob_acc): print(" Swap rejected! ") return swap_bool + def get_averaged_weights(self, log_files): + """ + For each replica, calculate the averaged weights (and the associated error) from the time series + of the weights since the previous update of the Wang-Landau incrementor. + + Parameters + ---------- + log_files : list + A list of file paths to GROMACS LOG files of different replicas. + + Returned + -------- + weights_avg : list + A list of lists of weights averaged since the last update of the Wang-Landau + incrementor. The length of the list should be the number of replicas. + weights_err : list + A list of lists of errors corresponding to the values in :code:`weights_avg`. + """ + for i in range(self.n_sim): + weights, _, wl_delta, _ = gmx_parser.parse_log(log_files[i]) + if self.current_wl_delta[i] == wl_delta: + self.updating_weights[i] += weights # expand the list + else: + self.current_wl_delta[i] = wl_delta + self.updating_weights[i] = weights + + # shape of self.updating_weights is (n_sim, n_points, n_states), but n_points can be different + # for different replicas, which will error out np.mean(self.updating_weights, axis=1) + weights_avg = [np.mean(self.updating_weights[i], axis=0).tolist() for i in range(self.n_sim)] + weights_err = [] + for i in range(self.n_sim): + if len(self.updating_weights[i]) == 1: # this would lead to a RunTime Warning and nan + weights_err.append([0] * self.n_sub) # in `weighted_mean``, a simple average will be returned. + else: + weights_err.append(np.std(self.updating_weights[i], axis=0, ddof=1).tolist()) + + return weights_avg, weights_err + def weight_correction(self, weights, counts): """ Corrects the lambda weights based on the histogram counts. Namely, @@ -1099,11 +1137,7 @@ def weight_correction(self, weights, counts): weights : list An updated list of lists of corected weights. """ - if self.verbose is True: - print("\nPerforming weight correction for the lambda weights ...") - else: - print("\nPerforming weight correction for the lambda weights ...", end="") - + skip_correction = False for i in range(len(weights)): # loop over the replicas if self.verbose is True: print(f" Counts of rep {i}:\t\t{counts[i]}") @@ -1113,8 +1147,11 @@ def weight_correction(self, weights, counts): if counts[i][j - 1] != 0 and counts[i][j - 1] != 0: if np.min([counts[i][j - 1], counts[i][j]]) > self.N_cutoff: weights[i][j] += np.log(counts[i][j - 1] / counts[i][j]) + else: + skip_correction = True + print('Weight correction was deactivated because neither N_{k-1} or N_k is larger than the histogram cutoff.') # noqa: E501 - if self.verbose is True: + if self.verbose is True and skip_correction is False: print(f' Corrected weights of rep {i}:\t{[float(f"{k:.3f}") for k in weights[i]]}\n') if self.verbose is False: @@ -1122,183 +1159,144 @@ def weight_correction(self, weights, counts): return weights - def get_averaged_weights(self, log_files): + def histogram_correction(self, hist, print_values=True): """ - For each replica, calculate the averaged weights (and the associated error) from the time series - of the weights since the previous update of the Wang-Landau incrementor. + Adjust the histogram counts. Specifically, the ratio of corrected histogram counts + for adjancent states is the geometric mean of the ratio of the original histogram counts + for the same states. Note, however, if the histogram counts are 0 for some states, the + histogram correction will be skipped and the original histogram counts will be returned. Parameters ---------- - log_files : list - A list of file paths to GROMACS LOG files of different replicas. + hist : list + A list of lists of histogram counts of ALL simulations. + print_values : bool, optional + Whether to print the histograms for each replica before and after histogram correction. - Returned - -------- - weights_avg : list - A list of lists of weights averaged since the last update of the Wang-Landau - incrementor. The length of the list should be the number of replicas. - weights_err : list - A list of lists of errors corresponding to the values in :code:`weights_avg`. + Returns + ------- + hist_modified : list + A list of lists of modified histogram counts of ALL simulations. """ - for i in range(self.n_sim): - weights, _, wl_delta, _ = gmx_parser.parse_log(log_files[i]) - if self.current_wl_delta[i] == wl_delta: - self.updating_weights[i] += weights # expand the list - else: - self.current_wl_delta[i] = wl_delta - self.updating_weights[i] = weights + # (1) Print the original histogram counts + if print_values is True: + print(' Original histogram counts:') + for i in range(len(hist)): + print(f' Rep {i}: {hist[i]}') - # shape of self.updating_weights is (n_sim, n_points, n_states), but n_points can be different - # for different replicas, which will error out np.mean(self.updating_weights, axis=1) - weights_avg = [np.mean(self.updating_weights[i], axis=0).tolist() for i in range(self.n_sim)] - weights_err = [] - for i in range(self.n_sim): - if len(self.updating_weights[i]) == 1: # this would lead to a RunTime Warning and nan - weights_err.append([0] * self.n_sub) # in `weighted_mean``, a simple average will be returned. - else: - weights_err.append(np.std(self.updating_weights[i], axis=0, ddof=1).tolist()) + # (2) Calculate adjacent weight differences and g_vec + N_ratio_vec = [] # N_{k-1}/N_k for the whole range + with warnings.catch_warnings(): # Suppress the specific warning here + warnings.simplefilter("ignore", category=RuntimeWarning) + N_ratio_adjacent = [list(np.array(hist[i][1:]) / np.array(hist[i][:-1])) for i in range(len(hist))] - return weights_avg, weights_err + for i in range(self.n_tot - 1): + N_ratio_list = [] + for j in range(len(self.state_ranges)): + if i in self.state_ranges[j] and i + 1 in self.state_ranges[j]: + idx = self.state_ranges[j].index(i) + N_ratio_list.append(N_ratio_adjacent[j][idx]) + N_ratio_vec.append(np.prod(N_ratio_list) ** (1 / len(N_ratio_list))) # geometric mean + N_ratio_vec.insert(0, hist[0][0]) - def prepare_weights(self, weights_avg, weights_final): - """ - Prepared weights to be combined by the function :code:`combine_weights`. - For each replica, the RMSE between the averaged weights and the final weights is calculated. If the - maximum of the RMSEs of all replicas is smaller than the cutoff specified in the input YAML file - (the parameter :code:`rmse_cutoff`), either final weights or time-averaged weights will be used - (depending on the value of the parameter :code:`w_combine`). Otherwise, :code:`None` will be returned, - which will lead to deactivation of weight combination in the CLI :code:`run_REXEE`. + # (3) Check if the histogram counts are 0 for some states, if so, the histogram correction will be skipped. + # Zero histogram counts can happen when the sampling is poor or the WL incrementor just got updated + contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501 + contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501 + skip_hist_correction = contains_nan or contains_inf + if skip_hist_correction: + print('\n Histogram correction is skipped because the histogram counts are 0 for some states.') - Parameters - ---------- - weights_avg : list - A list of lists of weights averaged since the last update of the Wang-Landau - incrementor. The length of the list should be the number of replicas. - weights_final : list - A list of lists of final weights of all simulations. The length of the list should - be the number of replicas. + # (4) Perform histogram correction if it is not skipped + if skip_hist_correction is False: + print('\n Performing histogram correction ...') + # When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will + # still not error out so it's fine to not add the conditional statement like here, since we will + # have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like + # int(np.nan) will lead to an error, so we put an if condition here. + N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))]) - Returns - ------- - weights_output : list - A list of lists of weights to be combined. - """ - rmse_list = [utils.calc_rmse(weights_avg[i], weights_final[i]) for i in range(self.n_sim)] - rmse_str = ', '.join([f'{i:.2f}' for i in rmse_list]) - print(f'RMSE between the final weights and time-averaged weights for each replica: {rmse_str} kT') - if np.max(rmse_list) < self.rmse_cutoff: - # Weight combination will be activated - if self.w_combine == 'final': - weights_output = weights_final - elif self.w_combine == 'avg': - weights_output = weights_avg + if skip_hist_correction is False: + hist_modified = [list(N_vec[self.state_ranges[i]]) for i in range(self.n_sim)] else: - weights_output = None + hist_modified = hist # the original input histogram - return weights_output + # (5) Print the modified histogram counts + if print_values is True: + print('\n Modified histogram counts:') + for i in range(len(hist_modified)): + print(f' Rep {i}: {hist_modified[i]}') + + return hist_modified - def combine_weights(self, hist, weights, weights_err=None, print_values=True): + def combine_weights(self, weights, weights_err=None, print_values=True): """ - Combine alchemical weights across multiple replicas and adjusts the histogram counts - correspondingly. Note that if :code:`weights_err` is provided, inverse-variance weighting will be used. + Combine alchemical weights across multiple replicas. Note that if + :code:`weights_err` is provided, inverse-variance weighting will be used. Care must be taken since inverse-variance weighting can lead to slower convergence if the provided errors are not accurate. (See :ref:`doc_w_schemes` for mor details.) Parameters ---------- - hist : list - A list of lists of histogram counts of ALL simulations. weights : list A list of lists alchemical weights of ALL simulations. weights_err : list, optional A list of lists of errors corresponding to the values in :code:`weights`. print_values : bool, optional - Whether to print the histograms and weights for each replica before and - after weight combinationfor each replica. + Whether to print the weights for each replica before and + after weight combination for each replica. Returns ------- - hist_modified : list - A list of modified histogram counts of ALL simulations. weights_modified : list A list of modified Wang-Landau weights of ALL simulations. g_vec : np.ndarray An array of alchemical weights of the whole range of states. """ + # (1) Print the original weights if print_values is True: w = np.round(weights, decimals=3).tolist() # just for printing print(' Original weights:') for i in range(len(w)): print(f' Rep {i}: {w[i]}') - print('\n Original histogram counts:') - for i in range(len(hist)): - print(f' Rep {i}: {hist[i]}') - # Calculate adjacent weight differences and g_vec - dg_vec, N_ratio_vec = [], [] # alchemical weight differences and histogram count ratios for the whole range + # (2) Calculate adjacent weight differences and g_vec + dg_vec = [] # alchemical weight differences for the whole range dg_adjacent = [list(np.diff(weights[i])) for i in range(len(weights))] - # Suppress the specific warning here - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=RuntimeWarning) - N_ratio_adjacent = [list(np.array(hist[i][1:]) / np.array(hist[i][:-1])) for i in range(len(hist))] - - # Below we deal with the case where the sampling is poor or the WL incrementor just got updated such that - # the histogram counts are 0 for some states, in which case we simply skip histogram correction. - contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501 - contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501 - skip_hist_correction = contains_nan or contains_inf - if skip_hist_correction: - print('\n Histogram correction is skipped because the histogram counts are 0 for some states.') if weights_err is not None: dg_adjacent_err = [[np.sqrt(weights_err[i][j] ** 2 + weights_err[i][j + 1] ** 2) for j in range(len(weights_err[i]) - 1)] for i in range(len(weights_err))] # noqa: E501 for i in range(self.n_tot - 1): - dg_list, dg_err_list, N_ratio_list = [], [], [] + dg_list, dg_err_list = [], [] for j in range(len(self.state_ranges)): if i in self.state_ranges[j] and i + 1 in self.state_ranges[j]: idx = self.state_ranges[j].index(i) dg_list.append(dg_adjacent[j][idx]) - N_ratio_list.append(N_ratio_adjacent[j][idx]) if weights_err is not None: dg_err_list.append(dg_adjacent_err[j][idx]) if weights_err is None: dg_vec.append(np.mean(dg_list)) else: dg_vec.append(utils.weighted_mean(dg_list, dg_err_list)[0]) - N_ratio_vec.append(np.prod(N_ratio_list) ** (1 / len(N_ratio_list))) # geometric mean + dg_vec.insert(0, 0) - N_ratio_vec.insert(0, hist[0][0]) g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))]) - if skip_hist_correction is False: - # When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will - # still not error out so it's fine to not add the conditional statement like here, since we will - # have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like - # int(np.nan) will lead to an error, so we put an if condition here. - N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))]) - # Determine the vector of alchemical weights and histogram counts for each replica + # (3) Determine the vector of alchemical weights for each replica weights_modified = np.zeros_like(weights) for i in range(self.n_sim): - hist_modified = [] if self.equil[i] == -1: # unequilibrated weights_modified[i] = list(g_vec[i * self.s: i * self.s + self.n_sub] - g_vec[i * self.s: i * self.s + self.n_sub][0]) # noqa: E501 else: weights_modified[i] = self.equilibrated_weights[i] - if skip_hist_correction is False: - hist_modified = [list(N_vec[self.state_ranges[i]]) for i in range(self.n_sim)] - else: - hist_modified = hist + # (4) Print the modified weights if print_values is True: w = np.round(weights_modified, decimals=3).tolist() # just for printing print('\n Modified weights:') for i in range(len(w)): print(f' Rep {i}: {w[i]}') - if skip_hist_correction is False: - print('\n Modified histogram counts:') - for i in range(len(hist_modified)): - print(f' Rep {i}: {hist_modified[i]}') if self.verbose is False: print(' DONE') @@ -1306,7 +1304,7 @@ def combine_weights(self, hist, weights, weights_err=None, print_values=True): else: print(f'\n The alchemical weights of all states: \n {list(np.round(g_vec, decimals=3))}') - return hist_modified, weights_modified, g_vec + return weights_modified, g_vec def _run_grompp(self, n, swap_pattern): """ diff --git a/ensemble_md/tests/test_replica_exchange_EE.py b/ensemble_md/tests/test_replica_exchange_EE.py index 2ac56cd3..05ebf744 100644 --- a/ensemble_md/tests/test_replica_exchange_EE.py +++ b/ensemble_md/tests/test_replica_exchange_EE.py @@ -82,14 +82,13 @@ def test_set_params_error(self, params_dict): params_dict['n_sim'] = 4 # so params_dict can be read without failing in the assertions below # 2. Available options - check_param_error(params_dict, 'proposal', "The specified proposal scheme is not available. Available options include 'single', 'neighboring', 'exhaustive', and 'multiple'.", 'cool', 'multiple') # set as multiple for later tests for n_ex # noqa: E501 + check_param_error(params_dict, 'proposal', "The specified proposal scheme is not available. Available options include 'single', 'neighboring', and 'exhaustive'.", 'cool', 'exhaustive') # noqa: E501 check_param_error(params_dict, 'acceptance', "The specified acceptance scheme is not available. Available options include 'same-state' and 'metropolis'.") # noqa: E501 check_param_error(params_dict, 'df_method', "The specified free energy estimator is not available. Available options include 'TI', 'BAR', and 'MBAR'.") # noqa: E501 check_param_error(params_dict, 'err_method', "The specified method for error estimation is not available. Available options include 'propagate', and 'bootstrap'.") # noqa: E501 # 3. Integer parameters check_param_error(params_dict, 'nst_sim', "The parameter 'nst_sim' should be an integer.") - check_param_error(params_dict, 'n_ex', "The parameter 'n_ex' should be an integer.") check_param_error(params_dict, 'seed', "The parameter 'seed' should be an integer.") check_param_error(params_dict, 'n_sim', "The parameter 'n_sim' should be an integer.", 4.1, 4) @@ -98,7 +97,6 @@ def test_set_params_error(self, params_dict): check_param_error(params_dict, 'n_iter', "The parameter 'n_iter' should be positive.", 0, 10) # 5. Non-negative parameters - check_param_error(params_dict, 'n_ex', "The parameter 'n_ex' should be non-negative.", -1) check_param_error(params_dict, 'N_cutoff', "The parameter 'N_cutoff' should be non-negative unless no weight correction is needed, i.e. N_cutoff = -1.", -5) # noqa: E501 # 6. String parameters @@ -169,9 +167,8 @@ def test_set_params(self, params_dict): # 2. Check the default values of the parameters not specified in params.yaml assert REXEE.proposal == "exhaustive" assert REXEE.acceptance == "metropolis" - assert REXEE.w_combine is None + assert REXEE.w_combine is False assert REXEE.N_cutoff == 1000 - assert REXEE.n_ex == 'N^3' assert REXEE.verbose is True assert REXEE.runtime_args is None assert REXEE.n_ckpt == 100 @@ -270,9 +267,11 @@ def test_print_params(self, capfd, params_dict): L += "Verbose log file: True\n" L += "Proposal scheme: exhaustive\n" L += "Acceptance scheme for swapping simulations: metropolis\n" - L += "Type of weights to be combined: None\n" - L += "Histogram cutoff: 1000\nNumber of replicas: 4\nNumber of iterations: 10\n" - L += "Number of attempted swaps in one exchange interval: N^3\n" + L += "Whether to perform weight combination: False\n" + L += "Type of means for weight combination: simple\n" + L += "Whether to perform histogram correction: False\n" + L += "Histogram cutoff for weight correction: 1000\n" + L += "Number of replicas: 4\nNumber of iterations: 10\n" L += "Length of each replica: 1.0 ps\nFrequency for checkpointing: 100 iterations\n" L += "Total number of states: 9\n" L += "Additionally defined swappable states: None\n" @@ -485,35 +484,6 @@ def test_get_swapping_pattern(self, params_dict): assert pattern_4_2 == [1, 0, 3, 2] assert swap_list_4_2 == [(2, 3), (0, 1)] - """ - We will deprecate multiple swaps anyway - # Case 5-1: Multiple swaps (proposal = 'multiple', n_ex = 5) - print('test 5-1') - random.seed(0) - REXEE = get_REXEE_instance(params_dict) - REXEE.n_ex = 5 - REXEE.proposal = 'multiple' - states = [3, 1, 4, 6] # swappable pairs: [(0, 1), (0, 2), (1, 2)], first swap = (0, 2), accept - f = copy.deepcopy(dhdl_files) - pattern_5_1, swap_list_5_1 = REXEE.get_swapping_pattern(f, states) - assert REXEE.n_swap_attempts == 5 - assert REXEE.n_rejected == 4 - assert pattern_5_1 == [2, 1, 0, 3] - assert swap_list_5_1 == [(0, 2)] - - # Case 5-2: Multiple swaps but with only one swappable pair (proposal = 'multiple') - # This is specifically for testing n_swap_attempts in the case where n_ex reduces to 1. - random.seed(0) - REXEE = get_REXEE_instance(params_dict) - states = [0, 2, 3, 8] # The only swappable pair is [(1, 2)] --> accept - f = copy.deepcopy(dhdl_files) - pattern_5_2, swap_list_5_2 = REXEE.get_swapping_pattern(f, states) - assert REXEE.n_swap_attempts == 1 # since there is only 1 swappable pair - assert REXEE.n_rejected == 0 - assert pattern_5_2 == [0, 2, 1, 3] - assert swap_list_5_2 == [(1, 2)] - """ - def test_calc_prob_acc(self, capfd, params_dict): # k = 1.380649e-23; NA = 6.0221408e23; T = 298; kT = k * NA * T / 1000 = 2.4777098766670016 REXEE = get_REXEE_instance(params_dict) @@ -587,9 +557,9 @@ def test_weight_correction(self, params_dict): weights_2 = REXEE.weight_correction(weights_2, counts_2) assert np.allclose(weights_2, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]]) - def test_combine_weights_1(self, params_dict): + def test_combine_weights(self, params_dict): """ - Here we just test the combined weights, so the values of hist does not matter. + Here we just test the combined weights. """ REXEE = get_REXEE_instance(params_dict) REXEE.n_tot = 6 @@ -597,37 +567,34 @@ def test_combine_weights_1(self, params_dict): REXEE.s = 1 REXEE.n_sim = 3 REXEE.state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]] - weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] - hist = [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]] - _, w_1, g_vec_1 = REXEE.combine_weights(hist, weights) + # Test 1: simple means + weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] + w_1, g_vec_1 = REXEE.combine_weights(weights) assert np.allclose(w_1, [ [0, 2.1, 3.9, 3.5], [0, 1.8, 1.4, 2.75], [0, -0.4, 0.95, 1.95]]) assert np.allclose(list(g_vec_1), [0, 2.1, 3.9, 3.5, 4.85, 5.85]) + # Test 2: weighted means weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] errors = [[0, 0.1, 0.15, 0.1], [0, 0.12, 0.1, 0.12], [0, 0.12, 0.15, 0.1]] - _, w_2, g_vec_2 = REXEE.combine_weights(hist, weights, errors) + w_2, g_vec_2 = REXEE.combine_weights(weights, errors) assert np.allclose(w_2, [ [0, 2.1, 3.86140725, 3.45417313], [0, 1.76140725, 1.35417313, 2.71436889], [0, -0.40723412, 0.95296164, 1.95296164]]) assert np.allclose(list(g_vec_2), [0, 2.1, 3.861407249466951, 3.4541731330165306, 4.814368891580968, 5.814368891580968]) # noqa: E501 - def test_combine_weights_2(self, params_dict): - """ - Here we just test the modified histograms, so the values of weights does not matter. - """ + def test_histogram_correction(self, params_dict): REXEE = get_REXEE_instance(params_dict) REXEE.n_tot = 6 REXEE.n_sub = 5 REXEE.s = 1 REXEE.n_sim = 2 REXEE.state_ranges = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]] - weights = [[0, 2.1, 4.0, 3.7, 5], [0, 1.7, 1.2, 2.6, 4]] hist = [[416, 332, 130, 71, 61], [303, 181, 123, 143, 260]] - hist_modified, _, _ = REXEE.combine_weights(hist, weights) + hist_modified = REXEE.histogram_correction(hist) assert hist_modified == [[416, 332, 161, 98, 98], [332, 161, 98, 98, 178]]