diff --git a/ensemble_md/analysis/analyze_traj.py b/ensemble_md/analysis/analyze_traj.py index c3300ab6..892c8ce8 100644 --- a/ensemble_md/analysis/analyze_traj.py +++ b/ensemble_md/analysis/analyze_traj.py @@ -965,7 +965,7 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None): plt.savefig(f'{swap_type}_swaps.png', dpi=600) -def get_g_evolution(log_files, N_states, avg_frac=0): +def get_g_evolution(log_files, N_states, avg_frac=0, avg_from_last_update=False): """ For weight-updating simulations, gets the time series of the alchemical weights of all states. @@ -979,7 +979,9 @@ def get_g_evolution(log_files, N_states, avg_frac=0): avg_frac : float The fraction of the last part of the simulation to be averaged. The default is 0, which means no averaging. - + avg_from_last_update : bool + Whether to average from the last update of wl-delta. If False, the + averaging will be from the beginning of the simulation. Returns ------- @@ -989,6 +991,10 @@ def get_g_evolution(log_files, N_states, avg_frac=0): g_vecs_avg : list The alchemical weights of all states averaged over the last part of the simulation. If :code:`avg_frac` is 0, :code:`None` will be returned. + g_vecs_err : list + The errors of the alchemical weights of all states averaged over the + last part of the simulation. If :code:`avg_frac` is 0 and :code:`avg_from_last_update` + is :code:`False`, :code:`None` will be returned. """ g_vecs_all = [] for log_file in log_files: @@ -997,6 +1003,7 @@ def get_g_evolution(log_files, N_states, avg_frac=0): f.close() n = -1 + idx_updates = [] # the indices of the data points corresponding to the updates of wl-delta find_equil = False for line in lines: n += 1 @@ -1011,19 +1018,35 @@ def get_g_evolution(log_files, N_states, avg_frac=0): if find_equil is False: g_vecs_all.append(w) + if 'weights are now' in line: + idx_updates.append(len(g_vecs_all) - 1) + if "Weights have equilibrated" in line: find_equil = True w = [float(i) for i in lines[n - 2].split(':')[-1].split()] g_vecs_all.append(w) break - if avg_frac != 0: - n_avg = int(avg_frac * len(g_vecs_all)) - g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0) + if avg_from_last_update is True: + # If the weights are equilibrated, then the last occurrence of "weights are now" + # is right before the equilibration message, in which case we want to average + # from the second last occurrence of "weights are now". + if find_equil is True: + idx_updates = idx_updates[:-1] + + idx_last_update = idx_updates[-1] + g_vecs_avg = np.mean(g_vecs_all[idx_last_update + 1:], axis=0) + g_vecs_err = np.std(g_vecs_all[idx_last_update + 1:], axis=0, ddof=1) else: - g_vecs_avg = None + if avg_frac != 0: + n_avg = int(avg_frac * len(g_vecs_all)) + g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0) + g_vecs_err = np.std(g_vecs_all[-n_avg:], axis=0, ddof=1) + else: + g_vecs_avg = None + g_vecs_err = None - return g_vecs_all, g_vecs_avg + return g_vecs_all, g_vecs_avg, g_vecs_err def get_dg_evolution(log_files, start_state, end_state): @@ -1091,3 +1114,81 @@ def plot_dg_evolution(log_files, start_state, end_state, start_idx=0, end_idx=-1 plt.savefig('dg_evolution.png', dpi=600) return dg + + +def get_delta_w_updates(log_file, plot=False): + """ + Parses a log file of a weight-updating simulation and identifies the + time frames when the Wang-Landau incrementor is updated. + + Parameters + ---------- + log_file : str + The name of the log file. + plot : bool + Whether to plot the Wang-Landau incrementor as a function of time. + + Returns + ------- + t_updates : list + A list of time frames when the Wang-Landau incrementor is updated. + delta_updates : list + A list of the updated Wang-Landau incrementors. Should be the same + length as :code:`t_updates`. + equil : bool + Whether the weights have been equilibrated. + """ + f = open(log_file, "r") + lines = f.readlines() + f.close() + + # Get the parameters + for l in lines: # noqa: E741 + if 'dt ' in l: + dt = float(l.split('=')[-1]) + if 'init-wl-delta ' in l: + init_wl_delta = float(l.split('=')[-1]) + if 'wl-scale ' in l: + wl_scale = float(l.split('=')[-1]) + if 'weight-equil-wl-delta ' in l: + wl_delta_cutoff = float(l.split('=')[-1]) + if 'Started mdrun' in l: + break + + # Start parsing the data + n = -1 + t_updates, delta_updates = [0], [init_wl_delta] + for l in lines: # noqa: E741 + n += 1 + if 'weights are now' in l: + t_updates.append(int(l.split(':')[0].split('Step')[-1]) * dt / 1000) # in ns + + # search the following 10 lines to find the Wang-Landau incrementor + for i in range(10): + if 'Wang-Landau incrementor is:' in lines[n + i]: + delta_updates.append(float(lines[n + i].split()[-1])) + break + if 'Weights have equilibrated' in l: + equil = True + break + + if equil is True: + delta_updates.append(delta_updates[-1] * wl_scale) + + # Plot the Wang-Landau incrementor as a function of time if requested + # Note that between adjacen entries in t_updates, a horizontal line should be drawn. + if plot is True: + plt.figure() + for i in range(len(t_updates) - 1): + plt.plot([t_updates[i], t_updates[i + 1]], [delta_updates[i], delta_updates[i]], c='C0') + plt.plot([t_updates[i + 1], t_updates[i + 1]], [delta_updates[i], delta_updates[i + 1]], c='C0') + + plt.text(0.65, 0.95, f'init_wl_delta: {init_wl_delta}', transform=plt.gca().transAxes) + plt.text(0.65, 0.9, f'wl-scale: {wl_scale}', transform=plt.gca().transAxes) + plt.text(0.65, 0.85, f'wl_delta_cutoff: {wl_delta_cutoff}', transform=plt.gca().transAxes) + + plt.xlabel('Time (ns)') + plt.ylabel(r'Wang-Landau incrementor ($k_{B}T$)') + plt.grid() + plt.savefig('delta_updates.png', dpi=600) + return t_updates, delta_updates, equil