Skip to content

Commit

Permalink
Removed synthesize_transmtx from analyze_matrix.py; Simplified synthe…
Browse files Browse the repository at this point in the history
…size_transmtx in synthesize_data.py
  • Loading branch information
wehs7661 committed Apr 3, 2024
1 parent 322af8e commit cf3821d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 59 deletions.
34 changes: 0 additions & 34 deletions ensemble_md/analysis/analyze_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,40 +122,6 @@ def calc_equil_prob(trans_mtx):
return equil_prob


def synthesize_transmtx(trans_mtx, mtx_type='rep', n_frames=100000):
"""
Synthesizes a mock transition matrix by calculating the underlying equilibrium probability
of the input transition matrix, synthesizing a trajectory by drawing samples from the equilibrium
distribution, and calculating the transition matrix from the trajectory.
Parameters
----------
trans_mtx: np.ndarray
The input transition matrix.
mtx_type: str
The type of the input transition matrix. It can be either 'rep' (replica-space transition matrix)
or 'state' (state-space transition matrix).
n_frames: int
The number of frames of the synthesized trajectory from which the mock transition matrix is calculated.
Returns
-------
syn_mtx: np.ndarray
The synthesized transition matrix.
syn_traj: np.ndarray
The synthesized trajectory.
diff_mtx: np.ndarray
The absolute difference between the input and synthesized transition matrices.
"""
equil_prob = calc_equil_prob(trans_mtx)
n_states = len(equil_prob)
syn_traj = np.random.choice(n_states, size=n_frames, p=equil_prob.reshape(n_states))
syn_mtx = analyze_traj.traj2transmtx(syn_traj, n_states)
diff_mtx = np.abs(trans_mtx - syn_mtx)

return syn_mtx, syn_traj, diff_mtx


def calc_spectral_gap(trans_mtx, atol=1e-8, n_bootstrap=50):
"""
Calculates the spectral gap of the input transition matrix and estimates its
Expand Down
35 changes: 10 additions & 25 deletions ensemble_md/analysis/synthesize_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def synthesize_traj(trans_mtx, n_frames=100000, method='transmtx', start=0, seed
return syn_traj


def synthesize_transmtx(trans_mtx, mtx_type='rep', n_frames=100000, seed=None):
def synthesize_transmtx(trans_mtx, n_frames=100000, seed=None):
"""
Synthesizes a normalized transition matrix similar to the input transition matrix by first
generating a trajectory using :code:`synthesize_traj` with :code:`method='transmtx'` and then
Expand All @@ -71,9 +71,6 @@ def synthesize_transmtx(trans_mtx, mtx_type='rep', n_frames=100000, seed=None):
----------
trans_mtx: np.ndarray
The input transition matrix.
mtx_type: str
The type of the input transition matrix. It can be either 'rep' (replica-space transition matrix)
or 'state' (state-space transition matrix).
n_frames: int
The number of frames of the synthesized trajectory from which the mock transition matrix is calculated.
The default value is 100000.
Expand All @@ -85,32 +82,20 @@ def synthesize_transmtx(trans_mtx, mtx_type='rep', n_frames=100000, seed=None):
syn_mtx: np.ndarray
The synthesized transition matrix.
syn_traj: np.ndarray
The synthesized trajectory/trajectories from which the transition matrix is calculated. Note that
if :code:`mtx_type` is 'rep', this will be a list of trajectories, which represent synthesized
replica-space trajectories.
The synthesized trajectory from which the transition matrix is calculated.
diff_mtx: np.ndarray
The input transition matrix subtracted by the synthesized transition matrix.
"""
N = len(trans_mtx) # can be the number of states or number of replicas depending on mtx_type
if mtx_type == 'state':
# Note that here we just use the default values (method='transmtx' and start=0) for synthesize_traj, so that
# the synthesized matrix will be similar to the input one. (If equil_prob is used, the resulting matrix may
# be very different from the input one, though the equilibrium probabilities and spectral gap should be similar.)
# Note that for transition matrix synthesis, the starting state of the synthesized trajectory
# should not matter given that the number of frames is large.
syn_traj = synthesize_traj(trans_mtx, n_frames, seed=seed)
syn_mtx = analyze_traj.traj2transmtx(syn_traj, N)
elif mtx_type == 'rep':
rep_trajs = np.array([synthesize_traj(trans_mtx, n_frames, start=i, seed=seed) for i in range(N)])
counts = [analyze_traj.traj2transmtx(rep_trajs[i], N, normalize=False) for i in range(len(rep_trajs))]
syn_mtx = np.sum(counts, axis=0)
syn_mtx /= np.sum(syn_mtx, axis=1)[:, None]
syn_traj = rep_trajs
else:
raise ValueError(f'Invalid mtx_type: {mtx_type}. The mtx_type must be either "rep" or "state".')


# Note that here we just use the default values (method='transmtx' and start=0) for synthesize_traj, so that
# the synthesized matrix will be similar to the input one. (If equil_prob is used, the resulting matrix may
# be very different from the input one, though the equilibrium probabilities and spectral gap should be similar.)
# Note that for transition matrix synthesis, the starting state of the synthesized trajectory
# should not matter given that the number of frames is large.
syn_traj = synthesize_traj(trans_mtx, n_frames, seed=seed)
syn_mtx = analyze_traj.traj2transmtx(syn_traj, N)
diff_mtx = trans_mtx - syn_mtx

return syn_mtx, syn_traj, diff_mtx


0 comments on commit cf3821d

Please sign in to comment.