Skip to content

Commit

Permalink
Developed get_delta_w_updates and modified get_g_evolution
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Nov 3, 2023
1 parent a3f07fc commit 5fd9440
Showing 1 changed file with 108 additions and 7 deletions.
115 changes: 108 additions & 7 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None):
plt.savefig(f'{swap_type}_swaps.png', dpi=600)


def get_g_evolution(log_files, N_states, avg_frac=0):
def get_g_evolution(log_files, N_states, avg_frac=0, avg_from_last_update=False):
"""
For weight-updating simulations, gets the time series of the alchemical
weights of all states.
Expand All @@ -979,7 +979,9 @@ def get_g_evolution(log_files, N_states, avg_frac=0):
avg_frac : float
The fraction of the last part of the simulation to be averaged. The
default is 0, which means no averaging.
avg_from_last_update : bool
Whether to average from the last update of wl-delta. If False, the
averaging will be from the beginning of the simulation.
Returns
-------
Expand All @@ -989,6 +991,10 @@ def get_g_evolution(log_files, N_states, avg_frac=0):
g_vecs_avg : list
The alchemical weights of all states averaged over the last part of
the simulation. If :code:`avg_frac` is 0, :code:`None` will be returned.
g_vecs_err : list
The errors of the alchemical weights of all states averaged over the
last part of the simulation. If :code:`avg_frac` is 0 and :code:`avg_from_last_update`
is :code:`False`, :code:`None` will be returned.
"""
g_vecs_all = []
for log_file in log_files:
Expand All @@ -997,6 +1003,7 @@ def get_g_evolution(log_files, N_states, avg_frac=0):
f.close()

n = -1
idx_updates = [] # the indices of the data points corresponding to the updates of wl-delta
find_equil = False
for line in lines:
n += 1
Expand All @@ -1011,19 +1018,35 @@ def get_g_evolution(log_files, N_states, avg_frac=0):
if find_equil is False:
g_vecs_all.append(w)

if 'weights are now' in line:
idx_updates.append(len(g_vecs_all) - 1)

if "Weights have equilibrated" in line:
find_equil = True
w = [float(i) for i in lines[n - 2].split(':')[-1].split()]
g_vecs_all.append(w)
break

if avg_frac != 0:
n_avg = int(avg_frac * len(g_vecs_all))
g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0)
if avg_from_last_update is True:
# If the weights are equilibrated, then the last occurrence of "weights are now"
# is right before the equilibration message, in which case we want to average
# from the second last occurrence of "weights are now".
if find_equil is True:
idx_updates = idx_updates[:-1]

idx_last_update = idx_updates[-1]
g_vecs_avg = np.mean(g_vecs_all[idx_last_update + 1:], axis=0)
g_vecs_err = np.std(g_vecs_all[idx_last_update + 1:], axis=0, ddof=1)
else:
g_vecs_avg = None
if avg_frac != 0:
n_avg = int(avg_frac * len(g_vecs_all))
g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0)
g_vecs_err = np.std(g_vecs_all[-n_avg:], axis=0, ddof=1)
else:
g_vecs_avg = None
g_vecs_err = None

return g_vecs_all, g_vecs_avg
return g_vecs_all, g_vecs_avg, g_vecs_err


def get_dg_evolution(log_files, start_state, end_state):
Expand Down Expand Up @@ -1091,3 +1114,81 @@ def plot_dg_evolution(log_files, start_state, end_state, start_idx=0, end_idx=-1
plt.savefig('dg_evolution.png', dpi=600)

return dg


def get_delta_w_updates(log_file, plot=False):
"""
Parses a log file of a weight-updating simulation and identifies the
time frames when the Wang-Landau incrementor is updated.
Parameters
----------
log_file : str
The name of the log file.
plot : bool
Whether to plot the Wang-Landau incrementor as a function of time.
Returns
-------
t_updates : list
A list of time frames when the Wang-Landau incrementor is updated.
delta_updates : list
A list of the updated Wang-Landau incrementors. Should be the same
length as :code:`t_updates`.
equil : bool
Whether the weights have been equilibrated.
"""
f = open(log_file, "r")
lines = f.readlines()
f.close()

# Get the parameters
for l in lines: # noqa: E741
if 'dt ' in l:
dt = float(l.split('=')[-1])
if 'init-wl-delta ' in l:
init_wl_delta = float(l.split('=')[-1])
if 'wl-scale ' in l:
wl_scale = float(l.split('=')[-1])
if 'weight-equil-wl-delta ' in l:
wl_delta_cutoff = float(l.split('=')[-1])
if 'Started mdrun' in l:
break

# Start parsing the data
n = -1
t_updates, delta_updates = [0], [init_wl_delta]
for l in lines: # noqa: E741
n += 1
if 'weights are now' in l:
t_updates.append(int(l.split(':')[0].split('Step')[-1]) * dt / 1000) # in ns

# search the following 10 lines to find the Wang-Landau incrementor
for i in range(10):
if 'Wang-Landau incrementor is:' in lines[n + i]:
delta_updates.append(float(lines[n + i].split()[-1]))
break
if 'Weights have equilibrated' in l:
equil = True
break

if equil is True:
delta_updates.append(delta_updates[-1] * wl_scale)

# Plot the Wang-Landau incrementor as a function of time if requested
# Note that between adjacen entries in t_updates, a horizontal line should be drawn.
if plot is True:
plt.figure()
for i in range(len(t_updates) - 1):
plt.plot([t_updates[i], t_updates[i + 1]], [delta_updates[i], delta_updates[i]], c='C0')
plt.plot([t_updates[i + 1], t_updates[i + 1]], [delta_updates[i], delta_updates[i + 1]], c='C0')

plt.text(0.65, 0.95, f'init_wl_delta: {init_wl_delta}', transform=plt.gca().transAxes)
plt.text(0.65, 0.9, f'wl-scale: {wl_scale}', transform=plt.gca().transAxes)
plt.text(0.65, 0.85, f'wl_delta_cutoff: {wl_delta_cutoff}', transform=plt.gca().transAxes)

plt.xlabel('Time (ns)')
plt.ylabel(r'Wang-Landau incrementor ($k_{B}T$)')
plt.grid()
plt.savefig('delta_updates.png', dpi=600)
return t_updates, delta_updates, equil

0 comments on commit 5fd9440

Please sign in to comment.