Skip to content

Commit

Permalink
Tweaked stitch_time_series
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Mar 28, 2024
1 parent 8f88afa commit bac786e
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,32 @@ def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, sav

# Then, stitch the trajectories for each starting configuration
trajs = [[] for i in range(n_configs)] # for each starting configuration
t_last, val_last = None, None # just for checking the continuity of the trajectory
for i in range(n_configs):
for j in range(n_iter):
if j == 0:
if dhdl:
traj, t = extract_state_traj(files_sorted[i][j])
else:
traj = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, col_idx]
t = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, 0] # only used if save_xvg is True
if dhdl:
traj, t = extract_state_traj(files_sorted[i][j])
else:
# Starting from the 2nd iteration, we get rid of the first time frame the first
# frame of iteration n+1 is the same as the last frame of iteration n
if dhdl:
traj, t = extract_state_traj(files_sorted[i][j])
traj, t = traj[1:], t[1:]
else:
traj = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, col_idx][1:]
traj = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, col_idx]
t = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, 0]

# Shift the indices so that global indices are used.
shift_idx = rep_trajs[i][j]
traj = list(np.array(traj) + shifts[shift_idx])

if j != 0:
# Check the continuity of the trajectory
if traj[0] != val_last or t[0] != t_last:
err_str = f'The first frame of iteration {j} of starting configuration {i} is not continuous with the last frame of the previous iteration. '
err_str += f'Please check files {files_sorted[i][j - 1]} and {files_sorted[i][j]}.'
raise ValueError(err_str)

if dhdl: # Trajectories of global alchemical indices will be generated.
shift_idx = rep_trajs[i][j]
traj = list(np.array(traj) + shifts[shift_idx])
t_last = t[-1]
val_last = traj[-1]

if j != 0:
traj = traj[:-1] # remove the last frame, which is the same as the first of the next time series.

trajs[i].extend(traj)

if save_npy is True:
Expand Down Expand Up @@ -158,7 +164,7 @@ def stitch_time_series_for_sim(files, dhdl=True, col_idx=-1, save=True):
n_sim = len(files) # number of replicas
n_iter = len(files[0]) # number of iterations per replica
trajs = [[] for i in range(n_sim)]
t_last = None # just for checking the continuity of the trajectory
t_last, val_last = None, None # just for checking the continuity of the trajectory
for i in range(n_sim):
for j in range(n_iter):
if dhdl:
Expand All @@ -172,14 +178,18 @@ def stitch_time_series_for_sim(files, dhdl=True, col_idx=-1, save=True):

if j != 0:
# Check the continuity of the trajectory
if traj[0] != trajs[i][-1] or t[0] != t_last:
err_str = f'The first frame of iteration {j} in replica {i} is not continuous with the last frame of the previous iteration.' # noqa: E501
if traj[0] != val_last or t[0] != t_last:
err_str = f'The first frame of iteration {j} in replica {i} is not continuous with the last frame of the previous iteration. ' # noqa: E501
err_str += f'Please check files {files[i][j - 1]} and {files[i][j]}.'
raise ValueError(err_str)


t_last = t[-1]
val_last = traj[-1]

if j != 0:
traj = traj[:-1] # remove the last frame, which is the same as the first of the next time series.

trajs[i].extend(traj)
t_last = t[-1]

# Save the trajectories as an NPY file if desired
if save is True:
Expand Down

0 comments on commit bac786e

Please sign in to comment.