diff --git a/docs/api/api_analysis.rst b/docs/api/api_analysis.rst index 6f305b8e..22dfbe62 100644 --- a/docs/api/api_analysis.rst +++ b/docs/api/api_analysis.rst @@ -22,6 +22,20 @@ ensemble\_md.analysis.analyze_free_energy :members: :undoc-members: +ensemble\_md.analysis.synthesize_data +----------------------------------------- + +.. automodule:: ensemble_md.analysis.synthesize_data + :members: + :undoc-members: + +ensemble\_md.analysis.clustering +----------------------------------------- + +.. automodule:: ensemble_md.analysis.clustering + :members: + :undoc-members: + ensemble\_md.analysis.msm_analysis ---------------------------------- diff --git a/docs/api/api_ensemble_EXE.rst b/docs/api/api_ensemble_EXE.rst index 2dc5bdba..f853dfde 100644 --- a/docs/api/api_ensemble_EXE.rst +++ b/docs/api/api_ensemble_EXE.rst @@ -1,6 +1,6 @@ -ensemble\_md.ensemble_EXE -========================= +ensemble\_md.replica_exchange_EE +================================ -.. automodule:: ensemble_md.ensemble_EXE +.. automodule:: ensemble_md.replica_exchange_EE :members: :undoc-members: \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 3497cee3..8119210b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,6 +50,7 @@ 'sphinx.ext.napoleon', 'sphinx.ext.intersphinx', 'sphinx.ext.extlinks', + 'sphinx.ext.todo', 'nbsphinx', ] @@ -57,6 +58,7 @@ napoleon_google_docstring = False napoleon_use_param = False napoleon_use_ivar = True +todo_include_todos = True # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -180,5 +182,13 @@ # -- Others ------------------------------------------------------------------ autodoc_default_options = { - 'private-members': True, + 'private-members': False, +} + +intersphinx_mapping = { + 'python': ('https://docs.python.org/3', None), + 'numpy': ('https://numpy.org/doc/stable/', None), + 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), + 'pymbar': ('https://pymbar.readthedocs.io/en/latest/', None), + 'alchemlyb': ('https://alchemlyb.readthedocs.io/en/latest/', None), } \ No newline at end of file diff --git a/docs/simulations.rst b/docs/simulations.rst index f5979c24..9a549c61 100644 --- a/docs/simulations.rst +++ b/docs/simulations.rst @@ -7,7 +7,7 @@ can be used to perform and analyze REXEE simulations, respectively. Below we provide more details about each of these CLIs. 1.1. CLI :code:`explore_REXEE` ------------------------------ +------------------------------ Here is the help message of :code:`explore_REXEE`: :: @@ -32,7 +32,7 @@ Here is the help message of :code:`explore_REXEE`: 1.2. CLI :code:`run_REXEE` -------------------------- +-------------------------- Here is the help message of :code:`run_REXEE`: :: @@ -71,7 +71,7 @@ so each of the 4 replicas will use 32 threads (assuming thread-MPI GROMACS), tak of 128 cores. 1.3. CLI :code:`analyze_REXEE` ------------------------------ +------------------------------ Finally, here is the help message of :code:`analyze_REXEE`: :: @@ -258,7 +258,7 @@ include parameters for data analysis here. .. _doc_REXEE_parameters: 3.3. REXEE parameters --------------------- +--------------------- - :code:`n_sim`: (Required) The number of replica simulations. @@ -425,6 +425,7 @@ MDP parameters: can be observed in a fixed-weight REXEE simulation and the equilibration time may be much longer for a weight-updating REXEE simulation. To ensure the same reference distance across all iterations in an REXEE simulation, consider the following scenarios: + - If you would like to use the COM distance between the pull groups in the input GRO file as the reference distance for all the iterations (whatever that value is), then specify :code:`pull_coord1_start = yes` with :code:`pull_coord1_init = 0` in your input MDP template. In this case, :obj:`.update_MDP` will parse :code:`pullx.xvg` diff --git a/docs/theory.rst b/docs/theory.rst index aae99f43..da369a7b 100644 --- a/docs/theory.rst +++ b/docs/theory.rst @@ -1,9 +1,9 @@ -.. _doc_basic_idea: - .. note:: This page is still a work in progress. Please check `Issue 33`_ for the current progress. .. _`Issue 33`: https://github.com/wehs7661/ensemble_md/issues/33 +.. _doc_basic_idea: + 1. Basic idea ============= Replica exchange of expanded ensemble (REXEE) integrates the core principles of replica exchange (REX) diff --git a/ensemble_md/analysis/analyze_free_energy.py b/ensemble_md/analysis/analyze_free_energy.py index b5e8ba55..0d2d1a0e 100644 --- a/ensemble_md/analysis/analyze_free_energy.py +++ b/ensemble_md/analysis/analyze_free_energy.py @@ -27,35 +27,34 @@ def preprocess_data(files_list, temp, data_type, spacing=1, t=None, g=None): """ - This function preprocesses :math:`u_{nk}`/:math:`dH/dλ` data for all replicas in an REXEE simulation. - For each replica, it reads in :math:`u_{nk}`/:math:`dH/dλ` data from all iterations, concatenate - them, remove the equilibrium region and and decorrelate the concatenated data. Notably, - the data preprocessing protocol is basically the same as the one adopted in - :code:`alchemlyb.subsampling.equilibrium_detection`. + This function preprocesses :math:`u_{nk}` or :math:`dH/dλ` data for all replicas in an REXEE simulation. + For each replica, it reads in :math:`u_{nk}` or :math:`dH/dλ` data from all iterations, concatenate + them, remove the equilibrium region and and decorrelate the concatenated data. Parameters ---------- files_list : list - A list of lists of naturally sorted dhdl file names from all iterations for different replicas. - :code:`files[i]` should be the list of dhdl file names from all iterations of replica :code:`i`. + A list of lists of naturally sorted DHDL file names from all iterations for different replicas. + Specifically, :code:`files[i]` should be the list of DHDL file names from all iterations of replica :code:`i`. temp : float The simulation temperature in Kelvin. We assume all replicas were performed at the same temperature. data_type : str The type of energy data to be procssed. Should be either :code:`'u_nk'` or :code:`'dhdl'`. - spacing : int - The spacing (number of data points) to consider when subsampling the data, which is assumed to - be the same for all replicas. - t : int + spacing : int, Optional + The spacing (in the number of data points) to consider when subsampling the data, which is assumed to + be the same for all replicas. The default is 1. + t : int, Optional The user-specified index that indicates the start of equilibrated data. If this parameter is not specified, - the function will estimate it using :code:`pymbar.timeseries.detect_equilibration`. - g : int - The user-specified index that indicates the start of equilibrated data. If this parameter is not specified, - the function will estimate it using :code:`pymbar.timeseries.detect_equilibration`. + the function will estimate it using :func:`pymbar.timeseries.detect_equilibration`. The default + is :code:`None`. + g : int, Optional + The user-specified statistical inefficiency. If this parameter is not specified, the function will estimate + it using :func:`pymbar.timeseries.detect_equilibration`. The default is :code:`None`. Returns ------- - preprocessed_data_all : pd.Dataframe - A list of preprocessed :math:`u_{nk}`/:math:`dH/dλ` data for all replicas that can serve as the + preprocessed_data_all : pandas.Dataframe + A list of preprocessed :math:`u_{nk}` or :math:`dH/dλ` data for all replicas that can serve as the input to free energy estimators. t_list : list A list of indices indicating the start of equilibrated data for different replicas. This list will @@ -78,7 +77,7 @@ def preprocess_data(files_list, temp, data_type, spacing=1, t=None, g=None): n_sim = len(files_list) preprocessed_data_all, t_list, g_list = [], [], [] for i in range(n_sim): - print(f'Reading dhdl files of alchemical range {i} ...') + print(f'Reading dhdl files of state range {i} ...') print(f'Collecting {data_type} data from all iterations ...') data = alchemlyb.concat([extract_fn(xvg, T=temp) for xvg in files_list[i]]) data_series = convert_fn(data) @@ -112,21 +111,27 @@ def preprocess_data(files_list, temp, data_type, spacing=1, t=None, g=None): def _apply_estimators(data, df_method="MBAR"): """ - An internal function that generates a list of estimators fitting the input data. + An internal function used in :func:`calculate_free_energy` to generate a list of estimators fitting the input data. Parameters ---------- - data : pd.Dataframe - A list of dHdl or u_nk dataframes obtained from all replicas of the REXEE simulation of interest. - Preferrably, the dHdl or u_nk data should be preprocessed by the function proprocess_data. - df_method : str - The selected free energy estimator. Options include "MBAR", "BAR" and "TI". + data : pandas.Dataframe + A list of :math:`dH/dλ` or :math:`u_{nk}` dataframes obtained from all replicas of the REXEE simulation + Preferrably, the :math:`dH/dλ` or :math:`u_{nk}` data should be preprocessed by the function + :func:`preprocess_data`. + df_method : str, Optional + The selected free energy estimator. Options include :code:`"MBAR"`, :code:`"BAR"` and :code:`"TI"`. + The default is :code:`"MBAR"`. Returns ------- estimators : list A list of estimators fitting the input data for all replicas. With this, the user can access all the free energies and their associated uncertainties for all states and replicas. + + See also + -------- + :func:`calculate_free_energy` """ n_sim = len(data) estimators = [] # A list of objects of the corresponding class in alchemlyb.estimators @@ -145,7 +150,7 @@ def _apply_estimators(data, df_method="MBAR"): def _calculate_df_adjacent(estimators): """ - An Internal function that calculates at list of free energy between adjacent + An internal function used in :func:`calculate_free_energy` to calculate a list of free energies between adjacent states for all replicas. Parameters @@ -153,13 +158,18 @@ def _calculate_df_adjacent(estimators): estimators : list A list of estimators fitting the input data for all replicas. With this, the user can access all the free energies and their associated uncertainties for all states and replicas. + In our code, these estimators come from the function :func:`_apply_estimators`. Returns ------- df_adjacent : list A list of lists free energy differences between adjacent states for all replicas. df_err_adjacent : list - A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. + A list of lists of uncertainties corresponding to the values in :code:`df_adjacent`. + + See also + -------- + :func:`calculate_free_energy` """ n_sim = len(estimators) df_adjacent = [list(np.array(estimators[i].delta_f_)[:-1, 1:].diagonal()) for i in range(n_sim)] @@ -168,35 +178,40 @@ def _calculate_df_adjacent(estimators): return df_adjacent, df_err_adjacent -def _combine_df_adjacent(df_adjacent, df_err_adjacent, state_ranges, err_type): +def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_type="propagate"): """ - An internal function that combines the free energy differences between adjacent states - in different state ranges using either simple means or inverse-variance weighted means. - Specifically, if :code:`df_err_adjacent` is :code:`None`, simple means will be used. - Otherwise, inverse-variance weighted means will be used. + An internal function used in :func:`calculate_free_energy` to combine the free energy differences between + adjacent states in different state ranges using either simple means or inverse-variance weighted means. Parameters ---------- df_adjacent : list A list of lists free energy differences between adjacent states for all replicas. - df_err_adjacent : list - A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. state_ranges : list - A list of lists of intergers that represents the alchemical states that can be sampled by different replicas. - err_type : str + A list of lists of showing the state indices sampled by each replica. + df_err_adjacent : list, Optional + A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. Notably, if + :code:`df_err_adjacent` is :code:`None`, simple means will be used. Otherwise, inverse-variance weighted + means will be used. The default is :code:`None` + err_type : str, Optional How the error of the combined free energy differences should be calculated. Available options include - "propagate" and "std". Note that the option "propagate" is only available when :code:`df_err_adjacent` - is not :code:`None`. + :code:`"propagate"` and :code:`"std"`. Note that the option :code:`"propagate"` is only available when + :code:`df_err_adjacent` is not :code:`None`. Returns ------- df : list - A list of free energy differences between states i and i + 1 for the entire state range. + A list of free energy differences between states :math:`i` and :math:`i + 1` for the entire state range. df_err : list A list of uncertainties of the free energy differences for the entire state range. overlap_bool : list - overlap_bool[i] = True means that the i-th free energy difference (i.e. df[i]) was available - in multiple replicas. + A list of boolean values indicating whether a free energy difference was available in multiple replicas. + For example, :code:`overlap_bool[i] = True` means that the :math:`i`-th free energy difference (i.e. + :code:`df[i]`) was available in multiple replicas. + + See also + -------- + :func:`calculate_free_energy` """ n_tot = state_ranges[-1][-1] + 1 df, df_err, overlap_bool = [], [], [] @@ -232,46 +247,62 @@ def _combine_df_adjacent(df_adjacent, df_err_adjacent, state_ranges, err_type): return df, df_err, overlap_bool -def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method='propagate', n_bootstrap=None, seed=None): +def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None): """ - Caculates the averaged free energy profile with the chosen method given dHdl or u_nk data obtained from - all replicas of the REXEE simulation of interest. Available methods include TI, BAR, and MBAR. TI - requires dHdl data while the other two require u_nk data. + Caculates the averaged free energy profile with the chosen method given :math:`u_{nk}` or :math:`dH/dλ` data + obtained from all replicas of the REXEE simulation. Available methods include TI, BAR, and MBAR. TI + requires :math:`dH/dλ` data while the other two require :math:`u_{nk}` data. Parameters ---------- - data : pd.Dataframe - A list of dHdl or u_nk dataframes obtained from all replicas of the REXEE simulation of interest. - Preferrably, the dHdl or u_nk data should be preprocessed by the function proprocess_data. + data : pandas.Dataframe + A list of :math:`u_{nk}` or :math:`dH/dλ` dataframes obtained from all replicas of the REXEE simulation. + Preferrably, the :math:`u_{nk}` or :math:`dH/dλ` data should be preprocessed by the function + :func:`proprocess_data`. state_ranges : list - A list of lists of intergers that represents the alchemical states that can be sampled by different replicas. - df_method : str - The method used to calculate the free energy profile. Available choices include "TI", "BAR", and "MBAR". - err_method : str + A list of lists of showing the state indices sampled by each replica. + df_method : str, Optional + The method used to calculate the free energy profile. Available choices include :code:`"TI"`, + :code:`"BAR"`, and :code:`"MBAR"`. The default is :code:`"MBAR"`. + err_method : str, Optional The method used to estimate the uncertainty of the free energy combined across multiple replicas. - Available options include "propagate" and "bootstrap". The bootstrapping method is more accurate - but much more computationally expensive than simple error propagation. - n_bootstrap : int + Available options include :code:`"propagate"` and :code:`"bootstrap"`. The bootstrapping method is + more accurate but much more computationally expensive than simple error propagation. + n_bootstrap : int, Optional The number of bootstrap iterations. This parameter is used only when the boostrapping method is chosen to - estimate the uncertainties of the free energies. - seed : int - The random seed for bootstrapping. + estimate the uncertainties of the free energies. The default is :code:`None`. In the CLI :code:`analyze_REXEE`, + this number is set by the YAML parameter :code:`n_bootstrap`. + seed : int, Optional + The random seed for bootstrapping. Only relevant when :code:`err_method` is :code:`"bootstrap"`. + The default is :code:`None`. Returns ------- f : list The full-range free energy profile. f_err : list - The uncertainty corresponding to the values in :code:`f`. + The uncertainties corresponding to the values in :code:`f`. estimators : list A list of estimators fitting the input data for all replicas. With this, the user can access all the free energies and their associated uncertainties for all states and replicas. + + Example + ------- + In the CLI :code:`analyze_REXEE`, lines like below are used: + + >>> import glob + >>> import natsort + >>> from ensemble_md.analysis import analyze_free_energy + >>> state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]] + >>> file_list = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/dhdl*xvg')) for i in range(4)] + >>> data_list, _, _ = analyze_free_energy.preprocess_data(file_list, temp=300, data_type='u_nk') + >>> f, _, _ = analyze_free_energy.calculate_free_energy(data_list, state_ranges, "MBAR", "propagate") """ n_sim = len(data) n_tot = state_ranges[-1][-1] + 1 estimators = _apply_estimators(data, df_method) df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators) - df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, df_err_adjacent, state_ranges, err_type='propagate') + df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') if err_method == 'bootstrap': if seed is not None: @@ -284,7 +315,7 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method='prop sampled_data = [sampled_data_all[i].iloc[b * len(data[i]):(b + 1) * len(data[i])] for i in range(n_sim)] bootstrap_estimators = _apply_estimators(sampled_data, df_method) df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators) - df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, df_err_adjacent, state_ranges, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501 + df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501 df_bootstrap.append(df_sampled) error_bootstrap = np.std(df_bootstrap, axis=0, ddof=1) @@ -309,7 +340,7 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method='prop def calculate_df_rmse(estimators, df_ref, state_ranges): """ - Calculates the RMSE values of the free energy profiles of different alchemical ranges given the reference free + Calculates the RMSE values of the free energy profiles of different state ranges given the reference free energy profile for the whole range of states. Parameters @@ -317,16 +348,21 @@ def calculate_df_rmse(estimators, df_ref, state_ranges): estimators : list A list of estimators fitting the input data for all replicas. With this, the user can access all the free energies and their associated uncertainties for all states and replicas. + The estimators should be generated by the function :func:`calculate_free_energy`. df_ref : list A list of values corresponding to the free energies of the whole range of states. The length of the list should be equal to the number of states in total. state_ranges : list - A list of lists of intergers that represents the alchemical states that can be sampled by different replicas. + A list of lists of showing the state indices sampled by each replica. Returns ------- rmse_list : list - A list of RMSE values of the free energy profiles of different alchemical ranges. + A list of RMSE values of the free energy profiles of different state ranges. + + See also + -------- + :func:`calculate_free_energy` """ n_sim = len(estimators) df_ref = np.array(df_ref) @@ -344,16 +380,16 @@ def calculate_df_rmse(estimators, df_ref, state_ranges): def plot_free_energy(f, f_err, fig_name): """ - Plot the free energy profile with error bars. + Plots the free energy profile with error bars. Parameters ---------- f : list The full-range free energy profile. f_err : list - The uncertainty corresponding to the values in :code:`f`. + The uncertainties corresponding to the values in :code:`f`. fig_name : str - The file name of the png file to be saved (with the extension). + The file path to save the figure. """ plt.figure() plt.plot(range(len(f)), f, 'o-', c='#1f77b4') @@ -366,24 +402,26 @@ def plot_free_energy(f, f_err, fig_name): def average_weights(g_vecs, frac): """ - Average the differences between the weights of the coupled and uncoupled states. - This can be an estimate of the free energy difference between two end states. + Given the time series of the whole range of alchemical weights, averages the + weight differences between the the coupled and decoupled states. This can be + an estimate of the free energy difference between two end states. This function + is only relevant for weight-updating REXEE simulations. Parameters ---------- - g_vecs : np.array + g_vecs : numpy.ndarray An array of alchemical weights of the whole range of states as a function of simulation time, which is typically generated by :obj:`.combine_weights`. frac : float - The fraction of g_vecs to average over. frac=0.2 means average the last 20% of - the weight vectors. + The fraction of :code:`g_vecs` to average over. :code:`frac=0.2` means average + the last 20% of the weight vectors will be averaged. Returns ------- dg_avg : float - The averaged difference in the weights between the coupled and uncoupled states. + The averaged weight difference between the coupled and decoupled states. dg_avg_err : float - The error of :code:`dg_avg`. + The errors corresponding to the value of :code:`dg_avg`. """ N = len(g_vecs) dg = [] diff --git a/ensemble_md/analysis/analyze_matrix.py b/ensemble_md/analysis/analyze_matrix.py index 5af1ef2f..2b226f5f 100644 --- a/ensemble_md/analysis/analyze_matrix.py +++ b/ensemble_md/analysis/analyze_matrix.py @@ -8,7 +8,7 @@ # # #################################################################### """ -The :obj:`.analyze_matrix` module provides methods for analyzing matrices obtained from REXEE. +The :obj:`.analyze_matrix` module provides methods for analyzing matrices obtained from a REXEE simulation. """ import numpy as np import seaborn as sns @@ -19,7 +19,7 @@ from ensemble_md.analysis import synthesize_data -def calc_transmtx(log_file, expanded_ensemble=True): +def calc_transmtx(log_file, simulation_type='EE'): """ Parses the log file to get the transition matrix of an expanded ensemble or replica exchange simulation. Notably, a theoretical transition matrix @@ -28,18 +28,19 @@ def calc_transmtx(log_file, expanded_ensemble=True): Parameters ---------- log_file : str - The log file to be parsed. - expanded_ensemble : bool - Whether the simulation is expanded ensemble or replica exchange + The file path to the log file to be parsed. + simulation_type : str, Optional + The type of simulation. It can be either a :code:`EE` (expanded ensemble) or :code:`HREX` + (Hamiltonian replica exchange) simulation. The default is :code:`EE`. Returns ------- - empirical : np.ndarray + empirical : numpy.ndarray The final empirical state transition matrix. - theoretical : None or np.ndarray + theoretical : None or numpy.ndarray The final theoretical state transition matrix. - diff_matrix : None or np.ndarray - The difference between the theortial and empirical state transition matrix (empirical - theoretical). + diff_matrix : None or numpy.ndarray + The difference calculated by subtracting the theoretical matrix from the empirical matrix. """ f = open(log_file, "r") lines = f.readlines() @@ -56,10 +57,12 @@ def calc_transmtx(log_file, expanded_ensemble=True): n_states = int(lines[n - 1].split()[-1]) empirical = np.zeros([n_states, n_states]) for i in range(n_states): - if expanded_ensemble is True: + if simulation_type == 'EE': empirical[i] = [float(k) for k in lines[n - 2 - i].split()[:-1]] - else: # replica exchange + elif simulation_type == 'HREX': empirical[i] = [float(k) for k in lines[n - 2 - i].split()[1:-1]] + else: + raise ValueError(f"Invalid simulation type {simulation_type}.") if "Transition Matrix" in l and "Empirical" not in l: # only occurs in expanded ensemble theoretical_found = True @@ -83,22 +86,23 @@ def calc_transmtx(log_file, expanded_ensemble=True): def calc_equil_prob(trans_mtx): """ - Calculates the equilibrium probability of each state from the state transition matrix. - The input state transition matrix can be either left or right stochastic, although the left + Calculates the equilibrium probability of each state from a transition matrix. + The input transition matrix can be either left or right stochastic, although the left stochastic ones are not common in GROMACS. Generally, transition matrices in GROMACS are either doubly stochastic (replica exchange), or right stochastic (expanded ensemble). For the latter case, - the staionary distribution vector is the left eigenvector corresponding to the eigenvalue 1 + the stationary distribution vector is the left eigenvector corresponding to the eigenvalue 1 of the transition matrix. (For the former case, it's either left or right eigenvector corresponding to the eigenvalue 1 - as the left and right eigenvectors are the same for a doubly stochasti matrix.) + Note that the input transition matrix can be either state-space or replica-space. Parameters ---------- - trans_mtx : np.ndarray - The input state transition matrix + trans_mtx : numpy.ndarray + The input transition matrix. Returns ------- - equil_prob : np.ndarray + equil_prob : numpy.ndarray """ check_row = sum([np.isclose(np.sum(trans_mtx[i]), 1) for i in range(len(trans_mtx))]) check_col = sum([np.isclose(np.sum(trans_mtx[:, i]), 1) for i in range(len(trans_mtx))]) @@ -123,30 +127,31 @@ def calc_equil_prob(trans_mtx): return equil_prob -def calc_spectral_gap(trans_mtx, atol=1e-8, n_bootstrap=50, bootstrap_seed=None): +def calc_spectral_gap(trans_mtx, atol=1e-8, n_bootstrap=50, seed=None): """ - Calculates the spectral gap of the input transition matrix and estimates its + Calculates the spectral gap of an input transition matrix and estimates its uncertainty using the bootstrap method. Parameters ---------- - trans_mtx : np.ndarray + trans_mtx : numpy.ndarray The input transition matrix - atol: float - The absolute tolerance for checking the sum of columns and rows. - n_bootstrap: int - The number of bootstrap iterations for uncertainty estimation. - bootstrap_seed: int - The seed for the random number generator for the bootstrap method. + atol: float, Optional + The absolute tolerance for checking the sum of columns and rows. The default is 1e-8. + n_bootstrap: int, Optional + The number of bootstrap iterations for uncertainty estimation. The default is 50. + seed: int, Optional + The seed for the random number generator used in the bootstrap method. The default is :code:`None`, + which means no seed will be used. Returns ------- spectral_gap : float - The spectral gap of the input transitio n matrix. + The spectral gap of the input transition matrix. spectral_gap_err : float The estimated uncertainty of the spectral gap. eig_vals : list - The list of eigenvalues. + The list of eigenvalues. The maximum eigenvalue should always be 1. """ check_row = sum([np.isclose(np.sum(trans_mtx[i]), 1, atol=atol) for i in range(len(trans_mtx))]) check_col = sum([np.isclose(np.sum(trans_mtx[:, i]), 1, atol=atol) for i in range(len(trans_mtx))]) @@ -171,7 +176,7 @@ def calc_spectral_gap(trans_mtx, atol=1e-8, n_bootstrap=50, bootstrap_seed=None) spectral_gap_list = [] n_performed = 0 while n_performed < n_bootstrap: - mtx_boot = synthesize_data.synthesize_transmtx(trans_mtx, seed=bootstrap_seed)[0] + mtx_boot = synthesize_data.synthesize_transmtx(trans_mtx, seed=seed)[0] check_row_boot = sum([np.isclose(np.sum(mtx_boot[i]), 1, atol=atol) for i in range(len(mtx_boot))]) check_col_boot = sum([np.isclose(np.sum(mtx_boot[:, i]), 1, atol=atol) for i in range(len(mtx_boot))]) if check_row_boot == len(mtx_boot): @@ -201,9 +206,10 @@ def calc_t_relax(spectral_gap, exchange_period, spectral_gap_err=None): The input spectral gap. exchange_period : float The exchange period of the simulation in ps. - spectral_gap_err : float - The uncertainty of the spectral gap, which is used to calculate the uncertainty of the relaxation time using - error propagation. + spectral_gap_err : float, Optional + The uncertainty of the spectral gap, which is used to calculate the uncertainty of the relaxation time by + error propagation. The default is :code:`None`, in which case the uncertainty of the relaxation time + will be :code:`None`. Returns ------- @@ -223,25 +229,25 @@ def calc_t_relax(spectral_gap, exchange_period, spectral_gap_err=None): def split_transmtx(trans_mtx, n_sim, n_sub): """ - Split the input transition matrix into blocks of smaller matrices corresponding to - difrerent alchemical ranges of different replicas. Notably, the function assumes + Splits the input state-space transition matrix into blocks of smaller matrices corresponding to + different state states sampled by different replicas. Notably, the function assumes homogeneous shifts and number of states across replicas. Also, the blocks of the transition matrix is generally not doubly stochastic but right stochastic even if the input is doubly stochastic. Parameters ---------- - trans_mtx : np.ndarray - The input state transition matrix to split + trans_mtx : numpy.ndarray + The input state-space transition matrix to be split. n_sim : int - The number of replicas in REXEE. + The number of replicas in the REXEE simulation. n_sub : int - The number of states for each replica. + The number of states in each replica. Returns ------- sub_mtx: list - Blocks of transition matrices split from the input. + Blocks of transition matrices split from the input transition matrix. """ sub_mtx = [] ranges = [[i, i + n_sub] for i in range(n_sim)] # A list of lists containing the min/max of alchemcial ranges @@ -255,20 +261,20 @@ def split_transmtx(trans_mtx, n_sim, n_sub): return sub_mtx -def plot_matrix(matrix, png_name, title=None, start_idx=0): +def plot_matrix(matrix, fig_name, title=None, start_idx=0): """ - Visualizes a matrix a in heatmap. + Visualizes a matrix in a heatmap. Parameters ---------- - matrix : np.ndarray - The matrix to be visualized - png_name : str - The file name of the output PNG file (including the extension). - title : str - The title of the plot - start_idx : int - The starting value of the state index + matrix : numpy.ndarray + The matrix to be visualized. + fig_name : str + The file path to save the figure. + title : str, Optional + The title of the plot. The default is :code:`None`, which means no title will be added. + start_idx : int, Optional + The starting value of the state index. The default is 0. """ sns.set_context( @@ -320,5 +326,5 @@ def plot_matrix(matrix, png_name, title=None, start_idx=0): plt.title(title, fontsize=10, weight="bold") plt.tight_layout(pad=1.0) - plt.savefig(png_name, dpi=600) + plt.savefig(fig_name, dpi=600) plt.close() diff --git a/ensemble_md/analysis/analyze_traj.py b/ensemble_md/analysis/analyze_traj.py index b57b1247..0d1619e8 100644 --- a/ensemble_md/analysis/analyze_traj.py +++ b/ensemble_md/analysis/analyze_traj.py @@ -8,7 +8,7 @@ # # #################################################################### """ -The :obj:`.analyze_traj` module provides methods for analyzing trajectories in REXEE. +The :obj:`.analyze_traj` module provides methods for analyzing trajectories of a REXEE simulation. """ import copy import numpy as np @@ -29,14 +29,14 @@ def extract_state_traj(dhdl): Parameters ---------- dhdl : str - The filename of the GROMACS DHDL file to be parsed. + The file path to the GROMACS DHDL file to be parsed. Returns ------- traj : list - A list that represents that state-space trajectory + A list that represents the state-space trajectory t : list - A list that represents the time series of the trajectory + A list that represents the time frames of the trajectory """ traj = list(extract_dataframe(dhdl, headers=get_headers(dhdl))['Thermodynamic state']) t = list(np.loadtxt(dhdl, comments=['#', '@'])[:, 0]) @@ -52,28 +52,28 @@ def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, sav Parameters ---------- files : list - A list of lists of file names of GROMACS DHDL files or general GROMACS XVG files or PLUMED ouptput files. + A list of lists of file paths to GROMACS DHDL files or general GROMACS XVG files or PLUMED ouptput files. Specifically, :code:`files[i]` should be a list containing the files of interest from all iterations in replica :code:`i`. The files should be sorted naturally. rep_trajs : list - A list of lists that represents the replica space trajectories for each starting configuration. For example, + A list of lists that represents the replica-space trajectories for each starting configuration. For example, :code:`rep_trajs[0] = [0, 2, 3, 0, 1, ...]` means that starting configuration 0 transitioned to replica 2, then - 3, 0, 1, in iterations 1, 2, 3, 4, ..., respectively. - shifts : list - A list of values for shifting the state indices for each replica. The length of the list - should be equal to the number of replicas. This is only needed when :code:`dhdl=True`. - dhdl : bool - Whether the input files are GROMACS dhdl files, in which case trajectories of global alchemical indices - will be generated. If :code:`dhdl=False`, the input files must be readable by `numpy.loadtxt` assuming that - the start of a comment is indicated by either the :code:`#` or :code:`@` characters. - Such files include any GROMACS XVG files or PLUMED output files (output by plumed driver, for instance). - In this case, trajectories of the configurational collective variable of interest are generated. + 3, 0, 1, in iterations 1, 2, 3, 4, ..., respectively. This can be read from the :code:`rep_trajs.npy` file + generated by the REXEE simulation. + shifts : list, Optional + A list of values for shifting the local state indices to global indices for each replica. The length of the + list should be equal to the number of replicas. This is only needed when :code:`dhdl=True`. + dhdl : bool, Optional + Whether the input files are GROMACS dhdl files. If :code:`dhdl=False`, the input files must be readable + by :func:`numpy.loadtxt` assuming that the start of a comment is indicated by either the :code:`#` or :code:`@` + characters. Such files include any GROMACS XVG files or PLUMED output files (output by plumed driver, + for instance). In this case, trajectories of the configurational collective variable of interest are generated. The default is :code:`True`. - col_idx : int + col_idx : int, Optional The index of the column to be extracted from the input files. This is only needed when :code:`dhdl=False`, By default, we extract the last column. - save_npy : bool - Whether to save the output trajectories as an NPY file. + save_npy : bool, Optional + Whether to save the output trajectories as an NPY file. The default is :code:`True`. Returns ------- @@ -81,6 +81,23 @@ def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, sav A list that contains lists of state-space/CV-space trajectory (in global indices) for each starting configuration. For example, :code:`trajs[i]` is the state-space/CV-space trajectory of starting configuration :code:`i`. + + Example + ------- + >>> import glob + >>> import natsort + >>> import numpy as np + >>> from ensemble_md.analysis import analyze_traj + >>> n_sim = 4 # Assuming 4 replicas sampling states sets 0-3, 2-5, 4-7, and 6-9, respectively. + >>> files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(n_sim)] + >>> shifts = [0, 2, 4, 6] + >>> rep_trajs = np.load('rep_trajs.npy') # rep_trajs.npy is generated by the REXEE simulation + >>> state_trajs = analyze_traj.stitch_time_series(files, rep_trajs, shifts, dhdl=True, save_npy=True) + + See also + -------- + :func:`.stitch_time_series_for_sim` + :func:`.stitch_xtc_trajs` """ n_configs = len(files) # number of starting configurations n_iter = len(files[0]) # number of iterations per replica @@ -120,7 +137,7 @@ def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, sav return trajs -def stitch_time_series_for_sim(files, shifts, dhdl=True, col_idx=-1, save=True): +def stitch_time_series_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save_npy=True): """ Stitches the state-space/CV-space time series in the same replica/simulation folder. That is, the output time series is contributed by multiple different trajectories (initiated by @@ -129,30 +146,44 @@ def stitch_time_series_for_sim(files, shifts, dhdl=True, col_idx=-1, save=True): Parameters ---------- files : list - A list of lists of file names of GROMACS DHDL files or general GROMACS XVG files + A list of lists of file paths to GROMACS DHDL files or general GROMACS XVG files or PLUMED output files. Specifically, :code:`files[i]` should be a list containing the files of interest from all iterations in replica :code:`i`. The files should be sorted naturally. - shifts : list - A list of values for shifting the state indices for each replica. The length of the list - should be equal to the number of replicas. This is only needed when :code:`dhdl=True`. - dhdl : bool - Whether the input files are GROMACS dhdl files, in which case trajectories of global alchemical indices - will be generated. If :code:`dhdl=False`, the input files must be readable by `numpy.loadtxt` assuming that - the start of a comment is indicated by either the :code:`#` or :code:`@` characters. - Such files include any GROMACS XVG files or PLUMED output files (output by plumed driver, for instance). - In this case, trajectories of the configurational collective variable of interest are generated. + shifts : list, Optional + A list of values for shifting the local state indices to global indices for each replica. The length of the + list should be equal to the number of replicas. This is only needed when :code:`dhdl=True`. + dhdl : bool, Optional + Whether the input files are GROMACS dhdl files. If :code:`dhdl=False`, the input files must be readable + by :func:`numpy.loadtxt` assuming that the start of a comment is indicated by either the :code:`#` or :code:`@` + characters. Such files include any GROMACS XVG files or PLUMED output files (output by plumed driver, for + instance). In this case, trajectories of the configurational collective variable of interest are generated. The default is :code:`True`. - col_idx : int + col_idx : int, Optional The index of the column to be extracted from the input files. This is only needed when :code:`dhdl=False`, By default, we extract the last column. - save : bool - Whether to save the output trajectories as an NPY file. + save_npy : bool, Optional + Whether to save the output trajectories as an NPY file. The default is :code:`True`. Returns ------- trajs : list A list that contains lists of state-space/CV-space trajectory (in global indices) for each replica. For example, :code:`trajs[i]` is the state-space/CV-space trajectory of replica :code:`i`. + + Example + ------- + >>> import glob + >>> import natsort + >>> from ensemble_md.analysis import analyze_traj + >>> n_sim = 4 # Assuming 4 replicas sampling states sets 0-3, 2-5, 4-7, and 6-9, respectively. + >>> files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(n_sim)] + >>> shifts = [0, 2, 4, 6] + >>> state_trajs = analyze_traj.stitch_time_series(files, shifts, dhdl=True, save_npy=True) + + See also + -------- + :func:`.stitch_time_series` + :func:`.stitch_xtc_trajs` """ n_sim = len(files) # number of replicas n_iter = len(files[0]) # number of iterations per replica @@ -182,10 +213,11 @@ def stitch_time_series_for_sim(files, shifts, dhdl=True, col_idx=-1, save=True): trajs[i].extend(traj) # All segments for the same replica should have the same shift - trajs[i] = list(np.array(trajs[i]) + shifts[i]) + if dhdl: + trajs[i] = list(np.array(trajs[i]) + shifts[i]) # Save the trajectories as an NPY file if desired - if save is True: + if save_npy: np.save('state_trajs_for_sim.npy', trajs) return trajs @@ -200,12 +232,19 @@ def stitch_xtc_trajs(gmx_executable, files, rep_trajs): gmx_executable : str The path to the GROMACS executable. files : list - A list of lists of file names of GROMACS XTC files. Specifically, :code:`files[i]` should be a list containing - the files of interest from all iterations in replica :code:`i`. The files should be sorted naturally. + A list of lists of file paths to GROMACS XTC files. Specifically, :code:`files[i]` should be a list containing + the paths to the files of interest from all iterations in replica :code:`i`. The files should be sorted + naturally. rep_trajs : list A list of lists that represents the replica space trajectories for each starting configuration. For example, :code:`rep_trajs[0] = [0, 2, 3, 0, 1, ...]` means that starting configuration 0 transitioned to replica 2, then - 3, 0, 1, in iterations 1, 2, 3, 4, ..., respectively. + 3, 0, 1, in iterations 1, 2, 3, 4, ..., respectively. This can be read from the :code:`rep_trajs.npy` file + generated by the REXEE simulation. + + See also + -------- + :func:`.stitch_time_series` + :func:`.stitch_time_series_for_sim` """ n_sim = len(files) # number of replicas n_iter = len(files[0]) # number of iterations per replica @@ -231,16 +270,16 @@ def stitch_xtc_trajs(gmx_executable, files, rep_trajs): def convert_npy2xvg(trajs, dt, subsampling=1): """ - Convert a :code:`state_trajs.npy` or :code:`cv_trajs.npy` file to :math:`N_{\text{rep}}` XVG files - that have two columns: time (ps) and state index. + Convert a :code:`state_trajs.npy` or :code:`cv_trajs.npy` file to :math:`R` XVG files + that have two columns: time (ps) and state index/CV value. (:math:`R` is the number of replicas.) Parameters ---------- - trajs : ndarray + trajs : numpy.ndarray The state-space or CV-space trajectories read from :code:`state_trajs.npy` or :code:`cv_trajs.npy`. dt : float The time interval (in ps) between consecutive frames of the trajectories. - subsampling : int + subsampling : int, Optional The stride for subsampling the time series. The default is 1. """ n_configs = len(trajs) @@ -259,9 +298,9 @@ def convert_npy2xvg(trajs, dt, subsampling=1): def traj2transmtx(traj, N, normalize=True): """ Computes the transition matrix given a trajectory. For example, if a state-space - trajectory from a EXE or HREX simulation given, a state transition matrix is returned. + trajectory from a EXE or HREX simulation is given, a state-space transition matrix is returned. If a trajectory showing transitions between replicas in a REXEE simulation is given, - a replica transition matrix is returned. + a replica-space transition matrix is returned. Parameters --------- @@ -271,12 +310,12 @@ def traj2transmtx(traj, N, normalize=True): N : int The size (N) of the expcted transition matrix (N by N). normalize : bool - Whether to normalize the matrix so that each row sum to 1. If False, then + Whether to normalize the matrix so that each row sum to 1. If :code:`normalize=False`, then the entries will be the counts of transitions. Returns ------- - transmtx : np.ndarray + transmtx : numpy.ndarray The transition matrix computed from the trajectory """ transmtx = np.zeros([N, N]) @@ -291,20 +330,25 @@ def traj2transmtx(traj, N, normalize=True): def plot_rep_trajs(trajs, fig_name, dt=None, stride=None): """ - Plots the time series of replicas visited by each trajectory in a single plot. + Plots the replica-space trajectories for a REXEE simulation. Parameters ---------- trajs : list - A list of arrays that represent the all replica space trajectories. + A list of lists that represent the all replica-space trajectories. fig_name : str - The file name of the png file to be saved (with the extension). - dt : float or None, optional - One trajectory timestep in ps. If None, it assumes there are no timeframes but MC steps. - stride : int, optional + The file path to save the figure. + dt : float or None, Optional + One trajectory timestep in ps. If :code:`dt=None`, the function assumes there are no time frames but MC steps. + The default is :code:`None`. + stride : int, Optional The stride for plotting the time series. The default is 100 if the length of any trajectory has more than one million frames. Otherwise, it will be 1. Typically plotting more than 10 million frames can take a lot of memory. + + See also + -------- + :func:`.plot_state_trajs` """ n_sim = len(trajs) cmap = plt.cm.ocean # other good options are CMRmap, gnuplot, terrain, turbo, brg, etc. @@ -348,28 +392,34 @@ def plot_rep_trajs(trajs, fig_name, dt=None, stride=None): def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None, title_prefix='Trajectory'): """ - Plots the time series of state index. + Plots the state-space trajectories for a REXEE simulation. Parameters ---------- trajs : list - A list of state index time series either from different continuous trajectories or from different - alchemical ranges (i.e. from different simulation folders). + A list of lists of state indices generated either from different continuous trajectories or from different + alchemical ranges (i.e. from different simulation folders). This can be generated by either + :func:`.stitch_time_series` or :func:`.stitch_time_series_for_sim`. state_ranges : list - A list of lists of state indices. (Like the attribute :code:`state_ranges` in :code:`EnsemblEXE`.) + A list of lists of showing the state indices sampled by each replica. fig_name : str - The file name of the png file to be saved (with the extension). - dt : float or None, optional - The time interval between consecutive frames of the trajectories. If None, it is assumed - that the trajectories are in terms of Monte Carlo (MC) moves instead of timeframes, and - the x-axis label is set to 'MC moves'. Default is None. - stride : int + The file path to save the figure. + dt : float or None, Optional + One trajectory timestep in ps. If :code:`dt=None`, the function assumes there are no time frames but MC steps. + The default is :code:`None`. + stride : int, Optional The stride for plotting the time series. The default is 10 if the length of any trajectory has more than 100,000 frames. Otherwise, it will be 1. Typically plotting more than 10 million frames can take a lot of memory. - title_prefix : str + title_prefix : str, Optional The prefix shared by the titles of the subplots. For example, if :code:`title_prefix` is set to "Trajectory", then the titles of the subplots will be "Trajectory 0", "Trajectory 1", ..., etc. + The default is :code:`'Trajectory'`. + + See also + -------- + :func:`.plot_rep_trajs` + :func:`.plot_state_hist` """ n_sim = len(trajs) cmap = plt.cm.ocean # other good options are CMRmap, gnuplot, terrain, turbo, brg, etc. @@ -440,36 +490,42 @@ def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None, title_ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, prefix='Trajectory', subplots=False, save_hist=True): # noqa: E501 """ - Plots state index histograms. + Plots the histograms of state visitation for all replicas in a REXEE simulation. Parameters ---------- trajs : list - A list of state index time series either from different continuous trajectories or from different - alchemical ranges (i.e. from different simulation folders). + A list of lists of state indices generated either from different continuous trajectories or from different + alchemical ranges (i.e. from different simulation folders). This can be generated by either + :func:`.stitch_time_series` or :func:`.stitch_time_series_for_sim`. state_ranges : list - A list of lists of state indices. (Like the attribute :code:`state_ranges` in :obj:`.ReplicaExchangeEE`.) + A list of lists of showing the state indices sampled by each replica. fig_name : str - The file name of the png file to be saved (with the extension). - stack : bool - Whether to stack the histograms. Only meaningful when :code:`subplots` is :code:`False`. - figsize : tuple + The file path to save the figure. + stack : bool, Optional + Whether to stack the histograms. This parameter is only relevant when :code:`subplots` is :code:`False`. + The default is :code:`True`. + figsize : tuple, Optional A tuple specifying the length and width of the output figure. The default is :code:`(6.4, 4.8)` for cases having less than 30 states and :code:`(10, 4.8)` otherwise. - prefix : str + prefix : str, Optional The prefix shared by the titles of the subplots, or the labels shown in the same plot. For example, if :code:`prefix` is set to "Trajectory", then the titles/labels of the - will be "Trajectory 0", "Trajectory 1", ..., etc. - subplots : bool - Whether to plot the histogram in multiple subplots, with the title of - each based on the value of :code:`prefix`. - save_hist : bool - Whether to save the histogram data. + will be "Trajectory 0", "Trajectory 1", ..., etc. The default is :code:`'Trajectory'`. + subplots : bool, Optional + Whether to plot the histograms in multiple subplots, with the title of + each based on the value of :code:`prefix`. The default is :code:`False`. + save_hist : bool, Optional + Whether to save the histogram data. The default is :code:`True`. Returns ------- hist_data : list The histogram data of the each state index time series. + + See also + -------- + :func:`.plot_state_trajs` """ n_configs = len(trajs) n_states = max(max(state_ranges)) + 1 @@ -578,12 +634,13 @@ def calc_hist_rmse(hist_data, state_ranges): hist_data : list The histogram data of the state index for each trajectory. state_ranges : list - A list of lists of state indices. (Like the attribute :code:`state_ranges` in :obj:`.ReplicaExchangeEE`.) + A list of lists of showing the state indices sampled by each replica. Returns ------- rmse : float - The RMSE value of accumulated histogram counts of the state index. + The RMSE value of accumulated histogram counts of the state index, with respect to the case + where equal sampling is reached for all states. """ N = np.max(state_ranges) + 1 # the number of states n_accessible = np.histogram(state_ranges, bins=np.arange(-0.5, N + 0.5))[0] @@ -598,34 +655,35 @@ def calc_hist_rmse(hist_data, state_ranges): def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'): """ - Caclulcates and plots the average transit times for each trajectory, including the time - it takes from states 0 to k, from k to 0 and from 0 to k back to 0 (i.e. round-trip time). + Calculates and plots the average transit times for each trajectory, including the time + it takes from states 0 to k, from k to 0 and from 0 to k back to 0 (i.e., round-trip time). If there are more than 100 round-trips, 3 histograms corresponding to t_0k, t_k0 and t_roundtrip will be generated. Parameters ---------- trajs : list - A list of arrays that represent the state space trajectories of all continuous trajectories. + A list of lists that represent the state-space trajectories of all continuous trajectories. N : int The total number of states in the whole alchemical range. - fig_prefix : str - A prefix to use for all generated figures. - dt : float or None, optional - One trajectory timestep in ps. If None, it assumes there are no timeframes but MC steps. - folder : str, optional - The directory for saving the figures + fig_prefix : str, Optional + A prefix to use for all generated figures. The default is :code:`None`, which means no prefix. + dt : float or None, Optional + One trajectory timestep in ps. If :code:`dt=None`, the function assumes there are no time frames but MC steps. + The default is :code:`None`. + folder : str, Optional + The directory where the figures will be saved. The default is the current directory. Returns ------- t_0k_list : list - A list of transit time from states 0 to k for each trajectory. + A list of transit times from states 0 to k for each trajectory. t_k0_list : list - A list of transit time from states k to 0 for each trajectory. + A list of transit times from states k to 0 for each trajectory. t_roundtrip_list : list A list of round-trip times for each trajectory. units : str - The units of the times. + The units of the time. """ if dt is None: x = np.arange(len(trajs[0])) @@ -772,23 +830,24 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'): def plot_g_vecs(g_vecs, refs=None, refs_err=None, plot_rmse=True): """ - Plots the alchemical weight for each alchemical intermediate state as a function of + For each alchemical intermediate state, plots the alchemical weight as a function of the iteration index. Note that the alchemical weight of the first state (which is always 0) - is skipped. If the reference values are given, they will be plotted in the figure and - an RMSE will be calculated. + is skipped. If the reference values are given, they will be plotted in the figure (as horizontoal lines) + and a final RMSE will be calculated. Note that this function is only meaningful for weight-updating + REXEE simulations. Parameters ---------- - g_vecs : np.array - The alchemical weights of all states as a function of iteration index. The shape should + g_vecs : numpy.ndarray + The alchemical weights of all states as a function of the iteration index. The shape should be (n_iterations, n_states). Such an array can be directly read from :code:`g_vecs.npy` - saved by :code:`run_REXEE`. - refs : np.array - The reference values of the alchemical weights. - refs_err : list or np.array - The errors of the reference values. - plot_rmse : bool - Whether to plot RMSE as a function of the iteration index. + generated by a REXEE simulation. + refs : numpy.ndarray + The reference values of the alchemical weights. The default is :code:`None`. + refs_err : list or numpy.ndarray, Optional + The errors of the reference values. The default is :code:`None`. + plot_rmse : bool, Optional + Whether to plot RMSE as a function of the iteration index. The default is :code:`True`. """ # n_iter, n_state = g_vecs.shape[0], g_vecs.shape[1] g_vecs = np.transpose(g_vecs) @@ -836,14 +895,12 @@ def plot_g_vecs(g_vecs, refs=None, refs_err=None, plot_rmse=True): def get_swaps(REXEE_log='run_REXEE_log.txt'): """ - For each replica, identifies the states involved in proposed and accepted. - (Todo: We should be able to only use :code:`rep_trajs.npy` and :code:`state_trajs.npy` - instead of parsing the REXEE log file to reach the same goal.) + For each replica, identifies the states involved in proposed and accepted swaps. Parameters ---------- - REXEE_log : str - The output log file of the REXEE simulation. + REXEE_log : str, Optional + The output log file of the REXEE simulation. The default is :code:`'run_REXEE_log.txt'`. Returns ------- @@ -857,6 +914,29 @@ def get_swaps(REXEE_log='run_REXEE_log.txt'): each replica. Each dictionary (corresponding to one replica) have keys being the global state indices and values being the number of accepted swaps that involved the state indicated by the key. + + Example + ------- + Below is an example based on a REXEE simulations having four replicas sampling states + 0-4, 1-5, 2-6 and 3-7, respectively. + + >>> from ensemble_md.analysis import analyze_traj + >>> proposed_swaps, accepted_swaps = analyze_traj.get_swaps('run_REXEE_log.txt') + >>> for i in range(len(proposed_swaps)): + >>> print(proposed_swaps[i]) + {0: 0, 1: 3, 2: 1, 3: 0, 4: 0} + {1: 2, 2: 2, 3: 0, 4: 1, 5: 1} + {2: 3, 3: 3, 4: 2, 5: 0, 6: 0} + {3: 0, 4: 1, 5: 0, 6: 3, 7: 0} + + todo + ---- + We should be able to only use :code:`rep_trajs.npy` and :code:`state_trajs.npy` + instead of parsing the REXEE log file to reach the same goal. + + See also + -------- + :func:`.plot_swaps` """ f = open(REXEE_log, 'r') lines = f.readlines() @@ -903,14 +983,18 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None): swaps : list A list of dictionaries showing showing the number of swaps for each state for each replica. This list could be either of the outputs from :obj:`.get_swaps`. - swap_type : str - The value should be either :code:`'accepted'` or :code:`'proposed'`. This value - will only influence the name of y-axis and the output file name. - stack : bool - Whether to stack the histograms. - figsize : tuple + swap_type : str, Optional + The type of swaps to be plotted. Common options include :code:`'accepted'` and :code:`'proposed'`. + This value will only influence the name of y-axis and the output file name. The default is an empty string. + stack : bool, Optional + Whether to stack the histograms. The default is :code:`True`. + figsize : tuple, Optional A tuple specifying the length and width of the output figure. The default is :code:`(6.4, 4.8)` for cases having less than 30 states and :code:`(10, 4.8)` otherwise. + + See also + -------- + :func:`.get_swaps` """ n_sim = len(swaps) n_states = max(max(d.keys()) for d in swaps) + 1 @@ -985,30 +1069,31 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None): def get_g_evolution(log_files, start_state, end_state, avg_frac=0, avg_from_last_update=False): """ - For weight-updating simulations, gets the time series of the alchemical + For a weight-updating simulation, gets the time series of the alchemical weights of all states. Note that this funciton is only suitable for analyzing - either a single expanded ensemble simulation or a replica in a REXEE simulation - (given all the log files for the replica). + either a single expanded ensemble simulation or a replica in a REXEE simulation. + For the latter case, all the log files for the replica should be provided. Parameters ---------- log_files : list - The list of log file names. If multiple log files are provided (for a REXEE) - simulations, please make sure the files are in the correct order. + The list of file paths to the log file(s). If multiple log files are provided (for a REXEE simulation), + please make sure the files are in the correct order such that the time series of the alchemical weights + are continuous. start_state : int The index of the first state of interest. The index starts from 0. end_state : int The index of the last state of interest. The index start from 0. For example, if :code:`start_state` is set to 1 and :code:`end_state` is set to 3, then the weight evolution for states 1, 2 and 3 will be extracted. - avg_frac : float + avg_frac : float, Optional The fraction of the last part of the simulation to be averaged. The default is 0, which means no averaging. Note that this parameter is ignored if :code:`avg_from_last_update` is :code:`True`. - avg_from_last_update : bool - Whether to average from the last update of wl-delta. If this option is set to False, - or the option is set to True but the wl-delta was not updated in the provided log - file(s), the all weights will be used for averging. + avg_from_last_update : bool, Optional + Whether to average from the last update of the Wang-Landau incrementor. If this option is set to + :code:`False`, or the option is set to :code:`True` but the Wang-Landau incrementor was not updated + in the provided log file(s), the all weights will be used for averging. Returns ------- @@ -1025,6 +1110,12 @@ def get_g_evolution(log_files, start_state, end_state, avg_frac=0, avg_from_last The errors of the alchemical weights of all states averaged over the last part of the simulation. If :code:`avg_frac` is 0 and :code:`avg_from_last_update` is :code:`False`, :code:`None` will be returned. Note that weights after equilibration are not considered. + + Example + ------- + >>> from ensemble_md import analyze_traj + >>> log_files = ['EXE.log'] # For analyzing a single expanded ensemble simulation + >>> results = analyze_traj.get_g_evolution(log_files, start_state=0, end_state=6) """ g_vecs_all = [] idx_updates = [] # the indices of the data points corresponding to the updates of wl-delta @@ -1092,14 +1183,15 @@ def get_g_evolution(log_files, start_state, end_state, avg_frac=0, avg_from_last def get_dg_evolution(log_files, start_state, end_state): """ - For weight-updating simulations, gets the time series of the weight - difference (:math:`Δg = g_2-g_1`) between the specified states. + For a weight-updating simulation, gets the time series of the weight + difference (:math:`Δg = g_2-g_1`) between the states of interest. Parameters ---------- log_files : list - The list of log file names. If multiple log files are provided (for a REXEE) - simulations, please make sure the files are in the correct order. + The list of file paths to the log file(s). If multiple log files are provided (for a REXEE simulation), + please make sure the files are in the correct order such that the time series of the alchemical weights + are continuous. start_state : int The index of the state (starting from 0) whose weight is :math:`g_1`. end_state : int @@ -1108,7 +1200,7 @@ def get_dg_evolution(log_files, start_state, end_state): Returns ------- dg : list - A list of :math:`Δg` values. + The time series of :math:`Δg`. """ # N_states = end_state - start_state + 1 # number of states for the range of insterest g_vecs, _, _ = get_g_evolution(log_files, start_state, end_state) @@ -1119,23 +1211,25 @@ def get_dg_evolution(log_files, start_state, end_state): def plot_dg_evolution(log_files, start_state, end_state, start_idx=None, end_idx=None, dt_log=2): """ - For weight-updating simulations, plots the time series of the weight - difference (:math:`Δg = g_2-g_1`) between the specified states. + For a weight-updating simulation, plots the time series of the weight + difference (:math:`Δg = g_2-g_1`) between the states of interest. Parameters ---------- log_files : list - The list of log file names. + The list of file paths to the log file(s). If multiple log files are provided (for a REXEE simulation), + please make sure the files are in the correct order such that the time series of the alchemical weights + are continuous. start_state : int The index of the state (starting from 0) whose weight is :math:`g_1`. end_state : int The index of the state (starting from 0) whose weight is :math:`g_2`. - start_idx : int - The index of the first frame to be plotted. - end_idx : int - The index of the last frame to be plotted. - dt_log : float - The time interval between two consecutive frames in the log file. The + start_idx : int, Optional + The index of the first frame to be plotted. The default is :code:`None`, which means the first frame. + end_idx : int, Optional + The index of the last frame to be plotted. The default is :code:`None`, which means the last frame. + dt_log : float, Optional + The time interval (in ps) between two consecutive frames in the log file. The default is 2 ps. """ dg = get_dg_evolution(log_files, start_state, end_state) @@ -1160,25 +1254,26 @@ def plot_dg_evolution(log_files, start_state, end_state, start_idx=None, end_idx def get_delta_w_updates(log_file, plot=False): """ - Parses a log file of a weight-updating simulation and identifies the - time frames when the Wang-Landau incrementor is updated. + Parses the log file of a weight-updating simulation and identifies the + time frames when the Wang-Landau incrementor was updated. Parameters ---------- log_file : str - The name of the log file. - plot : bool + The file path to the LOG file. + plot : bool, Optional Whether to plot the Wang-Landau incrementor as a function of time. + The default is :code:`False`. Returns ------- t_updates : list - A list of time frames (in ns) when the Wang-Landau incrementor is updated. + A list of time frames (in ns) when the Wang-Landau incrementor was updated. delta_w_updates : list A list of the updated Wang-Landau incrementors. Should be the same length as :code:`t_updates`. equil : bool - Whether the weights have been equilibrated. + Whether the weights got equilibrated during the simulation. """ f = open(log_file, "r") lines = f.readlines() diff --git a/ensemble_md/analysis/clustering.py b/ensemble_md/analysis/clustering.py index 337264ed..51f9f1db 100644 --- a/ensemble_md/analysis/clustering.py +++ b/ensemble_md/analysis/clustering.py @@ -16,7 +16,8 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkage', cutoff=0.1, suffix=None): """ Performs clustering analysis on a trajectory using the GROMACS command :code:`gmx cluster`. - Note that only fully coupled configurations are considered. + Note that this function encompasses the use of all the other functions in the module, including + :func:`get_cluster_info`, :func:`get_cluster_members`, and :func:`analyze_transitions`. Parameters ---------- @@ -27,24 +28,56 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkag The dictionary must have the following four keys: :code:`traj` (input trajectory file in XTC or TRR format), :code:`config` (the configuration file in TPR or GRO format), :code:`xvg` (a GROMACS XVG file), and :code:`index` (an index/NDX file), with the values - being the paths. Note that the value of the key :code:`index` can be :code:`None`, in which + being the paths to the files. Note that the value of the key :code:`index` can be :code:`None`,in which case the function will use a default index file generated by :code:`gmx make_ndx`. If the parameter :code:`coupled_only` is set to :code:`True`, an XVG file that contains the time series of the state index (e.g., :code:`dhdl.xvg`) must be provided with the key :code:`xvg`. Otherwise, the key :code:`xvg` can be set to :code:`None`. grps : dict A dictionary that contains the names of the groups in the index file (NDX) for - centering the system, calculating the RMSD, and outputting. The keys are + centering the system, calculating the RMSD, and outputting. The corresponding keys are :code:`center`, :code:`rmsd`, and :code:`output`. - coupled_only : bool - Whether to consider only the fully coupled configurations. The default is :code:`True`. - method : str - The method for clustering available for the GROMACS command :code:`gmx cluster`. The default is 'linkage'. - Check the GROMACS documentation for other available options. - cutoff : float + coupled_only : bool, Optional + Whether to only consider the fully coupled configurations. The default is :code:`True`. + method : str, Optional + The method for clustering available for the GROMACS command :code:`gmx cluster`. The default is + :code:`'linkage'`. Check the + `GROMACS documentation `_ + for other available options. + cutoff : float, Optional The RMSD cutoff for clustering in nm. The default is 0.1. - suffix : str + suffix : str, Optional The suffix for the output files. The default is :code:`None`, which means no suffix will be added. + + Example + ------- + Below is an example of performing a cluster analysis for all 4 replicas that compose of a REXEE simulation + of a host-guest binding complex. + + >>> import glob + >>> import natsort + >>> from ensemble_md.analysis.clustering import cluster_traj + >>> from ensemble_md.analysis.analyze_traj import stitch_trajs, convert_npy2xvg + >>> rep_trajs = np.load('rep_trajs.npy') # Usually genrated by the REXEE simulation + >>> state_trajs = np.load('state_trajs.npy') # Usually generated by analyze_traj.stitch_time_series + >>> files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*xtc')) for i in range(4)] + >>> stitch_trajs('gmx', files, rep_trajs) + >>> convert_npy2xvg(state_trajs, 0.2, subsampling=10) + >>> for i in range(4): + >>> print() + >>> print(f'Performing clustering analysis for traj_{i}.xtc ...') + >>> inputs = { + >>> 'traj': f'traj_{i}.xtc', + >>> 'config': 'complex.gro', + >>> 'xvg': f'traj_{i}.xvg', + >>> 'index': 'complex.ndx' + >>> } + >>> grps = { + >>> 'center': 'HOS_MOL', + >>> 'rmsd': 'complex_heavy', + >>> 'output': 'HOS_MOL' + >>> } + >>> cluster_traj('gmx', inputs, grps, coupled_only=False, cutoff=0.13, suffix=f'{i}') """ # Check input parameters required_keys_1 = ['traj', 'config', 'xvg', 'index'] @@ -192,7 +225,8 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkag def get_cluster_info(cluster_log): """ - Gets the metadata of the LOG file generated by the GROMACS :code:`gmx cluster` command. + Extracts basic results from the clustering analysis by parsing the LOG file generated + by the GROMACS :code:`gmx cluster` command. Parameters ---------- @@ -238,11 +272,11 @@ def get_cluster_members(cluster_log): Returns ------- clusters : dict - A dictionary that contains the cluster index (starting from 1) as the key and the list of members - (configurations at different timeframes) as the value. + A dictionary that contains the cluster indices (starting from 1) as the keys and the lists of members + (represented by time frames) as the values. sizes : dict - A dictionary that contains the cluster index (starting from 1) as the key and the size of the cluster - (in fraction) as the value. + A dictionary that contains the cluster indices (starting from 1) as the keys and the sizes of the cluster + (in fraction) as the values. """ clusters = {} current_cluster = 0 @@ -284,22 +318,24 @@ def analyze_transitions(clusters, normalize=True, plot_type=None): Parameters ---------- clusters : dict - A dictionary that contains the cluster index (starting from 1) as the key and the list of members - (configurations at different timeframes in ps) as the value. - plot_type : str + A dictionary that contains the cluster indices (starting from 1) as the keys and the lists of members + (represented by time frames) as the values. + normalize : bool, Optional + Whether to normalize the output transition matrix. The default is :code:`True`. + plot_type : str, Optional The type of the figure to be plotted. The default is :code:`None`, which means no figure will be plotted. The other options are :code:`'bar'` and :code:`'xy'`. The former plots the distribution of the clusters, while the latter plots the trajectory showing which cluster each configuration belongs to. Returns ------- - transmtx: np.ndarray + transmtx: numpy.ndarray The transition matrix. - traj: np.ndarray + traj: numpy.ndarray The trajectory showing which cluster each configuration belongs to. t_transitions: dict A dictionary with keys being pairs of cluster indices and values being the time frames of transitions - between the two clusters. If there is no transition, an empty dictionary will be returned. + between the two clusters. If there was no transition, an empty dictionary will be returned. """ # Combine all cluster members and sort them all_members = [] diff --git a/ensemble_md/analysis/msm_analysis.py b/ensemble_md/analysis/msm_analysis.py index 9455ff4a..c95986e0 100644 --- a/ensemble_md/analysis/msm_analysis.py +++ b/ensemble_md/analysis/msm_analysis.py @@ -8,7 +8,8 @@ # # #################################################################### """ -The :obj:`.msm_analysis` module provides analysis methods based on Markov state models. +The :obj:`.msm_analysis` module provides analysis methods based on Markov state models +built from REXEE simulations. """ import pyemma import ruptures as rpt @@ -19,16 +20,16 @@ def plot_acf(models, n_tot, fig_name): """ - Plots the state index autocorrelation times for all configurations in a single plot + Plots the state index autocorrelation times for all configurations in a single plot. Parameters ---------- models : list A list of MSM models (built by PyEMMA) that have the :code:`correlation` method. n_tot : int - The total number of states (whole range). + The total number of states across all replicas. fig_name : str - The file name of the png file to be saved (with the extension). + The file path to save the figure. """ plt.figure() for i in range(len(models)): @@ -45,7 +46,7 @@ def plot_acf(models, n_tot, fig_name): def plot_its(trajs, lags, fig_name, dt=1, units='step'): """ - Plots the implied timescales as a function of lag time for all configurations + Plots the implied timescales (ITS) as a function of lag time for all configurations in a subplot. Parameters @@ -55,16 +56,16 @@ def plot_its(trajs, lags, fig_name, dt=1, units='step'): lags : list A list of lag times to examine. fig_name : str - The file name of the png file to be saved (with the extension). + The file path to save the figure. dt : float Physical time between frames. The default is 1. units : str - The units of dt. The default is 'ps'. + The units of :code:`dt`. The default is :code:`'step'`. Returns ------- ts_list : list - An list of instances of the :code:`ImpliedTimescales` class in PyEMMA. + A list of instances of the :code:`ImpliedTimescales` class in PyEMMA. """ ts_list = [] n_rows, n_cols = utils.get_subplot_dimension(len(trajs)) @@ -90,21 +91,21 @@ def plot_its(trajs, lags, fig_name, dt=1, units='step'): def decide_lagtimes(ts_list): """ This function automatically estimates a lagtime for building an MSM for each configuration. - Specifically, the lag time will be estimated by the change point detection enabled by - ruptures for each (n-1) timescales (where n is the number of states). A good lag time - should be long enough such that the timescale is roughly constant but short enough to be - smaller than all timescales. If no lag time is smaller than all timescales, then a - warning will be printed and a lag time of 1 will be returned in chosen_lags. + Specifically, the lag time will be estimated using change-point detection enabled by + :code:`ruptures` for each (:math:`n-1`) timescales (where :math:`n` is the number of states). + A good lag time should be long enough such that the timescale is roughly constant but short + enough to be smaller than all timescales. If no lag time is smaller than all timescales, then a + warning will be printed and a lag time of 1 will be returned. Parameters ---------- ts_list : list - An list of instances of the ImpliedTimescales class in PyEMMA. + A list of instances of the :code:`ImpliedTimescales` class in PyEMMA. Returns ------- chosen_lags: list - A list of lag time automatically determined for each configuration. + A list of lag times automatically determined for each configuration. """ # Workflow: first find the timescales larger than the corressponding lag times, # then perform change change detection. diff --git a/ensemble_md/analysis/synthesize_data.py b/ensemble_md/analysis/synthesize_data.py index bf43281d..bb1cc6ef 100644 --- a/ensemble_md/analysis/synthesize_data.py +++ b/ensemble_md/analysis/synthesize_data.py @@ -8,7 +8,10 @@ # # #################################################################### """ -The :obj:`.synthesize_data` module provides methods for synthesizing REXEE data. +The :obj:`.synthesize_data` module provides methods for synthesizing REXEE data, +specifically trajectories and transition matrices. This is mainly useful +for carrying out the bootstrap analysis for some analysis, such as the spectral gap calculation in +:obj:`.analyze_matrix`. """ import numpy as np from ensemble_md.analysis import analyze_traj @@ -17,33 +20,33 @@ def synthesize_traj(trans_mtx, n_frames=100000, method='transmtx', start=0, seed=None): """ - Synthesize a trajectory based on the input transition matrix. + Synthesizes a trajectory based on the input transition matrix. Parameters ---------- - trans_mtx: np.ndarray + trans_mtx: numpy.ndarray The input transition matrix. - n_frames: int + n_frames: int, Optional The number of frames to be generated. The default value is 100000. - method: str - The method to be used for trajectory synthesis. It can be either 'transmtx' or 'equil_prob'. + method: str, Optional + The method to be used for trajectory synthesis. It can be either :code:`'transmtx'` or :code:`'equil_prob'`. The former refers to generating the trajectory by simulating the moves between states based on the input transition matrix, with the trajectory starting from the state specified by the :code:`start` parameter. If the method is :code:`equil_prob`, the trajectory will be generated by simply sampling from the equilibrium - probability distribution calculated from the input transition matrix. The method 'transmtx' should + probability distribution calculated from the input transition matrix. The method :code:`'transmtx'` should generate a trajectory characterized by a transition matrix similar to the input one, while the method - 'equil_prob' may generate a trajectory that has a significantly different transition matrix. Still, - a trajectory generated by either method should a similar underlying equilibrium probability distribution - (hence the spectral gap as well) as the input transition matrix. The default value is 'transmtx'. - start: int + :code:`'equil_prob'` may generate a trajectory that has a significantly different transition matrix. Still, + a trajectory generated by either method should have a similar underlying equilibrium probability distribution + (hence the spectral gap as well) as the input transition matrix. The default value is :code:`'transmtx'`. + start: int, Optional The starting state of the synthesized trajectory if the method is :code:`transmtx`. The default value is 0, i.e., the first state. This parameter is ignored if the method is :code:`equil_prob`. - seed: int - The seed for the random number generator. The default value is None, i.e., the seed is not set. + seed: int, Optional + The seed for the random number generator. The default value is :code:`None`, i.e., the seed will not be set. Returns ------- - syn_traj: np.ndarray + syn_traj: numpy.ndarray The synthesized trajectory. """ if seed is not None: @@ -79,26 +82,28 @@ def synthesize_traj(trans_mtx, n_frames=100000, method='transmtx', start=0, seed def synthesize_transmtx(trans_mtx, n_frames=100000, seed=None): """ Synthesizes a normalized transition matrix similar to the input transition matrix by first - generating a trajectory using :code:`synthesize_traj` with :code:`method='transmtx'` and then - calculating the transition matrix from the synthesized trajectory. + generating a trajectory using :obj:`synthesize_traj` with :code:`method='transmtx'` and then + calculating the transition matrix from the synthesized trajectory. This function can be + useful in performing boostrapping analysis when estimating the uncertainty of the spectral gap + associated with the input transition matrix. Parameters ---------- - trans_mtx: np.ndarray + trans_mtx: numpy.ndarray The input transition matrix. - n_frames: int + n_frames: int, Optional The number of frames of the synthesized trajectory from which the mock transition matrix is calculated. The default value is 100000. - seed: int - The seed for the random number generator. The default value is None, i.e., the seed is not set. + seed: int, Optional + The seed for the random number generator. The default value is :code:`None`, i.e., the seed will not be set. Returns ------- - syn_mtx: np.ndarray + syn_mtx: numpy.ndarray The synthesized transition matrix. - syn_traj: np.ndarray + syn_traj: numpy.ndarray The synthesized trajectory from which the transition matrix is calculated. - diff_mtx: np.ndarray + diff_mtx: numpy.ndarray The input transition matrix subtracted by the synthesized transition matrix. """ N = len(trans_mtx) # can be the number of states or number of replicas depending on mtx_type diff --git a/ensemble_md/cli/explore_REXEE.py b/ensemble_md/cli/explore_REXEE.py index e83334f4..886e798c 100644 --- a/ensemble_md/cli/explore_REXEE.py +++ b/ensemble_md/cli/explore_REXEE.py @@ -58,22 +58,22 @@ def solv_REXEE_diophantine(N, constraint=False): """ Solves the general nonlinear Diophantine equation associated with the homogeneous REXEE parameters. Specifically, given the total number of states :math:`N` and the number of replicas - r, the states for each replica n and the state shift s can be expressed as: - n = N + (r-1)(t-1), and s = 1 - t, with the range of t being either the following: - - Without the additional constraint, (r-N+1)/r <= t <= 0 - - With the additional constraint, (r-N+1)/r <= t <= (r-N+1)/(r+1) + :math:`R`, the number of states for each replica :math:`n_s` and the state shift :math:`ϕ` can be expressed as: + :math:`n_s = N + (R-1)(ϕ-1)`, and :math:`ϕ = 1 - t`, with the range of :math:`t` being either the following: + - Without the additional constraint, :math:`(r-N+1)/r ≤ t ≤ 0` + - With the additional constraint, :math:`(r-N+1)/r ≤ t ≤ (r-N+1)/(r+1)` Parameters ---------- N : int The total number of states of the homogeneous REXEE of interesst. constraint : bool - Whether to apply additional constraints such that n-s <= 1/2n. + Whether to apply additional constraints such that :math:`n-s ≤ 1/2n`. Returns ------- soln_all : pd.DataFrame - A pandas DataFrame that lists all the solutions of (N, r, n, s). + A pandas DataFrame that lists all the solutions of :math:`(N, R, n_s, ϕ)`. """ soln_all = [] # [N, r, n, s] r_list = range(2, N) @@ -103,7 +103,7 @@ def estimate_swapless_rate(state_ranges, N=1000000): Parameters ---------- state_ranges : list - A list of lists of state indices. (Like the attribute :code:`state_ranges` in :code:`EnsemblEXE`.) + A list of lists of state indices. (Like the attribute :code:`state_ranges` in :obj:`ReplicaExchangeEE`.) N : n The number of Monte Carlo iterations for the estimation. diff --git a/ensemble_md/replica_exchange_EE.py b/ensemble_md/replica_exchange_EE.py index 517684f8..4104a218 100644 --- a/ensemble_md/replica_exchange_EE.py +++ b/ensemble_md/replica_exchange_EE.py @@ -9,7 +9,7 @@ #################################################################### """ The :obj:`.replica_exchange_EE` module provides functions for setting up and -replica exchange and expanded ensemble (REXEE) simulations. +performing replica exchange and expanded ensemble (REXEE) simulations. """ import os import sys @@ -37,38 +37,39 @@ class ReplicaExchangeEE: """ - This class provides a variety of functions useful for setting up and running - replica exchange (REX) of expanded ensemble (EE), or REXEE simulations. - Upon instantiation, all parameters in the YAML file will be assigned to an + A class that provides a variety of functions useful for setting up and running + a replica exchange (REX) of expanded ensemble (EE) simulation, or a REXEE simulation. + Upon instantiation, all parameters in the input YAML file will be assigned to an attribute in the class. In addition to these variables, below is a list of attributes of the class. (All the the attributes are assigned by :obj:`set_params` - unless otherwise noted.) + except that :code:`yaml` is assigned by :code:`__init__`.) :ivar gmx_path: The absolute path of the GROMACS exectuable. :ivar gmx_version: The version of the GROMACS executable. - :ivar yaml: The input YAML file used to instantiate the class. Assigned by the :code:`__init__` function. + :ivar yaml: The input YAML file used to instantiate the class. The file should contain necessary REXEE parameters. + For more details, please check the :ref:`doc_parameters`. :ivar warnings: Warnings about parameter specification in either YAML or MDP files. - :ivar reformatted_mdp: Whether the templated MDP file has been reformatted by replacing hyphens + :ivar template: The template MDP file on which the instance of the :obj:`MDP` class is based. + :ivar reformatted_mdp: Whether the template MDP file has been reformatted by replacing hyphens with underscores or not. - :ivar template: The instance of the :obj:`MDP` class based on the template MDP file. :ivar dt: The simulation timestep in ps. :ivar temp: The simulation temperature in Kelvin. - :ivar fixed_weights: Whether the weights will be fixed during the simulation (according to the template MDP file). + :ivar fixed_weights: Whether the weights will be fixed during the simulation. :ivar updating_weights: The list of weights as a function of time (since the last update of the Wang-Landau incrementor) for different replicas. The length is equal to the number of replicas. This is only relevant for weight-updating simulations. - :ivar equilibrated_weights: The equilibrated weights of different replicas. For weight-equilibrating simulations, - this list is initialized as a list of empty lists. Otherwise (weight-fixed), it is initialized as a list of - :code:`None`. + :ivar equilibrated_weights: The equilibrated weights of different replicas. For weight-updating simulations, + this list is initialized as a list of empty lists. Otherwise (i.e., in fixed-weight simulations), it is + initialized as a list of :code:`None`. :ivar current_wl_delta: The current value of the Wang-Landau incrementor. This is only relevent for weight-updating simulations. :ivar kT: 1 kT in kJ/mol at the simulation temperature. - :ivar lambda_types: The types of lambda variables involved in expanded ensemble simulations, e.g. + :ivar lambda_types: The types of lambda variables involved in expanded ensemble simulations, e.g., :code:`fep_lambdas`, :code:`mass_lambdas`, :code:`coul_lambdas`, etc. :ivar n_tot: The total number of states for all replicas. - :ivar n_sub: The numbmer of states for each replica. The current implementation assumes homogenous replicas. - :ivar state_ranges: A list of list of state indices for each replica. - :ivar equil: A list of times it took to equilibrated the weights for different replicas. This + :ivar n_sub: The numbmer of states of each replica. The current implementation assumes homogenous replicas. + :ivar state_ranges: A list of list of (global) state indices for each replica. + :ivar equil: A list of times it took to equilibrate the weights for different replicas. This list is initialized with a list of -1, where -1 means that the weights haven't been equilibrated. Also, a value of 0 means that the simulation is a fixed-weight simulation. :ivar n_rejected: The number of proposed exchanges that have been rejected. Updated by :obj:`.accept_or_reject`. @@ -77,12 +78,14 @@ class ReplicaExchangeEE: :ivar n_emtpy_swappable: The number of times when there was no swappable pair. :ivar rep_trajs: The replica-space trajectories of all configurations. :ivar configs: A list that thows the current configuration index that each replica is sampling. - :ivar g_vecs: The time series of the (processed) whole-range alchemical weights. If no weight combination is - applied, this list will just be a list of :code:`None`'s. + :ivar g_vecs: The time series of processed (e.g., combined across replicas) alchemical weights for the entire state + space. If no weight combination scheme is applied, this list will just be a list of :code:`None`'s. :ivar df_data_type: The type of data (either :math:`u_{nk}` or :math:`dH/dλ`) that will be used for - free energy calculations if :code:`df_method` is :code:`True`. - :ivar modify_coords_fn: The function (callable) in the external module (specified as :code:`modify_coords` in - the input YAML file) for modifying coordinates at exchanges. + free energy calculations. This depends on the free energy estimator specified in the parameter + :code:`df_method`. + :ivar modify_coords_fn: The function (callable) in an external module (specified as :code:`modify_coords` in + the input YAML file) for modifying coordinates at exchanges. This parameter is only relevant to + multi-topology REXEE (i.e., MT-REXEE) simulations. """ def __init__(self, yaml_file, analysis=False): @@ -91,34 +94,34 @@ def __init__(self, yaml_file, analysis=False): def set_params(self, analysis): """ - Sets up or reads in the user-defined parameters from a yaml file and an MDP template. + Sets up or reads in the user-defined parameters from an input YAML file and an MDP template. This function is called to instantiate the class in the :code:`__init__` function of class. Specifically, it does the following: 1. Sets up constants. - 2. Reads in parameters from a YAML file. + 2. Reads in REXEE parameters from a YAML file. 3. Handles YAML parameters. 4. Checks if the parameters in the YAML file are well-defined. 5. Reformats the input MDP file to replace all hyphens with underscores. 6. Reads in parameters from the MDP template. - After instantiation, the class instance will have attributes corresponding to + After instantiation, the class instance will have an attribute corresponding to each of the parameters specified in the YAML file. For a full list of the parameters that can be specified in the YAML file, please refer to :ref:`doc_parameters`. - :param yaml_file: The file name of the YAML file for specifying the parameters for REXEE. + :param yaml_file: The file path of the input YAML file that specifies REXEE parameters. :type yaml_file: str :param analysis: Whether the instantiation of the class is for data analysis of REXEE simulations. - The default is :code:`False` - :type analysis: bool + The default is :code:`False`. + :type analysis: bool, Optional :raises ParameterError: - - If a required parameter is not specified in the YAML file. + - If a required parameter is not specified in the input YAML file. - If a specified parameter is not recognizable. - If a specified option is not available for a parameter. - If the data type or range (e.g., positive or negative) of a parameter is not correct. - - If an invalid MDP file is detected. + - If any MDP parameter invalid for the REXEE simulation is detected. """ self.warnings = [] # Store warnings, if any. @@ -468,12 +471,12 @@ def check_gmx_executable(self): def print_params(self, params_analysis=False): """ - Prints important parameters related to the EXEE simulation. + Prints important parameters relevant to the REXEE simulation to be performed. Parameters ---------- - params_analysis : bool, optional - If True, additional parameters related to data analysis will be printed. Default is False. + params_analysis : bool, Optional + Whether additional parameters for data analysis should be printed. The default is :code:`False`. """ if isinstance(self.gro, list): gro_str = ', '.join(self.gro) @@ -485,8 +488,8 @@ def print_params(self, params_analysis=False): else: top_str = self.top - print("Important parameters of EXEE") - print("============================") + print("Important parameters of REXEE") + print("=============================") print(f"Python version: {sys.version}") print(f"GROMACS executable: {self.gmx_path}") # we print the full path here print(f"GROMACS version: {self.gmx_version}") @@ -538,7 +541,12 @@ def reformat_MDP(mdp_file): will be set to :code:`True`. In this case, the new MDP object with reformatted parameter names will be written to the original file path of the file, while the original file will be renamed with a :code:`_backup` suffix. If the input MDP file is not reformatted, the function sets - the class attribute :code:`self.reformatted_mdp` to :code:`False`. + the attribute :code:`self.reformatted_mdp` to :code:`False`. + + Parameters + ---------- + mdp_file : str + The file path of the MDP file to be reformatted. Returns ------- @@ -564,10 +572,10 @@ def reformat_MDP(mdp_file): def initialize_MDP(self, idx): """ - Initializes the MDP object for generating MDP files for a replica based on the MDP template. + Initializes the MDP object for generating an MDP file for a specific replica based on the MDP template. This function should be called only for generating MDP files for the FIRST iteration and it has nothing to do with whether the weights are fixed or equilibrating. - It is assumed that the MDP template has all the common parameters of all replicas. + It is assumed that the MDP template has all the parameters shared by all replicas. Parameters ---------- @@ -577,7 +585,7 @@ def initialize_MDP(self, idx): Returns ------- MDP : :obj:`.gmx_parser.MDP` obj - An updated object of :obj:`.gmx_parser.MDP` that can be used to write MDP files. + A :obj:`.gmx_parser.MDP` object that can be used to write the MDP file. """ MDP = copy.deepcopy(self.template) MDP["nsteps"] = self.nst_sim @@ -597,13 +605,13 @@ def initialize_MDP(self, idx): def get_ref_dist(self, pullx_file=None): """ - Gets the reference distance(s) to use starting from the second iteration if distance restraint(s) are used. - Specifically, a reference distance determined here is the initial COM distance between the pull groups - in the input GRO file. This function initializes the attribute :code:`ref_dist`. + Gets the initial COM distance between the pull groups in the input GRO file. Importantly, this distance + will serve as the reference distance starting from the second iteration. This function initializes the + attribute :code:`ref_dist` and is only relevant when a distance restraint is applied in the GROMACS pull code. - Parameter - --------- - pullx_file : str + Parameters + ---------- + pullx_file : str, Optional The path to the pullx file whose initial value will be used as the reference distance. Usually, this should be the path of the pullx file of the first iteration. The default is :code:`sim_0/iteration_0/pullx.xvg`. @@ -620,14 +628,14 @@ def get_ref_dist(self, pullx_file=None): def update_MDP(self, new_template, sim_idx, iter_idx, states, wl_delta, weights, counts=None): """ - Updates the MDP file for a new iteration based on the new MDP template coming from the previous iteration. - Note that if the weights got equilibrated in the previous iteration, then the weights will be fixed - at these equilibrated values for all the following iterations. + Updates the MDP file for a new iteration based on the new MDP template, which is the MDP file + from the previous iteration. Note that if the weights got equilibrated in the previous iteration, + the weights will be fixed at these equilibrated values for all the following iterations. Parameters ---------- new_template : str - The new MDP template file. Typically the MDP file of the previous iteration. + The new MDP template file, which typically is the MDP file of the previous iteration. sim_idx : int The index of the simulation whose MDP parameters need to be updated. iter_idx : int @@ -635,20 +643,18 @@ def update_MDP(self, new_template, sim_idx, iter_idx, states, wl_delta, weights, states : list A list of last sampled states of all simulaitons in the previous iteration. wl_delta : list - A list of final Wang-Landau incrementors of all simulations. + A list of fina Wang-Landau incrementors of all simulations. weights : list A list of lists final weights of all simulations. - counts : list + counts : list, Optional A list of lists final counts of all simulations. If the value is :code:`None`, then the MDP parameter :code:`init-histogram-counts` won't be specified in the next iteration. - Note that not all the GROMACS versions have the MDP parameter :code:`init-histogram-counts` available, - in which case one should always pass :code:`None`, or set :code:`-maxwarn` in :code:`grompp_args` - in the input YAML file. + Note that this parameter is only supported by GROMACS with versions later than 2022.3. Return ------ MDP : :obj:`.gmx_parser.MDP` obj - An updated object of :obj:`.gmx_parser.MDP` that can be used to write MDP files. + A :obj:`.gmx_parser.MDP` object that can be used to write the MDP file. """ new_template = gmx_parser.MDP(new_template) # turn into a gmx_parser.MDP object MDP = copy.deepcopy(new_template) @@ -680,18 +686,18 @@ def update_MDP(self, new_template, sim_idx, iter_idx, states, wl_delta, weights, def extract_final_dhdl_info(self, dhdl_files): """ - For all the replica simulations, finds the last sampled state - and print the corresponding lambda values from a dhdl file. + Extracts the last sampled states for all replica simulations. Parameters ---------- dhdl_files : list - A list of file paths to GROMACS DHDL files of different replicas. + A list of file paths to GROMACS DHDL files of different replicas. Note that + the order of the files should be consistent with the order of the replicas. Returns ------- states : list - A list of the global indices of the last sampled states of all simulaitons. + A list of the global state indices of the last sampled states of all simulaitons. """ states = [] print("\nBelow are the final states being visited:") @@ -712,7 +718,8 @@ def extract_final_log_info(self, log_files): - The final Wang-Landau incrementors. - The final lists of weights. - The final lists of counts. - - Whether the weights were equilibrated in the simulations. + + Note that the order of the files should be consistent with the order of the replicas. Parameters ---------- @@ -761,19 +768,19 @@ def extract_final_log_info(self, log_files): @staticmethod def identify_swappable_pairs(states, state_ranges, neighbor_exchange, add_swappables=None): """ - Identify swappable pairs. By definition, a pair of simulation is considered swappable only if + Identifies swappable pairs. By definition, a pair of simulation is considered swappable only if their last sampled states are in the alchemical ranges of both simulations. This is required to ensure that the values of involved ΔH and Δg can always be looked up from the DHDL and LOG files. - This also automatically guarantee that the simulations to be swapped have overlapping alchemical ranges. + This also automatically guarantees that the simulations to be swapped have overlapping state sets. Parameters ---------- states : list - A list of the global indices of the last sampled states of all simulations. This list can be - generated by the :obj:`.extract_final_dhdl_info` method. Notably, the input list should not be + A list of the global state indices of the last sampled states of all simulations. This list can be + generated by :obj:`.extract_final_dhdl_info`. Notably, the input list should not be a list that has been updated/modified by :obj:`get_swapping_pattern`, or the result will be incorrect. state_ranges : list of lists - A list of state indies for all replicas. The input list can be a list updated by + A list of global state indices for all replicas. The input list can be a list updated by :obj:`.get_swapping_pattern`, especially in the case where there is a need to re-identify the swappable pairs after an attempted swap is accepted. neighbor_exchange : bool @@ -782,12 +789,27 @@ def identify_swappable_pairs(states, state_ranges, neighbor_exchange, add_swappa A list of lists that additionally consider states (in global indices) that can be swapped. For example, :code:`add_swappables=[[4, 5], [14, 15]]` means that if a replica samples state 4, it can be swapped with another replica that samples state 5 and vice versa. The same logic applies - to states 14 and 15. + to states 14 and 15. This parameter is only relevant to MT-REXEE simulations. Returns ------- swappables : list A list of tuples representing the simulations that can be swapped. + + Example + ------- + Below is an example where the REXEE simulation is composed of four replicas sampling states 0-3, 1-4, + 2-5, and 3-6, respectively. At exchanges, these replicas are respectively at states 2, 3, 3, and 4. + Therefore, the swappable pairs are [(0, 1), (1, 2), (2, 3)]. If only neighboring swaps are considered, + the swappable pairs will be [(1, 2), (2, 3)]. + + >>> from ensemble_md.replica_exchange_EE import ReplicaExchangeEE as REXEE + >>> states = [2, 3, 2, 5] + >>> state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]] + >>> REXEE.identify_swappable_pairs(states, state_ranges, neighbor_exchange=False) + [(0, 1), (1, 2), (2, 3)] + >>> REXEE.identify_swappable_pairs(states, state_ranges, neighbor_exchange=True) + [(1, 2), (2, 3)] """ n_sim = len(states) sim_idx = list(range(n_sim)) @@ -810,7 +832,6 @@ def identify_swappable_pairs(states, state_ranges, neighbor_exchange, add_swappa swappables.append(pair) if neighbor_exchange is True: - print('Note: One neighboring swap will be proposed.') swappables = [i for i in swappables if np.abs(i[0] - i[1]) == 1] return swappables @@ -818,7 +839,7 @@ def identify_swappable_pairs(states, state_ranges, neighbor_exchange, add_swappa @staticmethod def propose_swap(swappables): """ - Proposes a swap of coordinates between replicas by drawing samples from the swappable pairs. + Proposes a swap of coordinates between replicas by drawing a pair from the list of swappable pairs. Parameters ---------- @@ -847,7 +868,7 @@ def get_swapping_pattern(self, dhdl_files, states): sample configurations 0, 2, 1, 3, respectively, where configurations 0, 1, 2, 3 here are defined as whatever configurations are in replicas 0, 1, 2, 3 in the CURRENT iteration (not iteration 0), respectively. - Notably, when this function is called (e.g. once every iteration in an REXEE simulation), the output + Notably, when this function is called (e.g., once every iteration in a REXEE simulation), the output list :code:`swap_pattern` is always initialized as :code:`[0, 1, 2, 3, ...]` and gets updated once every attempted swap. This is different from the attribute :code:`configs`, which is only initialized at the very beginning of the entire REXEE simulation (iteration 0), though :code:`configs` also gets updated with @@ -856,10 +877,10 @@ def get_swapping_pattern(self, dhdl_files, states): Parameters ---------- dhdl_files : list - A list of DHDL files. The indicies in the DHDL filenames shouuld be in an ascending order, e.g. + A list of paths to the DHDL files. The indicies in the DHDL filenames should be in an ascending order, e.g. :code:`[dhdl_0.xvg, dhdl_1.xvg, ..., dhdl_N.xvg]`. states : list - A list of last sampled states (in global indices) of ALL simulaitons. :code:`states[i]=j` means that + A list of last sampled states (in global indices) of ALL simulations. :code:`states[i]=j` means that the configuration in replica :code:`i` is at state :code:`j` at the time when the exchange is performed. This list can be generated :obj:`.extract_final_dhdl_info`. @@ -982,7 +1003,7 @@ def get_swapping_pattern(self, dhdl_files, states): def calc_prob_acc(self, swap, dhdl_files, states, shifts): """ - Calculates the acceptance ratio given the Monte Carlo scheme for swapping the simulations. + Calculates the acceptance ratio for swapping simulations. Parameters ---------- @@ -992,11 +1013,11 @@ def calc_prob_acc(self, swap, dhdl_files, states, shifts): A list of DHDL files, e.g. :code:`dhdl_files = ['dhdl_2.xvg', 'dhdl_1.xvg', 'dhdl_0.xvg', 'dhdl_3.xvg']` means that configurations 2, 1, 0, and 3 are now in replicas 0, 1, 2, 3. This can happen in multiple swaps when a previous swap between configurations 0 and 2 has just been accepted. Otherwise, the list of - filenames should always be in the ascending order, e.g. :code:`['dhdl_0.xvg', 'dhdl_1.xvg', 'dhdl_2.xvg', + filenames should always be in the ascending order, e.g., :code:`['dhdl_0.xvg', 'dhdl_1.xvg', 'dhdl_2.xvg', dhdl_3.xvg]`. states : list A list of last sampled states (in global indices) in the DHDL files corresponding to configurations 0, 1, - 2, ... (e.g. :code:`dhdl_0.xvg`, :code:`dhdl_1.xvg`, :code:`dhdl_2.xvg`, ...) + 2, ... (e.g., :code:`dhdl_0.xvg`, :code:`dhdl_1.xvg`, :code:`dhdl_2.xvg`, ...) This list can be generated by :obj:`.extract_final_dhdl_info`. shifts : list A list of state shifts for converting global state indices to the local ones. Specifically, :code:`states` @@ -1046,12 +1067,13 @@ def calc_prob_acc(self, swap, dhdl_files, states, shifts): def accept_or_reject(self, prob_acc): """ - Returns a boolean variable indiciating whether the proposed swap should be acceepted given the acceptance rate. + Returns a boolean variable indicating whether the proposed swap should be acceepted or not given + the acceptance ratio. Parameters ---------- prob_acc : float - The acceptance rate. + The acceptance ratio. Returns ------- @@ -1082,16 +1104,17 @@ def accept_or_reject(self, prob_acc): def get_averaged_weights(self, log_files): """ - For each replica, calculate the averaged weights (and the associated error) from the time series - of the weights since the previous update of the Wang-Landau incrementor. + For each replica, calculates the averaged weights (and the associated error) from the time series + of the weights since the previous update of the Wang-Landau incrementor. This is only relevant + for weight-updating REXEE simulations. Parameters ---------- log_files : list A list of file paths to GROMACS LOG files of different replicas. - Returned - -------- + Returns + ------- weights_avg : list A list of lists of weights averaged since the last update of the Wang-Landau incrementor. The length of the list should be the number of replicas. @@ -1120,18 +1143,20 @@ def get_averaged_weights(self, log_files): def weight_correction(self, weights, counts): """ - Corrects the lambda weights based on the histogram counts. Namely, + Adjusts the lambda weights based on the histogram counts by using the following equation: :math:`g_k' = g_k + ln(N_{k-1}/N_k)`, where :math:`g_k` and :math:`g_k'` - are the lambda weight after and before the correction, respectively. + are the lambda weight before and after the correction, respectively. Notably, in any of the following situations, we don't do any correction. - - Either :math:`N_{k-1}` or :math:`N_k` is 0. - - Either :math:`N_{k-1}` or :math:`N_k` is smaller than the histogram cutoff. + - Either :math:`N_{k-1}` or :math:`N_k` is :math:`0`. + - Either :math:`N_{k-1}` or :math:`N_k` is smaller than the histogram cutoff specified by :code:`N_cutoff` + in the input YAML file. Parameters ---------- weights : list - A list of lists of weights (of ALL simulations) to be corrected. + A list of lists of weights (of ALL simulations) to be corrected. The i-th element corresponds to + the list of weights of the i-th replica. counts : list A list of lists of counts (of ALL simulations). @@ -1164,17 +1189,22 @@ def weight_correction(self, weights, counts): def histogram_correction(self, hist, print_values=True): """ - Adjust the histogram counts. Specifically, the ratio of corrected histogram counts - for adjancent states is the geometric mean of the ratio of the original histogram counts - for the same states. Note, however, if the histogram counts are 0 for some states, the - histogram correction will be skipped and the original histogram counts will be returned. + Adjusts the histogram counts. For example, if replicas A and B both sample states 1 and 2 and have + histogram counts :math:`N^A_1`, :math:`N^A_2`, :math:`N^B_1`, and :math:`N^B_2`, the corrected histogram + counts for states 1 and 2 for BOTH replicas will be adjusted according to the following equation: + :math:`N_1'/N_2'=((N_1^A N_1^B)/(N_2^A N_2^B))^{1/2}`. Namely, the ratio of the corrected histogram + counts for adjacent states is the geometric mean of the ratio of the original histogram counts + for the same states. Note that if any histogram count is 0, histogram correction will not be performed + and the original histogram counts will be returned. Parameters ---------- hist : list - A list of lists of histogram counts of ALL simulations. - print_values : bool, optional - Whether to print the histograms for each replica before and after histogram correction. + A list of lists of histogram counts of ALL simulations. The i-th element corresponds to + the list of histogram counts of the i-th replica. + print_values : bool, Optional + Whether to print the histogram counts for each replica before and after histogram correction. + The default is :code:`True`. Returns ------- @@ -1234,26 +1264,27 @@ def histogram_correction(self, hist, print_values=True): def combine_weights(self, weights, weights_err=None, print_values=True): """ - Combine alchemical weights across multiple replicas. Note that if + Combines alchemical weights across multiple replicas. Note that if :code:`weights_err` is provided, inverse-variance weighting will be used. Care must be taken since inverse-variance weighting can lead to slower - convergence if the provided errors are not accurate. (See :ref:`doc_w_schemes` for mor details.) + convergence if the provided errors are not accurate. (See :ref:`doc_w_schemes` for more details.) Parameters ---------- weights : list - A list of lists alchemical weights of ALL simulations. - weights_err : list, optional - A list of lists of errors corresponding to the values in :code:`weights`. - print_values : bool, optional + A list of lists of alchemical weights of ALL simulations. The i-th element corresponds to + the list of weights of the i-th replica. + weights_err : list, Optional + A list of lists of errors corresponding to the values in :code:`weights`. The default is :code:`None`. + print_values : bool, Optional Whether to print the weights for each replica before and - after weight combination for each replica. + after weight combination for each replica. The default is :code:`True`. Returns ------- weights_modified : list - A list of modified Wang-Landau weights of ALL simulations. - g_vec : np.ndarray + A list of modified alchemical weights of ALL simulations. + g_vec : numpy.ndarray An array of alchemical weights of the whole range of states. """ # (1) Print the original weights @@ -1376,7 +1407,7 @@ def _run_grompp(self, n, swap_pattern): def _run_mdrun(self, n): """ - Executes GROMACS mdrun commands in parallel. + Executes GROMACS mdrun commands in parallel for a REXEE simulation. Parameters ---------- @@ -1424,18 +1455,18 @@ def _run_mdrun(self, n): def run_REXEE(self, n, swap_pattern=None): """ - Perform one iteration in the REXEE simulation, which includes generating the - TPR files using the GROMACS grompp :code:`command` and running the expanded ensemble simulations - in parallel using GROMACS :code:`mdrun` command. The GROMACS commands are launched by as subprocesses. + Performs one iteration of a REXEE simulation, which includes generating the + TPR files using the GROMACS :code:`grompp` command and running the expanded ensemble simulations + in parallel using the GROMACS :code:`mdrun` command. The GROMACS commands are launched as subprocesses. The function assumes that the GROMACS executable is available. Parameters ---------- n : int The iteration index (starting from 0). - swap_pattern : list + swap_pattern : list, Optional A list generated by :obj:`.get_swapping_pattern`. It represents how the replicas should be swapped. - This parameter is not needed only if :code:`n` is 0. + This parameter is not needed only if :code:`n` is 0. The default is :code:`None`. """ if rank == 0: iter_str = f'\nIteration {n}: {self.dt * self.nst_sim * n: .1f} - {self.dt * self.nst_sim * (n + 1): .1f} ps' # noqa: E501 diff --git a/ensemble_md/tests/test_analyze_free_energy.py b/ensemble_md/tests/test_analyze_free_energy.py index 272fd216..30f1b9fc 100644 --- a/ensemble_md/tests/test_analyze_free_energy.py +++ b/ensemble_md/tests/test_analyze_free_energy.py @@ -152,7 +152,7 @@ def test_combine_df_adjacent(): state_ranges = [[0, 1, 2], [1, 2, 3]] # Test 1: df_err_adjacent is None (in which case err_type is ignored) - results = analyze_free_energy._combine_df_adjacent(df_adjacent, None, state_ranges, "propagate") + results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, None, "propagate") assert results[0] == [1, 3.5, 6] assert math.isnan(results[1][0]) assert results[1][1] == np.std([3, 4], ddof=1) @@ -160,14 +160,14 @@ def test_combine_df_adjacent(): assert results[2] == [False, True, False] # Test 2: df_err_adjacent is not None and err_type is "std" - results = analyze_free_energy._combine_df_adjacent(df_adjacent, df_err_adjacent, state_ranges, "std") + results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, "std") assert results[0] == [1, 3.5, 6] np.testing.assert_array_almost_equal(results[1], [0.1, np.std([3, 4], ddof=1), 0.1]) assert results[2] == [False, True, False] # Test 3: df_err_adjacent is not None and err_type is "propagate" df_err_adjacent = [[0.1, 0.1], [0.2, 0.1]] # make the errs different so that the weighted mean will not be equal to simple mean # noqa: E501 - results = analyze_free_energy._combine_df_adjacent(df_adjacent, df_err_adjacent, state_ranges, "propagate") + results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, "propagate") assert results[0] == [1, utils.weighted_mean([3, 4], [0.1, 0.2])[0], 6] assert results[1] == [0.1, utils.weighted_mean([3, 4], [0.1, 0.2])[1], 0.1] assert results[2] == [False, True, False] diff --git a/ensemble_md/tests/test_analyze_matrix.py b/ensemble_md/tests/test_analyze_matrix.py index 6ec25fcb..f7523ab1 100644 --- a/ensemble_md/tests/test_analyze_matrix.py +++ b/ensemble_md/tests/test_analyze_matrix.py @@ -56,7 +56,7 @@ def test_calc_transmtx(): # Case 3: Hamiltonian replica exchange # Note that the transition matrices shown in the log file of different replicas should all be the same. # Here we use log/HREX.log, which is a part of the log file from anthracene HREX. - A3, B3, C3 = analyze_matrix.calc_transmtx(os.path.join(input_path, 'log/HREX.log'), expanded_ensemble=False) + A3, B3, C3 = analyze_matrix.calc_transmtx(os.path.join(input_path, 'log/HREX.log'), simulation_type='HREX') A3_expected = np.array([[0.7869, 0.2041, 0.0087, 0.0003, 0.0000, 0.0000, 0.0000, 0.0000], # noqa: E128, E202, E203, E501 [0.2041, 0.7189, 0.0728, 0.0041, 0.0001, 0.0000, 0.0000, 0.0000], # noqa: E128, E202, E203 [0.0087, 0.0728, 0.7862, 0.1251, 0.0071, 0.0001, 0.0000, 0.0000], # noqa: E202, E203 diff --git a/ensemble_md/tests/test_gmx_parser.py b/ensemble_md/tests/test_gmx_parser.py index 1a929645..2ce72239 100644 --- a/ensemble_md/tests/test_gmx_parser.py +++ b/ensemble_md/tests/test_gmx_parser.py @@ -23,11 +23,6 @@ def test_parse_log(): - """ - - Case 1: The weights have never been equilibrated. - - Case 2: The weights were equilibrated during the simulation. - - Case 3: The weights were fixed in the simulation. - """ # Case 1: weight-updating simulation weights_0, counts_0, wl_delta_0, equil_time_0 = gmx_parser.parse_log(os.path.join(input_path, 'log/EXE_0.log')) assert len(weights_0) == 5 @@ -68,15 +63,6 @@ def test_parse_log(): assert equil_time_3 == 0 -def test_filename(): - MDP = gmx_parser.MDP() - with pytest.raises(ValueError, match="A file name is required because no default file name was defined."): - MDP.filename() - - MDP._filename = 'test' - assert MDP.filename() == 'test' - - class Test_MDP: def test__eq__(self): mdp_1 = gmx_parser.MDP("ensemble_md/tests/data/expanded.mdp") @@ -84,12 +70,11 @@ def test__eq__(self): assert mdp_1 == mdp_2 def test_read(self): - mdp = gmx_parser.MDP() f = open("fake.mdp", "a") f.write("TEST") f.close() with pytest.raises(ParseError, match="'fake.mdp': unknown line in mdp file, 'TEST'"): - mdp.read('fake.mdp') + gmx_parser.MDP('fake.mdp') # This should call the read function in __init__ os.remove('fake.mdp') def test_write(self): diff --git a/ensemble_md/tests/test_replica_exchange_EE.py b/ensemble_md/tests/test_replica_exchange_EE.py index 3a2880d8..3e079052 100644 --- a/ensemble_md/tests/test_replica_exchange_EE.py +++ b/ensemble_md/tests/test_replica_exchange_EE.py @@ -359,7 +359,7 @@ def test_print_params(self, capfd, params_dict): REXEE.print_params() out_1, err = capfd.readouterr() L = "" - L += "Important parameters of EXEE\n============================\n" + L += "Important parameters of REXEE\n=============================\n" L += f"Python version: {sys.version}\n" L += f"GROMACS executable: {REXEE.gmx_path}\n" # Easier to pass CI. This is easy to catch anyway L += f"GROMACS version: {REXEE.gmx_version}\n" # Easier to pass CI. This is easy to catch anyway diff --git a/ensemble_md/tests/test_utils.py b/ensemble_md/tests/test_utils.py index 495fc68d..f130e1ca 100644 --- a/ensemble_md/tests/test_utils.py +++ b/ensemble_md/tests/test_utils.py @@ -89,21 +89,15 @@ def test_format_time(): assert utils.format_time(90061) == "1 day, 1 hour(s) 1 minute(s) 1 second(s)" -def test_autoconvert(): +def test_convert_to_numeric(): # Test non-string input - assert utils._autoconvert(42) == 42 - - # Test string input that can be converted to int - assert utils._autoconvert("42") == 42 - - # Test string input that can be converted to float - assert utils._autoconvert("3.14159") == 3.14159 - - # Test string input that can be converted to a numpy array of ints - assert utils._autoconvert("1 2 3") == [1, 2, 3] - - # Test string input that can be converted to a numpy array of floats - assert utils._autoconvert("1.0 2.0 3.0") == [1.0, 2.0, 3.0] + assert utils._convert_to_numeric(42) == 42 + assert utils._convert_to_numeric("42") == 42 + assert utils._convert_to_numeric("3.14159") == 3.14159 + assert utils._convert_to_numeric("1 2 3") == [1, 2, 3] + assert utils._convert_to_numeric("1.0 2.0 3.0") == [1.0, 2.0, 3.0] + assert utils._convert_to_numeric("Hello, world!") == ['Hello,', 'world!'] + assert utils._convert_to_numeric('Y Y Y') == ['Y', 'Y', 'Y'] def test_get_subplot_dimension(): diff --git a/ensemble_md/utils/exceptions.py b/ensemble_md/utils/exceptions.py index 88a03df7..9cba4960 100644 --- a/ensemble_md/utils/exceptions.py +++ b/ensemble_md/utils/exceptions.py @@ -13,4 +13,4 @@ class ParameterError(Exception): class ParseError(Exception): - """Error raised when parsing of a file failed. Modified from GromacsWrapper.""" + """Error raised during parsing a file.""" diff --git a/ensemble_md/utils/gmx_parser.py b/ensemble_md/utils/gmx_parser.py index 2fd31df8..9cddacc6 100644 --- a/ensemble_md/utils/gmx_parser.py +++ b/ensemble_md/utils/gmx_parser.py @@ -13,7 +13,6 @@ import os import re import six -import logging import warnings from collections import OrderedDict as odict @@ -23,29 +22,29 @@ def parse_log(log_file): """ - This function parses a log file generated by expanded ensemble and provides - important information, especially for running new iterations in REXEE. - Typically, there are three types of log files from an expanded ensemble simulation: + Parses a log file generated by a GROMACS expanded ensemble simulation and extracts + important information. This function is especially useful for extracting information from each iteration in a REXEE + simulation on the fly. There are three types of log files from an expanded ensemble simulation: - **Case 1**: The weights are still updating in the simulation and have never been equilibrated. - - The output :code:`equil_time` should always be -1. + - In this case, the output :code:`equil_time` should be -1. - **Case 2**: The weights were equilibrated during the simulation. - - The output :code:`equil_time` should be the time (in ps) it took to get the weights equilibrated. + - The output :code:`equil_time` is the time (in ps) it took to get the weights equilibrated. - The final weights (:code:`weights`) will just be the equilibrated weights. - **Case 3**: The weights were fixed in the simulation. - - The output :code:`equil_time` should always be 0. + - In this case, the output :code:`equil_time` should be 0. - The final weights (which never change during the simulation) and the final counts will still be returned. Parameters ---------- log_file : str - The log file to be parsed. + The file path of the input log file Returns ------- @@ -56,16 +55,18 @@ def parse_log(log_file): incrementor will be returned. - In Case 2, a list of list of weights as a function of time since the last update of the Wang-Landau incrementor up to equilibration will be returned. - - In Case 3, the returned list will only have one list inside, which is the list of the final weights. + - In Case 3, the returned list will only have one list inside, which is the list of values at which + the weights were fixed. - That is, for all cases, :code:`weights[-1]` will be the final weights, which are useful in REXEE. + That is, for all cases, :code:`weights[-1]` will be the final weights, which are important for seeding the + next iteration in a REXEE simulation. counts : list The final histogram counts. wl_delta : float The final Wang-Landau incementor. In Cases 2 and 3, :code:`None` will be returned. equil_time : int or float - In Case 1, -1 will be returned, which means that the weights have not been equilibrated. - - In Case 2, the time in ps that it took to equilibrate the weights will be returned. + - In Case 2, the equilibration time in ps will be returned. - In Case 3, 0 will be returned, which means that the weights were fixed during the simulation. """ f = open(log_file, "r") @@ -168,167 +169,64 @@ def parse_log(log_file): return weights, counts, wl_delta, equil_time -class FileUtils(object): - """Mixin class to provide additional file-related capabilities. - Modified from `utilities.py in GromacsWrapper `_. - Copyright (c) 2009 Oliver Beckstein +class MDP(odict): """ + A class that represents a GROMACS MDP file. Note that an MDP instance is an ordered dictionary, + with the i-th key corresponding to the i-th line in the MDP file. Comments and blank lines are + also preserved, e.g., with keys 'C0001' and 'B0001', respectively. The value corresponding to a + 'C' key is the comment itself, while the value corresponding to a 'B' key is an empty string. + Comments after a parameter on the same line are discarded. Leading and trailing spaces + are always stripped. - #: Default extension for files read/written by this class. - default_extension = None - - def _init_filename(self, filename=None, ext=None): - """Initialize the current filename :attr:`FileUtils.real_filename` of the object. - - Bit of a hack. - - - The first invocation must have ``filename != None``; this will set a - default filename with suffix :attr:`FileUtils.default_extension` - unless another one was supplied. - - - Subsequent invocations either change the filename accordingly or - ensure that the default filename is set with the proper suffix. - - """ - - extension = ext or self.default_extension - filename = self.filename( - filename, ext=extension, use_my_ext=True, set_default=True - ) - #: Current full path of the object for reading and writing I/O. - self.real_filename = os.path.realpath(filename) - - def filename(self, filename=None, ext=None, set_default=False, use_my_ext=False): - """Supply a file name for the class object. - - Typical uses:: - - fn = filename() ---> - fn = filename('name.ext') ---> 'name' - fn = filename(ext='pickle') ---> '.pickle' - fn = filename('name.inp','pdf') --> 'name.pdf' - fn = filename('foo.pdf',ext='png',use_my_ext=True) --> 'foo.pdf' - - The returned filename is stripped of the extension - (``use_my_ext=False``) and if provided, another extension is - appended. Chooses a default if no filename is given. - - Raises a ``ValueError`` exception if no default file name is known. - - If ``set_default=True`` then the default filename is also set. - - ``use_my_ext=True`` lets the suffix of a provided filename take - priority over a default ``ext`` tension. - """ - if filename is None: - if not hasattr(self, "_filename"): - self._filename = None # add attribute to class - if self._filename: - filename = self._filename - else: - raise ValueError( - "A file name is required because no default file name was defined." - ) - my_ext = None - else: - filename, my_ext = os.path.splitext(filename) - if set_default: # replaces existing default file name - self._filename = filename - if my_ext and use_my_ext: - ext = my_ext - if ext is not None: - if ext.startswith(os.extsep): - ext = ext[1:] # strip a dot to avoid annoying mistakes - if ext != "": - filename = filename + os.extsep + ext - return filename - - -class MDP(odict, FileUtils): - """Class that represents a Gromacs mdp run input file. - Modified from `GromacsWrapper `_. - Copyright (c) 2009-2011 Oliver Beckstein - The MDP instance is an ordered dictionary. - - - *Parameter names* are keys in the dictionary. - - *Comments* are sequentially numbered with keys Comment0001, - Comment0002, ... - - *Empty lines* are similarly preserved as Blank0001, .... - - When writing, the dictionary is dumped in the recorded order to a - file. Inserting keys at a specific position is not possible. - - Currently, comments after a parameter on the same line are - discarded. Leading and trailing spaces are always stripped. + Parameters + ---------- + input_mdp : str, Optional + The path to the input MDP file. The default is None. + **kwargs : Optional + Additional keyword arguments to be passed to add additional key-value pairs to the MDP instance. + Note that no sanity checks will be performed for the key-value pairs passed in this way. This + also does not work for keys that are not legal python variable names, such as anything that includes + a minus '-' sign or starts with a number. + + Attributes + ---------- + COMMENT : :code:`re.Pattern` object + A compiled regular expression pattern for comments in MDP files. + PARAMETER : :code:`re.Pattern` object + A compiled regular expression pattern for parameters in MDP files. + input_mdp : str + The real path to the input MDP file returned by :code:`os.path.realpath(input_mdp)`, + which resolves any symbolic links in the path. + + Example + ------- + >>> from ensemble_md.utils import gmx_parser + >>> gmx_parser.MDP("em.mdp") + MDP([('C0001', 'em.mdp - used as input into grompp to generate em.tpr'), ('C0002', 'All unspecified parameters adopt their own default values.'), ('B0001', ''), ('C0003', 'Run Control'), ('integrator', 'steep'), ('nsteps', 500000), ('B0002', ''), ('C0004', 'Energy minnimization'), ('emtol', 100.0), ('emstep', 0.01), ('B0003', ''), ('C0005', 'Neighbor searching/Electrostatics/Van der Waals'), ('cutoff-scheme', 'Verlet'), ('nstlist', 10), ('ns_type', 'grid'), ('pbc', 'xyz'), ('coulombtype', 'PME'), ('rcoulomb', 1.0), ('rvdw', 1.0)]) # noqa: E501 """ + # Below are some class variables accessible to all functions. + COMMENT = re.compile("""\s*;\s*(?P.*)""") # noqa: W605 + PARAMETER = re.compile("""\s*(?P[^=]+?)\s*=\s*(?P[^;]*)(?P\s*;.*)?""", re.VERBOSE) # noqa: W605, E501 - default_extension = "mdp" - logger = logging.getLogger("gromacs.formats.MDP") - - COMMENT = re.compile("""\s*;\s*(?P.*)""") # eat initial ws # noqa: W605 - # see regex in cbook.edit_mdp() - PARAMETER = re.compile( - """ - \s*(?P[^=]+?)\s*=\s* # parameter (ws-stripped), before '=' # noqa: W605 - (?P[^;]*) # value (stop before comment=;) # noqa: W605 - (?P\s*;.*)? # optional comment # noqa: W605 - """, - re.VERBOSE, - ) - - def __init__(self, filename=None, autoconvert=True, **kwargs): - """Initialize mdp structure. - - :Arguments: - *filename* - read from mdp file - *autoconvert* : boolean - ``True`` converts numerical values to python numerical types; - ``False`` keeps everything as strings [``True``] - *kwargs* - Populate the MDP with key=value pairs. (NO SANITY CHECKS; and also - does not work for keys that are not legal python variable names such - as anything that includes a minus '-' sign or starts with a number). - """ - super(MDP, self).__init__( - **kwargs - ) # can use kwargs to set dict! (but no sanity checks!) - - self.autoconvert = autoconvert - - if filename is not None: - self._init_filename(filename) - self.read(filename) + def __init__(self, input_mdp=None, **kwargs): + super(MDP, self).__init__(**kwargs) # can use kwargs to set dict! (but no sanity checks!) + if input_mdp is not None: + self.input_mdp = os.path.realpath(input_mdp) + self.read() - def __eq__(self, other): + def read(self): """ - __eq__ inherited from FileUtils needs to be overridden if new attributes (autoconvert in - this case) are assigned to the instance of the subclass (MDP in our case). - See `this post by LGTM `_ for more details. + Reads and parses the input MDP file. """ - if not isinstance(other, MDP): - return False - return FileUtils.__eq__(self, other) and self.autoconvert == other.autoconvert - - def _transform(self, value): - if self.autoconvert: - return utils._autoconvert(value) - else: - return value.rstrip() - - def read(self, filename=None): - """Read and parse mdp file *filename*.""" - self._init_filename(filename) - def BLANK(i): - return "B{0:04d}".format(i) + return f"B{i:04d}" def COMMENT(i): - return "C{0:04d}".format(i) + return f"C{i:04d}" data = odict() iblank = icomment = 0 - with open(self.real_filename) as mdp: + with open(self.input_mdp) as mdp: for line in mdp: line = line.strip() if len(line) == 0: @@ -340,73 +238,85 @@ def COMMENT(i): icomment += 1 data[COMMENT(icomment)] = m.group("value") continue - # parameter + m = self.PARAMETER.match(line) if m: - # check for comments after parameter?? -- currently discarded parameter = m.group("parameter") - value = self._transform(m.group("value")) + value = utils._convert_to_numeric(m.group("value")) data[parameter] = value else: - errmsg = "{filename!r}: unknown line in mdp file, {line!r}".format( - **vars() - ) - self.logger.error(errmsg) - raise ParseError(errmsg) + err_msg = f"{os.path.basename(self.input_mdp)!r}: unknown line in mdp file, {line!r}" + raise ParseError(err_msg) super(MDP, self).update(data) - def write(self, filename=None, skipempty=False): - """Write mdp file to *filename*. + def write(self, output_mdp=None, skipempty=False): + """ + Writes the MDP instance (the ordered dictionary) to an output MDP file. Parameters ---------- - filename : str - Output mdp file; default is the filename the mdp was read from. If the filename - is not supplied, the function will overwrite the file that the mdp was read from. - skipempty : bool - ``True`` removes any parameter lines from output that contain empty values [``False``] + output_mdp : str, Optional + The file path of the output MDP file. The default is the filename the MDP instance was built from. + If that if :code:`output_mdp` is not specified, the input MDP file will be overwritten. + skipempty : bool, Optional + Whether to skip empty values when writing the MDP file. If :code:`True`, any parameter lines from + the output that contain empty values will be removed. The default is :code:`False`. """ # The line 'if skipempty and (v == "" or v is None):' below could possibly incur FutureWarning warnings.simplefilter(action='ignore', category=FutureWarning) - with open(self.filename(filename, ext="mdp"), "w") as mdp: + if output_mdp is None: + output_mdp = self.input_mdp + + with open(output_mdp, "w") as mdp: for k, v in self.items(): if k[0] == "B": # blank line mdp.write("\n") elif k[0] == "C": # comment - mdp.write("; {v!s}\n".format(**vars())) + mdp.write(f"; {v!s}\n") else: # parameter = value if skipempty and (v == "" or v is None): continue if isinstance(v, six.string_types) or not hasattr(v, "__iter__"): - mdp.write("{k!s} = {v!s}\n".format(**vars())) + mdp.write(f"{k!s} = {v!s}\n") else: - mdp.write("{} = {}\n".format(k, " ".join(map(str, v)))) + mdp.write(f"{k} = {' '.join(map(str, v))}\n") def compare_MDPs(mdp_list, print_diff=False): """ - Given a list of MDP files, identify the parameters for which not all MDP - files have the same values. Note that this function is not aware of the default - values of GROMACS parameters. (Currently, this function is not used in the - workflow adopted in :code:`run_REXEE.py` but it might be useful in some places, - so we decided to keep it.) + Identifies the parameters differeing between a given list of MDP files. Note that + this function is not aware of the default values of GROMACS parameters. + (Currently, this function is not used in the workflow adopted in :code:`run_REXEE.py` + but it might be useful in some places, so we decided to keep it.) Parameters ---------- mdp_list : list A list of MDP files. - print_diff : bool - If :code:`True`, print to screen the parameters that are different among the MDP files - and the values of the parameters in the MDP files in a more readable format. + print_diff : bool, Optional + Whether to print the parameters that are different among the MDP files in a more readable format. + The default is :code:`False`. Returns ------- diff_params : dict - A dictionary of parameters that are different among the MDP files. - The keys are the parameter names and the values is a list of values of the - parameters in the MDP files. + A dictionary of parameters differing between MDP files. The keys are the parameter names and + the values is a list of values of the parameters in the MDP files. + + Example + ------- + >>> from ensemble_md.utils import gmx_parser + >>> mdp_list = ['A.mdp', 'B.mdp'] + >>> diff_params = gmx_parser.compare_MDPs(mdp_list, print_diff=True) + The following parameters are different among the MDP files: + wl_scale + - A.mdp: None + - B.mdp: 0.8 + ... + >>> print(diff_params) + {'wl_scale': [None, 0.8], ...} """ diff_params = {} for i in range(len(mdp_list)): diff --git a/ensemble_md/utils/utils.py b/ensemble_md/utils/utils.py index e84878ff..f072c8ae 100644 --- a/ensemble_md/utils/utils.py +++ b/ensemble_md/utils/utils.py @@ -8,7 +8,7 @@ # # #################################################################### """ -The :obj:`.utils` module provides useful utility functions. +The :obj:`.utils` module provides useful utility functions for running or analyzing REXEE simulations. """ import sys import glob @@ -21,41 +21,35 @@ class Logger: """ - Redirects the STDOUT and STDERR to a specified output file while preserving them on screen. + A logger class that redirects the STDOUT and STDERR to a specified output file while + preserving the output on screen. This is useful for logging terminal output to a file + for later analysis while still seeing the output in real-time during execution. Parameters ---------- logfile : str - Name of the output file to write the logged messages. + The file path to which the standard output and standard error should be logged. Attributes ---------- - terminal : file object - The file object that represents the original STDOUT (i.e., the screen). - log : file object - The file object that represents the logfile where messages will be written. + terminal : :code:`io.TextIOWrapper` object + The original standard output object, typically :code:`sys.stdout`. + log : :code:`io.TextIOWrapper` object + File object used to log the output in append mode. """ def __init__(self, logfile): - """ - Initializes a Logger instance. - - Parameters - ---------- - logfile : str - Name of the output file to write the logged messages. - """ self.terminal = sys.stdout self.log = open(logfile, "a") def write(self, message): """ - Writes the given message to both the STDOUT and the logfile. + Writes a message to the terminal and to the log file. Parameters ---------- message : str - The message to be written to STDOUT and logfile. + The message to be written to STDOUT and the log file. """ self.terminal.write(message) self.log.write(message) @@ -63,7 +57,7 @@ def write(self, message): def flush(self): """ This method is needed for Python 3 compatibility. This handles the flush command by doing nothing. - You might want to specify some extra behavior here. + Some extra behaviors may be specified here. """ # self.terminal.log() pass @@ -71,15 +65,15 @@ def flush(self): def run_gmx_cmd(arguments, prompt_input=None): """ - Runs a GROMACS command as a subprocess + Runs a GROMACS command through a subprocess call. Parameters ---------- arguments : list - A list of arguments that compose of the GROMACS command to run, e.g. + A list of arguments that compose of the GROMACS command to run, e.g., :code:`['gmx', 'mdrun', '-deffnm', 'sys']`. - prompt_input : str or None - The input to be passed to the GROMACS command when it prompts for input. + prompt_input : str or None, Optional + The input to be passed to the interative prompt launched by the GROMACS command, if any. Returns ------- @@ -89,7 +83,6 @@ def run_gmx_cmd(arguments, prompt_input=None): The STDOUT of the process. stderr: str or None The STDERR or the process. - """ try: result = subprocess.run(arguments, capture_output=True, text=True, input=prompt_input, check=True) @@ -102,7 +95,7 @@ def run_gmx_cmd(arguments, prompt_input=None): def format_time(t): """ - Converts time in seconds to the "most readable" format. + Converts time in seconds to a more readable format. Parameters ---------- @@ -112,7 +105,9 @@ def format_time(t): Returns ------- t_str : str - A string in the format of "XX day XX hour(s) XX minute(s) XX second(s)". + A string representing the time duration in a format of "X hour(s) Y minute(s) Z second(s)", adjusting the units + as necessary based on the input duration, e.g., 1 hour(s) 0 minute(s) 0 second(s) for 3600 seconds and + 15 minute(s) 30 second(s) for 930 seconds. """ hh_mm_ss = str(datetime.timedelta(seconds=t)).split(":") @@ -133,30 +128,23 @@ def format_time(t): return t_str -def _autoconvert(s): +def _convert_to_numeric(s): """ - Converts input to a numerical type if possible. Used for the MDP parser. - Modified from `utilities.py in GromacsWrapper `_. - Copyright (c) 2009 Oliver Beckstein + Converts the input to a numerical type when possible. This internal function is used for the MDP parser. Parameters ---------- - s : str or any - The input value to be converted to a numerical type if possible. If :code:`s` is not a string, - it is returned as is. + s : any + The input value to be converted to a numerical type if possible. The data type of :code:`s` is + usually :code:`str` but can be any. However, if :code:`s` is not a string, it will be returned as is. Returns ------- - numerical : int, float, numpy.ndarray, or any + numerical : any The converted numerical value. If :code:`s` can be converted to a single numerical value, that value is returned as an :code:`int` or :code:`float`. If :code:`s` can be converted to - multiple numerical values, a :code:`numpy.ndarray` containing those values is returned. + multiple numerical values, a list containing those values is returned. If :code:`s` cannot be converted to a numerical value, :code:`s` is returned as is. - - Raises - ------ - ValueError - If :code:`s` cannot be converted to a numerical value. """ if type(s) is not str: return s @@ -167,27 +155,27 @@ def _autoconvert(s): return s[0] else: return s - """ - if len(s) != 0 and type(s[0]) == str: - # For the case like pull_coord1_dim = Y Y Y - return s - else: - return np.array(s) - """ except (ValueError, AttributeError): pass - raise ValueError("Failed to autoconvert {0!r}".format(s)) + raise ValueError(f"Failed to convert {s} to a numeric value.") def _get_subplot_dimension(n_panels): """ - Gets the numbers of rows and columns in a subplot such that - the arrangement of the . + Gets the number of rows and columns for a subplot based on the number of panels such + that the subplots are arranged in a grid that is as square as possible. A greater number + of columns is preferred to a greater number of rows. Parameters ---------- n_panels : int - The number of panels in the subplot. + The number of panels to be arranged in subplots. + + Example + ------- + >>> from ensemble_md.utils import utils + >>> utils._get_subplot_dimension(10) + (4, 3) """ if int(np.sqrt(n_panels) + 0.5) ** 2 == n_panels: # perfect square number @@ -260,13 +248,13 @@ def calc_rmse(data, ref): def get_time_metrics(log): """ - Gets the time-based metrics from a log file, including the core time (s), + Gets the time-based metrics from a log file of a REXEE simulation, including the core time, wall time, and performance (ns/day). Parameters ---------- log : str - The input log file. + The file path of the input log file. Returns ------- @@ -290,17 +278,17 @@ def get_time_metrics(log): def analyze_REXEE_time(n_iter=None, log_files=None): """ - Perform simple data analysis on the wall times and performances of all iterations of an REXEE simulation. + Performs simple data analysis on the wall times and performances of all iterations of an REXEE simulation. Parameters ---------- - n_iter : None or int + n_iter : None or int, Optional The number of iterations in the REXEE simulation. If None, the function will try to find the number of - iterations by counting the number of directories named "iteration_*" in the simulation directory - (i.e., :code:`sim_0`) in the current working directory or where the log files are located. - log_files : None or list - A list of lists log files with the shape of (n_iter, n_replicas). If None, the function will try to find - the log files by searching the current working directory. + iterations by counting the number of directories named in the format of :code`iteration_*` in the simulation + directory (specifically :code:`sim_0`) in the current working directory or where the log files are located. + log_files : None or list, Optional + A list of lists of log paths with the shape of :code:`(n_iter, n_replicas)`. If None, the function will try to + find the log files by searching the current working directory. Returns ------- @@ -310,7 +298,7 @@ def analyze_REXEE_time(n_iter=None, log_files=None): The total time spent in synchronizing all replicas, which is the sum of the differences between the longest and the shortest time elapsed to finish a iteration. t_wall_list : list - The list of wall times of finishing each mdrun command. + The list of wall times for finishing each GROMACS mdrun command. """ if n_iter is None: if log_files is None: