diff --git a/ensemble_md/analysis/analyze_traj.py b/ensemble_md/analysis/analyze_traj.py index 45af0bb..e546b4e 100644 --- a/ensemble_md/analysis/analyze_traj.py +++ b/ensemble_md/analysis/analyze_traj.py @@ -1330,3 +1330,90 @@ def get_delta_w_updates(log_file, plot=False): plt.savefig('delta_w_updates.png', dpi=600) return t_updates, delta_w_updates, equil + +def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_rep_pattern, ps_per_frame): + import pandas as pd + import os + import mdtraj as md + + #Determine how many end states are present, which simulations and lambdas those end states correspond to + state_name = ['A'] + considered_swaps = [[0,0]] + cat = ord('A') + 1 + for swap in swap_rep_pattern: + part_1, part_2 = swap + if part_1 in considered_swaps and part_2 in considered_swaps: + continue + elif part_1 in considered_swaps: + index = considered_swaps.index(part_1) + state_name.append(state_name[index]) + considered_swaps.append(part_2) + elif part_2 in considered_swaps: + index = considered_swaps.index(part_2) + state_name.append(state_name[index]) + considered_swaps.append(part_1) + else: + state_name.append(chr(cat)) + state_name.append(chr(cat)) + considered_swaps.append(part_1) + considered_swaps.append(part_2) + cat += 1 + for i in range(n_sim): + for j in [0, 1]: + if [i, j] not in considered_swaps: + state_name.append(chr(cat)) + considered_swaps.append([i, j]) + cat += 1 + + #Determine which frames correspond to which end states + state_frame_df = pd.DataFrame() + for n in range(n_sim): + for i in range(n_iter): + l0_frame, l1_frame = [],[] + dhdl_file = open(f'{working_dir}/sim_{n}/iteration_{i}/dhdl.xvg', 'r').readlines() + start = True + for line in dhdl_file: + split_line = line.split(' ') + while '' in split_line: + split_line.remove('') + if '#' not in split_line[0] and '@' not in split_line[0]: + time = float(split_line[0]) + if start: + start_time = time + start = False + state = float(split_line[1]) + if time%ps_per_frame == 0: + if state in l0_states: + l0_frame.append(int((time-start_time)/ps_per_frame)) + elif state in l1_states: + l1_frame.append(int((time-start_time)/ps_per_frame)) + if len(l0_frame) != 0: + df_0 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l0_frame, 'Lambda': 0}) + state_frame_df = pd.concat([state_frame_df, df_0]) + if len(l1_frame) != 0: + df_1 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l1_frame, 'Lambda': 1}) + state_frame_df = pd.concat([state_frame_df, df_1]) + + #Concatenate all frames from each set of trajectories for each end state + unique_states = list(set(state_name)) + for state in unique_states: + indices = [i for i, value in enumerate(state_name) if value == state] + for i, index in enumerate(indices): + rep, l = considered_swaps[index] + started = False + if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'): + name = 'confout_backup' + else: + name = 'confout' + for iteration in range(n_iter): + frames_select = state_frame_df[(state_frame_df['Sim'] == rep) & (state_frame_df['Iteration'] == iteration) & (state_frame_df['Lambda'] == l)]['Frame'].to_numpy() + if len(frames_select) != 0: + if not started: + traj = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') + started = True + else: + traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') + traj = md.join(traj, traj_add) + traj.save_xtc(f'{working_dir}/analysis/{state}_{rep}.xtc') + + diff --git a/ensemble_md/cli/analyze_REXEE.py b/ensemble_md/cli/analyze_REXEE.py index 62dfbd6..017d9d6 100644 --- a/ensemble_md/cli/analyze_REXEE.py +++ b/ensemble_md/cli/analyze_REXEE.py @@ -25,6 +25,7 @@ warnings.simplefilter(action='ignore', category=UserWarning) from ensemble_md.utils import utils # noqa: E402 +from ensemble_md.utils import gmx_parser # noqa: E402 from ensemble_md.analysis import analyze_traj # noqa: E402 from ensemble_md.analysis import analyze_matrix # noqa: E402 from ensemble_md.analysis import msm_analysis # noqa: E402 @@ -120,6 +121,7 @@ def main(): print('\nData analysis of the simulation ensemble') print('========================================') + # Section 1. Analysis based on transitions between state sets print('[ Section 1. Analysis based on transitions between state sets/replicas ]') section_idx += 1 @@ -128,6 +130,15 @@ def main(): print('1-0. Reading in the replica-space trajectory ...') rep_trajs = np.load(args.rep_trajs) # Shape: (n_sim, n_iter) + # ***** Testing Section ******* + if REXEE.modify_coords is not None: + l0, l1, ps_per_frame = gmx_parser.get_end_states(f'{REXEE.working_dir}/sim_0/iteration_0/expanded.mdp') + n_sim, n_iter = np.shape(rep_trajs) + if REXEE.swap_rep_pattern is None: + raise Exception('MT-REXEE trajectory analysis requires swap_rep_pattern to be defined') + analyze_traj.end_states_only_traj(REXEE.working_dir, n_sim, n_iter, l0, l1, REXEE.swap_rep_pattern, ps_per_frame) + exit() + # 1-1. Plot the replica-sapce trajectory print('1-1. Plotting transitions between state sets/replicas ...') dt_swap = REXEE.nst_sim * REXEE.dt # dt for swapping replicas diff --git a/ensemble_md/utils/gmx_parser.py b/ensemble_md/utils/gmx_parser.py index 9c41669..5279ee2 100644 --- a/ensemble_md/utils/gmx_parser.py +++ b/ensemble_md/utils/gmx_parser.py @@ -408,3 +408,22 @@ def read_top(file_name, resname): if len(line_sep) > 4 and line_sep[3] == resname: return input_file raise Exception(f'Residue {resname} can not be found in {file_name}') + +def get_end_states(mdp_path): + mdp = MDP(mdp_path) + end_0, end_1 = [], [] + coul_lambda = mdp['coul_lambdas'] + vdw_lambda = mdp['vdw_lambdas'] + n = 0 + for vdw, coul in zip(coul_lambda, vdw_lambda): + if vdw == 0.0 and coul == 0.0: + end_0.append(n) + elif vdw == 1.0 and coul == 1.0: + end_1.append(n) + n += 1 + dt = mdp['dt'] + steps_per_frame = mdp['nstxout'] + ps_per_frame = dt*steps_per_frame + + return end_0, end_1, ps_per_frame +