Skip to content

Commit

Permalink
Refined docstrings for analyze_matrix.py; Modified api_analysis.rst
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Apr 18, 2024
1 parent 8275f3c commit d61e0d1
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 56 deletions.
14 changes: 14 additions & 0 deletions docs/api/api_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ ensemble\_md.analysis.analyze_free_energy
:members:
:undoc-members:

ensemble\_md.analysis.synthesize_data
-----------------------------------------

.. automodule:: ensemble_md.analysis.synthesize_data
:members:
:undoc-members:

ensemble\_md.analysis.clustering
-----------------------------------------

.. automodule:: ensemble_md.analysis.clustering
:members:
:undoc-members:

ensemble\_md.analysis.msm_analysis
----------------------------------

Expand Down
106 changes: 56 additions & 50 deletions ensemble_md/analysis/analyze_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# #
####################################################################
"""
The :obj:`.analyze_matrix` module provides methods for analyzing matrices obtained from REXEE.
The :obj:`.analyze_matrix` module provides methods for analyzing matrices obtained from a REXEE simulation.
"""
import numpy as np
import seaborn as sns
Expand All @@ -19,7 +19,7 @@
from ensemble_md.analysis import synthesize_data


def calc_transmtx(log_file, expanded_ensemble=True):
def calc_transmtx(log_file, simulation_type='EE'):
"""
Parses the log file to get the transition matrix of an expanded ensemble
or replica exchange simulation. Notably, a theoretical transition matrix
Expand All @@ -28,18 +28,19 @@ def calc_transmtx(log_file, expanded_ensemble=True):
Parameters
----------
log_file : str
The log file to be parsed.
expanded_ensemble : bool
Whether the simulation is expanded ensemble or replica exchange
The file path to the log file to be parsed.
simulation_type : str, Optional
The type of simulation. It can be either a :code:`EE` (expanded ensemble) or :code:`HREX`
(Hamiltonian replica exchange) simulation. The default is :code:`EE`.
Returns
-------
empirical : np.ndarray
empirical : numpy.ndarray
The final empirical state transition matrix.
theoretical : None or np.ndarray
theoretical : None or numpy.ndarray
The final theoretical state transition matrix.
diff_matrix : None or np.ndarray
The difference between the theortial and empirical state transition matrix (empirical - theoretical).
diff_matrix : None or numpy.ndarray
The difference calculated by subtracting the theoretical matrix from the empirical matrix.
"""
f = open(log_file, "r")
lines = f.readlines()
Expand All @@ -56,10 +57,12 @@ def calc_transmtx(log_file, expanded_ensemble=True):
n_states = int(lines[n - 1].split()[-1])
empirical = np.zeros([n_states, n_states])
for i in range(n_states):
if expanded_ensemble is True:
if simulation_type == 'EE':
empirical[i] = [float(k) for k in lines[n - 2 - i].split()[:-1]]
else: # replica exchange
elif simulation_type == 'HREX':
empirical[i] = [float(k) for k in lines[n - 2 - i].split()[1:-1]]
else:
raise ValueError(f"Invalid simulation type {simulation_type}.")

if "Transition Matrix" in l and "Empirical" not in l: # only occurs in expanded ensemble
theoretical_found = True
Expand All @@ -83,22 +86,23 @@ def calc_transmtx(log_file, expanded_ensemble=True):

def calc_equil_prob(trans_mtx):
"""
Calculates the equilibrium probability of each state from the state transition matrix.
The input state transition matrix can be either left or right stochastic, although the left
Calculates the equilibrium probability of each state from a transition matrix.
The input transition matrix can be either left or right stochastic, although the left
stochastic ones are not common in GROMACS. Generally, transition matrices in GROMACS are either
doubly stochastic (replica exchange), or right stochastic (expanded ensemble). For the latter case,
the staionary distribution vector is the left eigenvector corresponding to the eigenvalue 1
the stationary distribution vector is the left eigenvector corresponding to the eigenvalue 1
of the transition matrix. (For the former case, it's either left or right eigenvector corresponding
to the eigenvalue 1 - as the left and right eigenvectors are the same for a doubly stochasti matrix.)
Note that the input transition matrix can be either state-space or replica-space.
Parameters
----------
trans_mtx : np.ndarray
The input state transition matrix
trans_mtx : numpy.ndarray
The input transition matrix.
Returns
-------
equil_prob : np.ndarray
equil_prob : numpy.ndarray
"""
check_row = sum([np.isclose(np.sum(trans_mtx[i]), 1) for i in range(len(trans_mtx))])
check_col = sum([np.isclose(np.sum(trans_mtx[:, i]), 1) for i in range(len(trans_mtx))])
Expand All @@ -123,30 +127,31 @@ def calc_equil_prob(trans_mtx):
return equil_prob


