-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added synthesize_data.py to ensemble_md.analysis
- Loading branch information
Showing
2 changed files
with
122 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
#################################################################### | ||
# # | ||
# ensemble_md, # | ||
# a python package for running GROMACS simulation ensembles # | ||
# # | ||
# Written by Wei-Tse Hsu <wehs7661@colorado.edu> # | ||
# Copyright (c) 2022 University of Colorado Boulder # | ||
# # | ||
#################################################################### | ||
""" | ||
The :obj:`.synthesize_data` module provides methods for synthesizing REXEE data. | ||
""" | ||
import numpy as np | ||
from ensemble_md.analysis import analyze_traj | ||
from ensemble_md.analysis import analyze_matrix | ||
|
||
def synthesize_traj(trans_mtx, n_frames=100000, method='transmtx', start=0, seed=None): | ||
""" | ||
Synthesize a trajectory based on the input transition matrix. | ||
Parameters | ||
---------- | ||
trans_mtx: np.ndarray | ||
The input transition matrix. | ||
n_frames: int | ||
The number of frames to be generated. The default value is 100000. | ||
method: str | ||
The method to be used for trajectory synthesis. It can be either 'transmtx' or 'equil_prob'. | ||
The former refers to generating the trajectory by simulating the moves between states based on the | ||
input transition matrix, with the trajectory starting from the state specified by the :code:`start` parameter. | ||
If the method is :code:`equil_prob`, the trajectory will be generated by simply sampling from the equilibrium | ||
probability distribution calculated from the input transition matrix. The method 'transmtx' should | ||
generate a trajectory characterized by a transition matrix similar to the input one, while the method | ||
'equil_prob' may generate a trajectory that has a significantly different transition matrix. Still, | ||
a trajectory generated by either method should a similar underlying equilibrium probability distribution | ||
(hence the spectral gap as well) as the input transition matrix. The default value is 'transmtx'. | ||
start: int | ||
The starting state of the synthesized trajectory if the method is :code:`transmtx`. The default value is 0, | ||
i.e., the first state. This parameter is ignored if the method is :code:`equil_prob`. | ||
seed: int | ||
The seed for the random number generator. The default value is None, i.e., the seed is not set. | ||
Returns | ||
------- | ||
syn_traj: np.ndarray | ||
The synthesized trajectory. | ||
""" | ||
np.random.seed(seed) # If seed is None, the seed is not set. | ||
N = len(trans_mtx) # Can be the number of states or replicas depending on the type of the input mtraix | ||
if method == 'equil_prob': | ||
equil_prob = analyze_traj.calc_equil_prob(trans_mtx) | ||
syn_traj = np.random.choice(N, size=n_frames, p=equil_prob.reshape(N)) | ||
elif method == 'transmtx': | ||
syn_traj = np.zeros(n_frames, dtype=int) | ||
syn_traj[0] = start | ||
for i in range(1, n_frames): | ||
syn_traj[i] = np.random.choice(N, p=trans_mtx[syn_traj[i-1]]) | ||
else: | ||
raise ValueError(f'Invalid method: {method}. The method must be either "transmtx" or "equil_prob".') | ||
|
||
return syn_traj | ||
|
||
|
||
def synthesize_transmtx(trans_mtx, mtx_type='rep', n_frames=100000, seed=None): | ||
""" | ||
Synthesizes a normalized transition matrix similar to the input transition matrix by first | ||
generating a trajectory using :code:`synthesize_traj` with :code:`method='transmtx'` and then | ||
calculating the transition matrix from the synthesized trajectory. | ||
Parameters | ||
---------- | ||
trans_mtx: np.ndarray | ||
The input transition matrix. | ||
mtx_type: str | ||
The type of the input transition matrix. It can be either 'rep' (replica-space transition matrix) | ||
or 'state' (state-space transition matrix). | ||
n_frames: int | ||
The number of frames of the synthesized trajectory from which the mock transition matrix is calculated. | ||
The default value is 100000. | ||
seed: int | ||
The seed for the random number generator. The default value is None, i.e., the seed is not set. | ||
Returns | ||
------- | ||
syn_mtx: np.ndarray | ||
The synthesized transition matrix. | ||
syn_traj: np.ndarray | ||
The synthesized trajectory/trajectories from which the transition matrix is calculated. Note that | ||
if :code:`mtx_type` is 'rep', this will be a list of trajectories, which represent synthesized | ||
replica-space trajectories. | ||
diff_mtx: np.ndarray | ||
The input transition matrix subtracted by the synthesized transition matrix. | ||
""" | ||
N = len(trans_mtx) # can be the number of states or number of replicas depending on mtx_type | ||
if mtx_type == 'state': | ||
# Note that here we just use the default values (method='transmtx' and start=0) for synthesize_traj, so that | ||
# the synthesized matrix will be similar to the input one. (If equil_prob is used, the resulting matrix may | ||
# be very different from the input one, though the equilibrium probabilities and spectral gap should be similar.) | ||
# Note that for transition matrix synthesis, the starting state of the synthesized trajectory | ||
# should not matter given that the number of frames is large. | ||
syn_traj = synthesize_traj(trans_mtx, n_frames, seed=seed) | ||
syn_mtx = analyze_traj.traj2transmtx(syn_traj, N) | ||
elif mtx_type == 'rep': | ||
rep_trajs = np.array([synthesize_traj(trans_mtx, n_frames, start=i, seed=seed) for i in range(N)]) | ||
counts = [analyze_traj.traj2transmtx(rep_trajs[i], N, normalize=False) for i in range(len(rep_trajs))] | ||
syn_mtx = np.sum(counts, axis=0) | ||
syn_mtx /= np.sum(syn_mtx, axis=1)[:, None] | ||
syn_traj = rep_trajs | ||
else: | ||
raise ValueError(f'Invalid mtx_type: {mtx_type}. The mtx_type must be either "rep" or "state".') | ||
|
||
diff_mtx = trans_mtx - syn_mtx | ||
|
||
return syn_mtx, syn_traj, diff_mtx | ||
|
||
|