Skip to content

Commit

Permalink
Added hist_corr, developed histogram_correction, and modified run_REX…
Browse files Browse the repository at this point in the history
…EE.py correspondingly
  • Loading branch information
wehs7661 committed Nov 2, 2023
1 parent 25077aa commit 9bb2431
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 86 deletions.
24 changes: 20 additions & 4 deletions ensemble_md/cli/run_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,27 +186,43 @@ def main():
print('Performing weight combination ...')
else:
print('Performing weight combination ...', end='')
counts, weights, g_vec = REXEE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
weights, g_vec = REXEE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
REXEE.g_vecs.append(g_vec)

# Check if histogram correction is needed after weight combination
if REXEE.hist_corr is True:
print('Performing histogram correction ...')
counts = REXEE.histogram_correction(counts_)
else:
print('Note: No histogram correction will be performed.')

elif REXEE.N_cutoff == -1 and REXEE.w_combine is True:
# Only perform weight combination
print('Note: No weight correction will be performed.')
if REXEE.verbose is True:
print('Performing weight combination ...')
else:
print('Performing weight combination ...', end='')
counts, weights, g_vec = REXEE.combine_weights(counts_, weights_avg)
weights, g_vec = REXEE.combine_weights(weights_avg)
REXEE.g_vecs.append(g_vec)

# Check if histogram correction is needed after weight combination
if REXEE.hist_corr is True:
print('Performing histogram correction ...')
counts = REXEE.histogram_correction(counts_)
else:
print('Note: No histogram correction will be performed.')

elif REXEE.N_cutoff != -1 and REXEE.w_combine is False:
# Only perform weight correction
print('Note: No weight combination will be performed.')
weights = REXEE.weights_correction(weights_avg, counts)
_ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501
_ = REXEE.combine_weights(weights, print_values=False)[1] # just to print the combined weights # noqa: E501
else:
print('Note: No weight correction will be performed.')
print('Note: No weight combination will be performed.')
# Note that in this case, the final weights will be used in the next iteration.
_ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501
_ = REXEE.combine_weights(weights, print_values=False)[1] # just to print the combiend weights # noqa: E501

# 3-5. Modify the MDP files and swap out the GRO files (if needed)
# Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501
Expand Down
195 changes: 113 additions & 82 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def set_params(self, analysis):
"acceptance": "metropolis",
"w_combine": False,
"N_cutoff": 1000,
# "n_ex": 'N^3', # only active for multiple swaps.
"hist_corr": False,
"verbose": True,
"mdp_args": None,
"grompp_args": None,
Expand All @@ -180,6 +180,7 @@ def set_params(self, analysis):
"err_method": "propagate",
"n_bootstrap": 50,
"seed": None,
# "n_ex": 'N^3', # only active for multiple swaps.
}
for i in optional_args:
if hasattr(self, i) is False or getattr(self, i) is None:
Expand Down Expand Up @@ -1073,6 +1074,44 @@ def accept_or_reject(self, prob_acc):
print(" Swap rejected! ")
return swap_bool

def get_averaged_weights(self, log_files):
"""
For each replica, calculate the averaged weights (and the associated error) from the time series
of the weights since the previous update of the Wang-Landau incrementor.
Parameters
----------
log_files : list
A list of file paths to GROMACS LOG files of different replicas.
Returned
--------
weights_avg : list
A list of lists of weights averaged since the last update of the Wang-Landau
incrementor. The length of the list should be the number of replicas.
weights_err : list
A list of lists of errors corresponding to the values in :code:`weights_avg`.
"""
for i in range(self.n_sim):
weights, _, wl_delta, _ = gmx_parser.parse_log(log_files[i])
if self.current_wl_delta[i] == wl_delta:
self.updating_weights[i] += weights # expand the list
else:
self.current_wl_delta[i] = wl_delta
self.updating_weights[i] = weights

# shape of self.updating_weights is (n_sim, n_points, n_states), but n_points can be different
# for different replicas, which will error out np.mean(self.updating_weights, axis=1)
weights_avg = [np.mean(self.updating_weights[i], axis=0).tolist() for i in range(self.n_sim)]
weights_err = []
for i in range(self.n_sim):
if len(self.updating_weights[i]) == 1: # this would lead to a RunTime Warning and nan
weights_err.append([0] * self.n_sub) # in `weighted_mean``, a simple average will be returned.
else:
weights_err.append(np.std(self.updating_weights[i], axis=0, ddof=1).tolist())

return weights_avg, weights_err