def calc_spectral_gap(trans_mtx, atol=1e-8, n_bootstrap=50, bootstrap_seed=None):
def calc_spectral_gap(trans_mtx, atol=1e-8, n_bootstrap=50, seed=None):
"""
Calculates the spectral gap of the input transition matrix and estimates its
Calculates the spectral gap of an input transition matrix and estimates its
uncertainty using the bootstrap method.
Parameters
----------
trans_mtx : np.ndarray
trans_mtx : numpy.ndarray
The input transition matrix
atol: float
The absolute tolerance for checking the sum of columns and rows.
n_bootstrap: int
The number of bootstrap iterations for uncertainty estimation.
bootstrap_seed: int
The seed for the random number generator for the bootstrap method.
atol: float, Optional
The absolute tolerance for checking the sum of columns and rows. The default is 1e-8.
n_bootstrap: int, Optional
The number of bootstrap iterations for uncertainty estimation. The default is 50.
seed: int, Optional
The seed for the random number generator used in the bootstrap method. The default is :code:`None`,
which means no seed will be used.
Returns
-------
spectral_gap : float
The spectral gap of the input transitio n matrix.
The spectral gap of the input transition matrix.
spectral_gap_err : float
The estimated uncertainty of the spectral gap.
eig_vals : list
The list of eigenvalues.
The list of eigenvalues. The maximum eigenvalue should always be 1.
"""
check_row = sum([np.isclose(np.sum(trans_mtx[i]), 1, atol=atol) for i in range(len(trans_mtx))])
check_col = sum([np.isclose(np.sum(trans_mtx[:, i]), 1, atol=atol) for i in range(len(trans_mtx))])
Expand All @@ -171,7 +176,7 @@ def calc_spectral_gap(trans_mtx, atol=1e-8, n_bootstrap=50, bootstrap_seed=None)
spectral_gap_list = []
n_performed = 0
while n_performed < n_bootstrap:
mtx_boot = synthesize_data.synthesize_transmtx(trans_mtx, seed=bootstrap_seed)[0]
mtx_boot = synthesize_data.synthesize_transmtx(trans_mtx, seed=seed)[0]
check_row_boot = sum([np.isclose(np.sum(mtx_boot[i]), 1, atol=atol) for i in range(len(mtx_boot))])
check_col_boot = sum([np.isclose(np.sum(mtx_boot[:, i]), 1, atol=atol) for i in range(len(mtx_boot))])
if check_row_boot == len(mtx_boot):
Expand Down Expand Up @@ -201,9 +206,10 @@ def calc_t_relax(spectral_gap, exchange_period, spectral_gap_err=None):
The input spectral gap.
exchange_period : float
The exchange period of the simulation in ps.
spectral_gap_err : float
The uncertainty of the spectral gap, which is used to calculate the uncertainty of the relaxation time using
error propagation.
spectral_gap_err : float, Optional
The uncertainty of the spectral gap, which is used to calculate the uncertainty of the relaxation time by
error propagation. The default is :code:`None`, in which case the uncertainty of the relaxation time
will be :code:`None`.
Returns
-------
Expand All @@ -223,25 +229,25 @@ def calc_t_relax(spectral_gap, exchange_period, spectral_gap_err=None):

def split_transmtx(trans_mtx, n_sim, n_sub):
"""
Split the input transition matrix into blocks of smaller matrices corresponding to
difrerent alchemical ranges of different replicas. Notably, the function assumes
Splits the input state-space transition matrix into blocks of smaller matrices corresponding to
different state states sampled by different replicas. Notably, the function assumes
homogeneous shifts and number of states across replicas. Also, the blocks of the
transition matrix is generally not doubly stochastic but right stochastic even if
the input is doubly stochastic.
Parameters
----------
trans_mtx : np.ndarray
The input state transition matrix to split
trans_mtx : numpy.ndarray
The input state-space transition matrix to be split.
n_sim : int
The number of replicas in REXEE.
The number of replicas in the REXEE simulation.
n_sub : int
The number of states for each replica.
The number of states in each replica.
Returns
-------
sub_mtx: list
Blocks of transition matrices split from the input.
Blocks of transition matrices split from the input transition matrix.
"""
sub_mtx = []
ranges = [[i, i + n_sub] for i in range(n_sim)] # A list of lists containing the min/max of alchemcial ranges
Expand All @@ -255,20 +261,20 @@ def split_transmtx(trans_mtx, n_sim, n_sub):
return sub_mtx


