From 9bb2431c52ea72ad6ffae8fe121e1c864db0c056 Mon Sep 17 00:00:00 2001 From: Wei-Tse Hsu Date: Wed, 1 Nov 2023 19:44:37 -0600 Subject: [PATCH] Added hist_corr, developed histogram_correction, and modified run_REXEE.py correspondingly --- ensemble_md/cli/run_REXEE.py | 24 +++- ensemble_md/replica_exchange_EE.py | 195 +++++++++++++++++------------ 2 files changed, 133 insertions(+), 86 deletions(-) diff --git a/ensemble_md/cli/run_REXEE.py b/ensemble_md/cli/run_REXEE.py index 98a128da..9782a097 100644 --- a/ensemble_md/cli/run_REXEE.py +++ b/ensemble_md/cli/run_REXEE.py @@ -186,8 +186,16 @@ def main(): print('Performing weight combination ...') else: print('Performing weight combination ...', end='') - counts, weights, g_vec = REXEE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 + weights, g_vec = REXEE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 REXEE.g_vecs.append(g_vec) + + # Check if histogram correction is needed after weight combination + if REXEE.hist_corr is True: + print('Performing histogram correction ...') + counts = REXEE.histogram_correction(counts_) + else: + print('Note: No histogram correction will be performed.') + elif REXEE.N_cutoff == -1 and REXEE.w_combine is True: # Only perform weight combination print('Note: No weight correction will be performed.') @@ -195,18 +203,26 @@ def main(): print('Performing weight combination ...') else: print('Performing weight combination ...', end='') - counts, weights, g_vec = REXEE.combine_weights(counts_, weights_avg) + weights, g_vec = REXEE.combine_weights(weights_avg) REXEE.g_vecs.append(g_vec) + + # Check if histogram correction is needed after weight combination + if REXEE.hist_corr is True: + print('Performing histogram correction ...') + counts = REXEE.histogram_correction(counts_) + else: + print('Note: No histogram correction will be performed.') + elif REXEE.N_cutoff != -1 and REXEE.w_combine is False: # Only perform weight correction print('Note: No weight combination will be performed.') weights = REXEE.weights_correction(weights_avg, counts) - _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 + _ = REXEE.combine_weights(weights, print_values=False)[1] # just to print the combined weights # noqa: E501 else: print('Note: No weight correction will be performed.') print('Note: No weight combination will be performed.') # Note that in this case, the final weights will be used in the next iteration. - _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 + _ = REXEE.combine_weights(weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 # 3-5. Modify the MDP files and swap out the GRO files (if needed) # Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501 diff --git a/ensemble_md/replica_exchange_EE.py b/ensemble_md/replica_exchange_EE.py index d7d0f10d..e1e757b1 100644 --- a/ensemble_md/replica_exchange_EE.py +++ b/ensemble_md/replica_exchange_EE.py @@ -164,7 +164,7 @@ def set_params(self, analysis): "acceptance": "metropolis", "w_combine": False, "N_cutoff": 1000, - # "n_ex": 'N^3', # only active for multiple swaps. + "hist_corr": False, "verbose": True, "mdp_args": None, "grompp_args": None, @@ -180,6 +180,7 @@ def set_params(self, analysis): "err_method": "propagate", "n_bootstrap": 50, "seed": None, + # "n_ex": 'N^3', # only active for multiple swaps. } for i in optional_args: if hasattr(self, i) is False or getattr(self, i) is None: @@ -1073,6 +1074,44 @@ def accept_or_reject(self, prob_acc): print(" Swap rejected! ") return swap_bool + def get_averaged_weights(self, log_files): + """ + For each replica, calculate the averaged weights (and the associated error) from the time series + of the weights since the previous update of the Wang-Landau incrementor. + + Parameters + ---------- + log_files : list + A list of file paths to GROMACS LOG files of different replicas. + + Returned + -------- + weights_avg : list + A list of lists of weights averaged since the last update of the Wang-Landau + incrementor. The length of the list should be the number of replicas. + weights_err : list + A list of lists of errors corresponding to the values in :code:`weights_avg`. + """ + for i in range(self.n_sim): + weights, _, wl_delta, _ = gmx_parser.parse_log(log_files[i]) + if self.current_wl_delta[i] == wl_delta: + self.updating_weights[i] += weights # expand the list + else: + self.current_wl_delta[i] = wl_delta + self.updating_weights[i] = weights + + # shape of self.updating_weights is (n_sim, n_points, n_states), but n_points can be different + # for different replicas, which will error out np.mean(self.updating_weights, axis=1) + weights_avg = [np.mean(self.updating_weights[i], axis=0).tolist() for i in range(self.n_sim)] + weights_err = [] + for i in range(self.n_sim): + if len(self.updating_weights[i]) == 1: # this would lead to a RunTime Warning and nan + weights_err.append([0] * self.n_sub) # in `weighted_mean``, a simple average will be returned. + else: + weights_err.append(np.std(self.updating_weights[i], axis=0, ddof=1).tolist()) + + return weights_avg, weights_err + def weight_correction(self, weights, counts): """ Corrects the lambda weights based on the histogram counts. Namely, @@ -1118,152 +1157,144 @@ def weight_correction(self, weights, counts): return weights - def get_averaged_weights(self, log_files): + def histogram_correction(self, hist, print_values=True): """ - For each replica, calculate the averaged weights (and the associated error) from the time series - of the weights since the previous update of the Wang-Landau incrementor. + Adjust the histogram counts. Specifically, the ratio of corrected histogram counts + for adjancent states is the geometric mean of the ratio of the original histogram counts + for the same states. Note, however, if the histogram counts are 0 for some states, the + histogram correction will be skipped and the original histogram counts will be returned. Parameters ---------- - log_files : list - A list of file paths to GROMACS LOG files of different replicas. + hist : list + A list of lists of histogram counts of ALL simulations. + print_values : bool, optional + Whether to print the histograms for each replica before and after histogram correction. - Returned - -------- - weights_avg : list - A list of lists of weights averaged since the last update of the Wang-Landau - incrementor. The length of the list should be the number of replicas. - weights_err : list - A list of lists of errors corresponding to the values in :code:`weights_avg`. + Returns + ------- + hist_modified : list + A list of lists of modified histogram counts of ALL simulations. """ - for i in range(self.n_sim): - weights, _, wl_delta, _ = gmx_parser.parse_log(log_files[i]) - if self.current_wl_delta[i] == wl_delta: - self.updating_weights[i] += weights # expand the list - else: - self.current_wl_delta[i] = wl_delta - self.updating_weights[i] = weights + # (1) Print the original histogram counts + if print_values is True: + print(' Original histogram counts:') + for i in range(len(self.hist)): + print(f' Rep {i}: {self.hist[i]}') - # shape of self.updating_weights is (n_sim, n_points, n_states), but n_points can be different - # for different replicas, which will error out np.mean(self.updating_weights, axis=1) - weights_avg = [np.mean(self.updating_weights[i], axis=0).tolist() for i in range(self.n_sim)] - weights_err = [] - for i in range(self.n_sim): - if len(self.updating_weights[i]) == 1: # this would lead to a RunTime Warning and nan - weights_err.append([0] * self.n_sub) # in `weighted_mean``, a simple average will be returned. - else: - weights_err.append(np.std(self.updating_weights[i], axis=0, ddof=1).tolist()) + # (2) Calculate adjacent weight differences and g_vec + N_ratio_vec = [] # N_{k-1}/N_k for the whole range + with warnings.catch_warnings(): # Suppress the specific warning here + warnings.simplefilter("ignore", category=RuntimeWarning) + N_ratio_adjacent = [list(np.array(hist[i][1:]) / np.array(hist[i][:-1])) for i in range(len(hist))] - return weights_avg, weights_err + for i in range(self.n_tot - 1): + N_ratio_list = [] + for j in range(len(self.state_ranges)): + if i in self.state_ranges[j] and i + 1 in self.state_ranges[j]: + idx = self.state_ranges[j].index(i) + N_ratio_list.append(N_ratio_adjacent[j][idx]) + N_ratio_vec.append(np.prod(N_ratio_list) ** (1 / len(N_ratio_list))) # geometric mean + N_ratio_vec.insert(0, hist[0][0]) - def combine_weights(self, hist, weights, weights_err=None, print_values=True): + # (3) Check if the histogram counts are 0 for some states, if so, the histogram correction will be skipped. + # Zero histogram counts can happen when the sampling is poor or the WL incrementor just got updated + contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501 + contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501 + skip_hist_correction = contains_nan or contains_inf + if skip_hist_correction: + print('\n Histogram correction is skipped because the histogram counts are 0 for some states.') + + # (4) Perform histogram correction if it is not skipped + if skip_hist_correction is False: + print('\n Performing histogram correction ...') + # When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will + # still not error out so it's fine to not add the conditional statement like here, since we will + # have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like + # int(np.nan) will lead to an error, so we put an if condition here. + N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))]) + + if skip_hist_correction is False: + hist_modified = [list(N_vec[self.state_ranges[i]]) for i in range(self.n_sim)] + else: + hist_modified = hist # the original input histogram + + # (5) Print the modified histogram counts + if print_values is True: + print('\n Modified histogram counts:') + for i in range(len(hist_modified)): + print(f' Rep {i}: {hist_modified[i]}') + + return hist_modified + + def combine_weights(self, weights, weights_err=None, print_values=True): """ - Combine alchemical weights across multiple replicas and adjusts the histogram counts - correspondingly. Note that if :code:`weights_err` is provided, inverse-variance weighting will be used. + Combine alchemical weights across multiple replicas. Note that if + :code:`weights_err` is provided, inverse-variance weighting will be used. Care must be taken since inverse-variance weighting can lead to slower convergence if the provided errors are not accurate. (See :ref:`doc_w_schemes` for mor details.) Parameters ---------- - hist : list - A list of lists of histogram counts of ALL simulations. weights : list A list of lists alchemical weights of ALL simulations. weights_err : list, optional A list of lists of errors corresponding to the values in :code:`weights`. print_values : bool, optional - Whether to print the histograms and weights for each replica before and - after weight combinationfor each replica. + Whether to print the weights for each replica before and + after weight combination for each replica. Returns ------- - hist_modified : list - A list of modified histogram counts of ALL simulations. weights_modified : list A list of modified Wang-Landau weights of ALL simulations. g_vec : np.ndarray An array of alchemical weights of the whole range of states. """ - # (1) Print the original weights and histogram counts + # (1) Print the original weights if print_values is True: w = np.round(weights, decimals=3).tolist() # just for printing print(' Original weights:') for i in range(len(w)): print(f' Rep {i}: {w[i]}') - print('\n Original histogram counts:') - for i in range(len(hist)): - print(f' Rep {i}: {hist[i]}') # (2) Calculate adjacent weight differences and g_vec - # Note that N_ratio_vec and other similar variables are only used in histogram corrections - dg_vec, N_ratio_vec = [], [] # alchemical weight differences and histogram count ratios for the whole range + dg_vec = [] # alchemical weight differences for the whole range dg_adjacent = [list(np.diff(weights[i])) for i in range(len(weights))] - with warnings.catch_warnings(): # Suppress the specific warning here - warnings.simplefilter("ignore", category=RuntimeWarning) - N_ratio_adjacent = [list(np.array(hist[i][1:]) / np.array(hist[i][:-1])) for i in range(len(hist))] - if weights_err is not None: dg_adjacent_err = [[np.sqrt(weights_err[i][j] ** 2 + weights_err[i][j + 1] ** 2) for j in range(len(weights_err[i]) - 1)] for i in range(len(weights_err))] # noqa: E501 for i in range(self.n_tot - 1): - dg_list, dg_err_list, N_ratio_list = [], [], [] + dg_list, dg_err_list = [], [] for j in range(len(self.state_ranges)): if i in self.state_ranges[j] and i + 1 in self.state_ranges[j]: idx = self.state_ranges[j].index(i) dg_list.append(dg_adjacent[j][idx]) - N_ratio_list.append(N_ratio_adjacent[j][idx]) if weights_err is not None: dg_err_list.append(dg_adjacent_err[j][idx]) if weights_err is None: dg_vec.append(np.mean(dg_list)) else: dg_vec.append(utils.weighted_mean(dg_list, dg_err_list)[0]) - N_ratio_vec.append(np.prod(N_ratio_list) ** (1 / len(N_ratio_list))) # geometric mean + dg_vec.insert(0, 0) - N_ratio_vec.insert(0, hist[0][0]) g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))]) - # (3) Determine the vector of alchemical weights and histogram counts for each replica + # (3) Determine the vector of alchemical weights for each replica weights_modified = np.zeros_like(weights) for i in range(self.n_sim): - hist_modified = [] if self.equil[i] == -1: # unequilibrated weights_modified[i] = list(g_vec[i * self.s: i * self.s + self.n_sub] - g_vec[i * self.s: i * self.s + self.n_sub][0]) # noqa: E501 else: weights_modified[i] = self.equilibrated_weights[i] - # (4) Perform histogram correction - # Below we first deal with the case where the sampling is poor or the WL incrementor just got updated such that - # the histogram counts are 0 for some states, in which case we simply skip histogram correction. - contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501 - contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501 - skip_hist_correction = contains_nan or contains_inf - if skip_hist_correction: - print('\n Histogram correction is skipped because the histogram counts are 0 for some states.') - - if skip_hist_correction is False: - # When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will - # still not error out so it's fine to not add the conditional statement like here, since we will - # have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like - # int(np.nan) will lead to an error, so we put an if condition here. - N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))]) - - if skip_hist_correction is False: - hist_modified = [list(N_vec[self.state_ranges[i]]) for i in range(self.n_sim)] - else: - hist_modified = hist - - # (5) Print the modified weights and histogram counts + # (4) Print the modified weights if print_values is True: w = np.round(weights_modified, decimals=3).tolist() # just for printing print('\n Modified weights:') for i in range(len(w)): print(f' Rep {i}: {w[i]}') - if skip_hist_correction is False: - print('\n Modified histogram counts:') - for i in range(len(hist_modified)): - print(f' Rep {i}: {hist_modified[i]}') if self.verbose is False: print(' DONE') @@ -1271,7 +1302,7 @@ def combine_weights(self, hist, weights, weights_err=None, print_values=True): else: print(f'\n The alchemical weights of all states: \n {list(np.round(g_vec, decimals=3))}') - return hist_modified, weights_modified, g_vec + return weights_modified, g_vec def _run_grompp(self, n, swap_pattern): """