def weight_correction(self, weights, counts):
"""
Corrects the lambda weights based on the histogram counts. Namely,
Expand Down Expand Up @@ -1118,160 +1157,152 @@ def weight_correction(self, weights, counts):

return weights

def get_averaged_weights(self, log_files):
def histogram_correction(self, hist, print_values=True):
"""
For each replica, calculate the averaged weights (and the associated error) from the time series
of the weights since the previous update of the Wang-Landau incrementor.
Adjust the histogram counts. Specifically, the ratio of corrected histogram counts
for adjancent states is the geometric mean of the ratio of the original histogram counts
for the same states. Note, however, if the histogram counts are 0 for some states, the
histogram correction will be skipped and the original histogram counts will be returned.
Parameters
----------
log_files : list
A list of file paths to GROMACS LOG files of different replicas.
hist : list
A list of lists of histogram counts of ALL simulations.
print_values : bool, optional
Whether to print the histograms for each replica before and after histogram correction.
Returned
--------
weights_avg : list
A list of lists of weights averaged since the last update of the Wang-Landau
incrementor. The length of the list should be the number of replicas.
weights_err : list
A list of lists of errors corresponding to the values in :code:`weights_avg`.
Returns
-------
hist_modified : list
A list of lists of modified histogram counts of ALL simulations.
"""
for i in range(self.n_sim):
weights, _, wl_delta, _ = gmx_parser.parse_log(log_files[i])
if self.current_wl_delta[i] == wl_delta:
self.updating_weights[i] += weights # expand the list
else:
self.current_wl_delta[i] = wl_delta
self.updating_weights[i] = weights
# (1) Print the original histogram counts
if print_values is True:
print(' Original histogram counts:')
for i in range(len(self.hist)):
print(f' Rep {i}: {self.hist[i]}')

# shape of self.updating_weights is (n_sim, n_points, n_states), but n_points can be different
# for different replicas, which will error out np.mean(self.updating_weights, axis=1)
weights_avg = [np.mean(self.updating_weights[i], axis=0).tolist() for i in range(self.n_sim)]
weights_err = []
for i in range(self.n_sim):
if len(self.updating_weights[i]) == 1: # this would lead to a RunTime Warning and nan
weights_err.append([0] * self.n_sub) # in `weighted_mean``, a simple average will be returned.
else:
weights_err.append(np.std(self.updating_weights[i], axis=0, ddof=1).tolist())
# (2) Calculate adjacent weight differences and g_vec
N_ratio_vec = [] # N_{k-1}/N_k for the whole range
with warnings.catch_warnings(): # Suppress the specific warning here
warnings.simplefilter("ignore", category=RuntimeWarning)
N_ratio_adjacent = [list(np.array(hist[i][1:]) / np.array(hist[i][:-1])) for i in range(len(hist))]

return weights_avg, weights_err
for i in range(self.n_tot - 1):
N_ratio_list = []
for j in range(len(self.state_ranges)):
if i in self.state_ranges[j] and i + 1 in self.state_ranges[j]:
idx = self.state_ranges[j].index(i)
N_ratio_list.append(N_ratio_adjacent[j][idx])
N_ratio_vec.append(np.prod(N_ratio_list) ** (1 / len(N_ratio_list))) # geometric mean
N_ratio_vec.insert(0, hist[0][0])

def combine_weights(self, hist, weights, weights_err=None, print_values=True):
# (3) Check if the histogram counts are 0 for some states, if so, the histogram correction will be skipped.
# Zero histogram counts can happen when the sampling is poor or the WL incrementor just got updated
contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501
contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501
skip_hist_correction = contains_nan or contains_inf
if skip_hist_correction:
print('\n Histogram correction is skipped because the histogram counts are 0 for some states.')

# (4) Perform histogram correction if it is not skipped
if skip_hist_correction is False:
print('\n Performing histogram correction ...')
# When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will
# still not error out so it's fine to not add the conditional statement like here, since we will
# have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like
# int(np.nan) will lead to an error, so we put an if condition here.
N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))])

if skip_hist_correction is False:
hist_modified = [list(N_vec[self.state_ranges[i]]) for i in range(self.n_sim)]
else:
hist_modified = hist # the original input histogram

# (5) Print the modified histogram counts
if print_values is True:
print('\n Modified histogram counts:')
for i in range(len(hist_modified)):
print(f' Rep {i}: {hist_modified[i]}')

return hist_modified

def combine_weights(self, weights, weights_err=None, print_values=True):
"""
Combine alchemical weights across multiple replicas and adjusts the histogram counts
correspondingly. Note that if :code:`weights_err` is provided, inverse-variance weighting will be used.
Combine alchemical weights across multiple replicas. Note that if
:code:`weights_err` is provided, inverse-variance weighting will be used.
Care must be taken since inverse-variance weighting can lead to slower
convergence if the provided errors are not accurate. (See :ref:`doc_w_schemes` for mor details.)
Parameters
----------
hist : list
A list of lists of histogram counts of ALL simulations.
weights : list
A list of lists alchemical weights of ALL simulations.
weights_err : list, optional
A list of lists of errors corresponding to the values in :code:`weights`.
print_values : bool, optional
Whether to print the histograms and weights for each replica before and
after weight combinationfor each replica.
Whether to print the weights for each replica before and
after weight combination for each replica.
Returns
-------
hist_modified : list
A list of modified histogram counts of ALL simulations.
weights_modified : list
A list of modified Wang-Landau weights of ALL simulations.
g_vec : np.ndarray
An array of alchemical weights of the whole range of states.
"""
# (1) Print the original weights and histogram counts
# (1) Print the original weights
if print_values is True:
w = np.round(weights, decimals=3).tolist() # just for printing
print(' Original weights:')
for i in range(len(w)):
print(f' Rep {i}: {w[i]}')
print('\n Original histogram counts:')
for i in range(len(hist)):
print(f' Rep {i}: {hist[i]}')

# (2) Calculate adjacent weight differences and g_vec
# Note that N_ratio_vec and other similar variables are only used in histogram corrections
dg_vec, N_ratio_vec = [], [] # alchemical weight differences and histogram count ratios for the whole range
dg_vec = [] # alchemical weight differences for the whole range
dg_adjacent = [list(np.diff(weights[i])) for i in range(len(weights))]

with warnings.catch_warnings(): # Suppress the specific warning here
warnings.simplefilter("ignore", category=RuntimeWarning)
N_ratio_adjacent = [list(np.array(hist[i][1:]) / np.array(hist[i][:-1])) for i in range(len(hist))]

if weights_err is not None:
dg_adjacent_err = [[np.sqrt(weights_err[i][j] ** 2 + weights_err[i][j + 1] ** 2) for j in range(len(weights_err[i]) - 1)] for i in range(len(weights_err))] # noqa: E501

for i in range(self.n_tot - 1):
dg_list, dg_err_list, N_ratio_list = [], [], []
dg_list, dg_err_list = [], []
for j in range(len(self.state_ranges)):
if i in self.state_ranges[j] and i + 1 in self.state_ranges[j]:
idx = self.state_ranges[j].index(i)
dg_list.append(dg_adjacent[j][idx])
N_ratio_list.append(N_ratio_adjacent[j][idx])
if weights_err is not None:
dg_err_list.append(dg_adjacent_err[j][idx])
if weights_err is None:
dg_vec.append(np.mean(dg_list))
else:
dg_vec.append(utils.weighted_mean(dg_list, dg_err_list)[0])
N_ratio_vec.append(np.prod(N_ratio_list) ** (1 / len(N_ratio_list))) # geometric mean

dg_vec.insert(0, 0)
N_ratio_vec.insert(0, hist[0][0])
g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))])

# (3) Determine the vector of alchemical weights and histogram counts for each replica
# (3) Determine the vector of alchemical weights for each replica
weights_modified = np.zeros_like(weights)
for i in range(self.n_sim):
hist_modified = []
if self.equil[i] == -1: # unequilibrated
weights_modified[i] = list(g_vec[i * self.s: i * self.s + self.n_sub] - g_vec[i * self.s: i * self.s + self.n_sub][0]) # noqa: E501
else:
weights_modified[i] = self.equilibrated_weights[i]

# (4) Perform histogram correction
# Below we first deal with the case where the sampling is poor or the WL incrementor just got updated such that
# the histogram counts are 0 for some states, in which case we simply skip histogram correction.
contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501
contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501
skip_hist_correction = contains_nan or contains_inf
if skip_hist_correction:
print('\n Histogram correction is skipped because the histogram counts are 0 for some states.')

if skip_hist_correction is False:
# When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will
# still not error out so it's fine to not add the conditional statement like here, since we will
# have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like
# int(np.nan) will lead to an error, so we put an if condition here.
N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))])

if skip_hist_correction is False:
hist_modified = [list(N_vec[self.state_ranges[i]]) for i in range(self.n_sim)]
else:
hist_modified = hist

# (5) Print the modified weights and histogram counts
# (4) Print the modified weights
if print_values is True:
w = np.round(weights_modified, decimals=3).tolist() # just for printing
print('\n Modified weights:')
for i in range(len(w)):
print(f' Rep {i}: {w[i]}')
if skip_hist_correction is False:
print('\n Modified histogram counts:')
for i in range(len(hist_modified)):
print(f' Rep {i}: {hist_modified[i]}')

if self.verbose is False:
print(' DONE')
print(f'The alchemical weights of all states: \n {list(np.round(g_vec, decimals=3))}')
else:
print(f'\n The alchemical weights of all states: \n {list(np.round(g_vec, decimals=3))}')

return hist_modified, weights_modified, g_vec
return weights_modified, g_vec

def _run_grompp(self, n, swap_pattern):
"""
Expand Down

0 comments on commit 9bb2431

Please sign in to comment.