def plot_matrix(matrix, png_name, title=None, start_idx=0):
def plot_matrix(matrix, fig_name, title=None, start_idx=0):
"""
Visualizes a matrix a in heatmap.
Visualizes a matrix in a heatmap.
Parameters
----------
matrix : np.ndarray
The matrix to be visualized
png_name : str
The file name of the output PNG file (including the extension).
title : str
The title of the plot
start_idx : int
The starting value of the state index
matrix : numpy.ndarray
The matrix to be visualized.
fig_name : str
The file path to save the figure.
title : str, Optional
The title of the plot. The default is :code:`None`, which means no title will be added.
start_idx : int, Optional
The starting value of the state index. The default is 0.
"""

sns.set_context(
Expand Down Expand Up @@ -320,5 +326,5 @@ def plot_matrix(matrix, png_name, title=None, start_idx=0):
plt.title(title, fontsize=10, weight="bold")
plt.tight_layout(pad=1.0)

plt.savefig(png_name, dpi=600)
plt.savefig(fig_name, dpi=600)
plt.close()
6 changes: 3 additions & 3 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def plot_rep_trajs(trajs, fig_name, dt=None, stride=None):
trajs : list
A list of lists that represent the all replica-space trajectories.
fig_name : str
The file path of the PNG file to be saved.
The file path to save the figure.
dt : float or None, Optional
One trajectory timestep in ps. If :code:`dt=None`, the function assumes there are no time frames but MC steps.
The default is :code:`None`.
Expand Down Expand Up @@ -403,7 +403,7 @@ def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None, title_
state_ranges : list
A list of lists of showing the state indices sampled by each replica.
fig_name : str
The file path of the PNG file to be saved.
The file path to save the figure.
dt : float or None, Optional
One trajectory timestep in ps. If :code:`dt=None`, the function assumes there are no time frames but MC steps.
The default is :code:`None`.
Expand Down Expand Up @@ -501,7 +501,7 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre
state_ranges : list
A list of lists of showing the state indices sampled by each replica.
fig_name : str
The file path of the PNG file to be saved.
The file path to save the figure.
stack : bool, Optional
Whether to stack the histograms. This parameter is only relevant when :code:`subplots` is :code:`False`.
The default is :code:`True`.
Expand Down
4 changes: 3 additions & 1 deletion ensemble_md/analysis/synthesize_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
# #
####################################################################
"""
The :obj:`.synthesize_data` module provides methods for synthesizing REXEE data.
The :obj:`.synthesize_data` module provides methods for synthesizing REXEE data. This is mainly useful
for carrying out the bootstrap analysis for some analysis, such as the spectral gap calculation in
:obj:`.analyze_matrix`.
"""
import numpy as np
from ensemble_md.analysis import analyze_traj
Expand Down
2 changes: 1 addition & 1 deletion ensemble_md/tests/test_analyze_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_calc_transmtx():
# Case 3: Hamiltonian replica exchange
# Note that the transition matrices shown in the log file of different replicas should all be the same.
# Here we use log/HREX.log, which is a part of the log file from anthracene HREX.
A3, B3, C3 = analyze_matrix.calc_transmtx(os.path.join(input_path, 'log/HREX.log'), expanded_ensemble=False)
A3, B3, C3 = analyze_matrix.calc_transmtx(os.path.join(input_path, 'log/HREX.log'), simulation_type='HREX')
A3_expected = np.array([[0.7869, 0.2041, 0.0087, 0.0003, 0.0000, 0.0000, 0.0000, 0.0000], # noqa: E128, E202, E203, E501
[0.2041, 0.7189, 0.0728, 0.0041, 0.0001, 0.0000, 0.0000, 0.0000], # noqa: E128, E202, E203
[0.0087, 0.0728, 0.7862, 0.1251, 0.0071, 0.0001, 0.0000, 0.0000], # noqa: E202, E203
Expand Down
2 changes: 1 addition & 1 deletion ensemble_md/tests/test_replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_print_params(self, capfd, params_dict):
REXEE.print_params()
out_1, err = capfd.readouterr()
L = ""
L += "Important parameters of EXEE\n============================\n"
L += "Important parameters of REXEE\n=============================\n"
L += f"Python version: {sys.version}\n"
L += f"GROMACS executable: {REXEE.gmx_path}\n" # Easier to pass CI. This is easy to catch anyway
L += f"GROMACS version: {REXEE.gmx_version}\n" # Easier to pass CI. This is easy to catch anyway
Expand Down

0 comments on commit d61e0d1

Please sign in to comment.