Skip to content

Commit

Permalink
Refactored combine_weights to group the codes for histogram correction
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Nov 2, 2023
1 parent a7a37ba commit 25077aa
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,7 @@ def combine_weights(self, hist, weights, weights_err=None, print_values=True):
g_vec : np.ndarray
An array of alchemical weights of the whole range of states.
"""
# (1) Print the original weights and histogram counts
if print_values is True:
w = np.round(weights, decimals=3).tolist() # just for printing
print(' Original weights:')
Expand All @@ -1193,22 +1194,15 @@ def combine_weights(self, hist, weights, weights_err=None, print_values=True):
for i in range(len(hist)):
print(f' Rep {i}: {hist[i]}')

# Calculate adjacent weight differences and g_vec
# (2) Calculate adjacent weight differences and g_vec
# Note that N_ratio_vec and other similar variables are only used in histogram corrections
dg_vec, N_ratio_vec = [], [] # alchemical weight differences and histogram count ratios for the whole range
dg_adjacent = [list(np.diff(weights[i])) for i in range(len(weights))]
# Suppress the specific warning here
with warnings.catch_warnings():

with warnings.catch_warnings(): # Suppress the specific warning here
warnings.simplefilter("ignore", category=RuntimeWarning)
N_ratio_adjacent = [list(np.array(hist[i][1:]) / np.array(hist[i][:-1])) for i in range(len(hist))]

# Below we deal with the case where the sampling is poor or the WL incrementor just got updated such that
# the histogram counts are 0 for some states, in which case we simply skip histogram correction.
contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501
contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501
skip_hist_correction = contains_nan or contains_inf
if skip_hist_correction:
print('\n Histogram correction is skipped because the histogram counts are 0 for some states.')

if weights_err is not None:
dg_adjacent_err = [[np.sqrt(weights_err[i][j] ** 2 + weights_err[i][j + 1] ** 2) for j in range(len(weights_err[i]) - 1)] for i in range(len(weights_err))] # noqa: E501

Expand All @@ -1229,26 +1223,38 @@ def combine_weights(self, hist, weights, weights_err=None, print_values=True):
dg_vec.insert(0, 0)
N_ratio_vec.insert(0, hist[0][0])
g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))])
if skip_hist_correction is False:
# When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will
# still not error out so it's fine to not add the conditional statement like here, since we will
# have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like
# int(np.nan) will lead to an error, so we put an if condition here.
N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))])

# Determine the vector of alchemical weights and histogram counts for each replica

# (3) Determine the vector of alchemical weights and histogram counts for each replica
weights_modified = np.zeros_like(weights)
for i in range(self.n_sim):
hist_modified = []
if self.equil[i] == -1: # unequilibrated
weights_modified[i] = list(g_vec[i * self.s: i * self.s + self.n_sub] - g_vec[i * self.s: i * self.s + self.n_sub][0]) # noqa: E501
else:
weights_modified[i] = self.equilibrated_weights[i]

# (4) Perform histogram correction
# Below we first deal with the case where the sampling is poor or the WL incrementor just got updated such that
# the histogram counts are 0 for some states, in which case we simply skip histogram correction.
contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501
contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501
skip_hist_correction = contains_nan or contains_inf
if skip_hist_correction:
print('\n Histogram correction is skipped because the histogram counts are 0 for some states.')

if skip_hist_correction is False:
# When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will
# still not error out so it's fine to not add the conditional statement like here, since we will
# have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like
# int(np.nan) will lead to an error, so we put an if condition here.
N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))])

if skip_hist_correction is False:
hist_modified = [list(N_vec[self.state_ranges[i]]) for i in range(self.n_sim)]
else:
hist_modified = hist

# (5) Print the modified weights and histogram counts
if print_values is True:
w = np.round(weights_modified, decimals=3).tolist() # just for printing
print('\n Modified weights:')
Expand Down

0 comments on commit 25077aa

Please sign in to comment.