Skip to content

Commit

Permalink
Merge pull request #36 from wehs7661/analysis_tweaks
Browse files Browse the repository at this point in the history
Some minor tweaks for the analysis codes
  • Loading branch information
wehs7661 authored Feb 29, 2024
2 parents bce555f + 2a51ffc commit 074dd47
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 98 deletions.
214 changes: 128 additions & 86 deletions ensemble_md/analysis/clustering.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
####################################################################
# #
# ensemble_md, #
# a python package for running GROMACS simulation ensembles #
# #
# Written by Wei-Tse Hsu <wehs7661@colorado.edu> #
# Copyright (c) 2022 University of Colorado Boulder #
# #
####################################################################
import numpy as np
from ensemble_md.utils.utils import run_gmx_cmd


def cluster_traj(gmx_executable, inputs, grps, method='linkage', cutoff=0.1, suffix=None):
def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkage', cutoff=0.1, suffix=None):
"""
Performs clustering analysis on a trajectory using the GROMACS command :code:`gmx cluster`.
Note that only fully coupled configurations are considered.
Expand All @@ -12,27 +21,46 @@ def cluster_traj(gmx_executable, inputs, grps, method='linkage', cutoff=0.1, suf
gmx_executable : str
The path to the GROMACS executable.
inputs : dict
A dictionary that contains the file names of the input trajectory file (XTC or TRR),
the configuration file (TPR or GRO), the file that contains the time series of the
state index, and the index file (NDX). The must include the keys :code:`traj`, :code:`config`,
:code:`xvg`, and :code:`index`. Note that the value for the key :code:`index` can be :code:`None`.
A dictionary that contains the different input files required for the clustering analysis.
The dictionary must have the following four keys: :code:`traj` (input trajectory file in
XTC or TRR format), :code:`config` (the configuration file in TPR or GRO format),
:code:`xvg` (a GROMACS XVG file), and :code:`index` (an index/NDX file), with the values
being the paths. Note that the value of the key :code:`index` can be :code:`None`, in which
case the function will use a default index file generated by :code:`gmx make_ndx`. If the
parameter :code:`coupled_only` is set to :code:`True`, an XVG file that contains the time
series of the state index (e.g., :code:`dhdl.xvg`) must be provided with the key :code:`xvg`.
Otherwise, the key :code:`xvg` can be set to :code:`None`.
grps : dict
A dictionary that contains the names of the groups in the index file (NDX) for
centering the system, calculating the RMSD, and outputting. The keys are
:code:`center`, :code:`rmsd`, and :code:`output`.
coupled_only : bool
Whether to consider only the fully coupled configurations. The default is :code:`True`.
method : str
The method for clustering available for the GROMACS command :code:`gmx cluster`. The default is 'linkage'.
Check the GROMACS documentation for other available options.
cutoff : float
The cutoff in RMSD for clustering. The default is 0.1.
The RMSD cutoff for clustering in nm. The default is 0.1.
suffix : str
The suffix for the output files. The default is :code:`None`, which means no suffix will be added.
"""
# First check if all specified groups are present in the index file
# Check if the index file is provided
if inputs['index'] is None:
print('Running gmx make_ndx to generate an index file ...')
args = [
gmx_executable, 'make_ndx',
'-f', inputs['config'],
'-o', 'index.ndx',
]
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input='q\n')
inputs['index'] = 'index.ndx'

# Check if the groups are present in the index file
with open(inputs['index'], 'r') as f:
content = f.read()
for key in grps:
if grps[key] not in content:
raise ValueError(f'The group {grps[key]} is not present in the index file.')
raise ValueError(f'The group {grps[key]} is not present in the provided/generated index file.')

outputs = {
'nojump': 'nojump.xtc',
Expand All @@ -47,90 +75,104 @@ def cluster_traj(gmx_executable, inputs, grps, method='linkage', cutoff=0.1, suf
for key in outputs:
outputs[key] = outputs[key].replace('.', f'_{suffix}.')

print('Eliminating jumps across periodic boundaries for the input trajectory ...')
args = [
gmx_executable, 'trjconv',
'-f', inputs['traj'],
'-s', inputs['config'],
'-o', outputs['nojump'],
'-center', 'yes',
'-pbc', 'nojump',
'-drop', inputs['xvg'],
'-dropover', '0'
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')

print('Centering the system ...')
args = [
gmx_executable, 'trjconv',
'-f', outputs['nojump'],
'-s', inputs['config'],
'-o', outputs['center'],
'-center', 'yes',
'-pbc', 'mol',
'-ur', 'compact',
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')

print('Performing clustering analysis ...')
args = [
gmx_executable, 'cluster',
'-f', outputs['center'],
'-s', inputs['config'],
'-o', outputs['rmsd-clust'],
'-dist', outputs['rmsd-dist'],
'-g', outputs['cluster-log'],
'-cl', outputs['cluster-pdb'],
'-cutoff', str(cutoff),
'-method', method,
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')

rmsd_range, rmsd_avg, n_clusters = get_cluster_info(outputs['cluster-log'])

print(f'Range of RMSD values: from {rmsd_range[0]:.3f} to {rmsd_range[1]:.3f} nm')
print(f'Average RMSD: {rmsd_avg:.3f} nm')
print(f'Number of clusters: {n_clusters}')

if n_clusters > 1:
clusters, sizes = get_cluster_members(outputs['cluster-log'])
for i in range(1, n_clusters + 1):
print(f' - Cluster {i} accounts for {sizes[i] * 100:.2f}% of the total configurations.')

n_transitions, t_transitions = count_transitions(clusters)
print(f'Number of transitions between the two biggest clusters: {n_transitions}')
print(f'Time frames of the transitions (ps): {t_transitions}')

print('Calculating the inter-medoid RMSD between the two biggest clusters ...')
# Note that we pass outputs['cluster-pdb'] to -s so that the first medoid will be used as the reference
# Check if there is any fully coupled state in the trajectory
lambda_data = np.transpose(np.loadtxt(inputs['xvg'], comments=['#', '@']))[1]
if coupled_only is True and 0 not in lambda_data:
print('Terminating clustering analysis since no fully decoupled state is present in the input trajectory while coupled_only is set to True.') # noqa: E501
else:
# Either coupled_only is False or coupled_only is True but there are coupled configurations.
print('Eliminating jumps across periodic boundaries for the input trajectory ...')
args = [
gmx_executable, 'rms',
'-f', outputs['cluster-pdb'],
'-s', outputs['cluster-pdb'],
'-o', outputs['rmsd'],
gmx_executable, 'trjconv',
'-f', inputs['traj'],
'-s', inputs['config'],
'-n', inputs['index'],
'-o', outputs['nojump'],
'-center', 'yes',
'-pbc', 'nojump',
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])

# Here we simply assume same groups for least-squares fitting and RMSD calculation
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["rmsd"]}\n')
if coupled_only:
if inputs['xvg'] is None:
raise ValueError('The parameter "coupled_only" is set to True but no XVG file is provided.')
args.extend([
'-drop', inputs['xvg'],
'-dropover', '0'
])

returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')

print('Centering the system ...')
args = [
gmx_executable, 'trjconv',
'-f', outputs['nojump'],
'-s', inputs['config'],
'-n', inputs['index'],
'-o', outputs['center'],
'-center', 'yes',
'-pbc', 'mol',
'-ur', 'compact',
]
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')

if coupled_only is True:
N_coupled = np.count_nonzero(lambda_data == 0)
print(f'Number of fully coupled configurations: {N_coupled}')

print('Performing clustering analysis ...')
args = [
gmx_executable, 'cluster',
'-f', outputs['center'],
'-s', inputs['config'],
'-n', inputs['index'],
'-o', outputs['rmsd-clust'],
'-dist', outputs['rmsd-dist'],
'-g', outputs['cluster-log'],
'-cl', outputs['cluster-pdb'],
'-cutoff', str(cutoff),
'-method', method,
]
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')

rmsd = np.transpose(np.loadtxt(outputs['rmsd'], comments=['@', '#']))[1][1] # inter-medoid RMSD
print(f'Inter-medoid RMSD between the two biggest clusters: {rmsd:.3f} nm')
rmsd_range, rmsd_avg, n_clusters = get_cluster_info(outputs['cluster-log'])

print(f'Range of RMSD values: from {rmsd_range[0]:.3f} to {rmsd_range[1]:.3f} nm')
print(f'Average RMSD: {rmsd_avg:.3f} nm')
print(f'Number of clusters: {n_clusters}')

if n_clusters > 1:
clusters, sizes = get_cluster_members(outputs['cluster-log'])
for i in range(1, n_clusters + 1):
print(f' - Cluster {i} accounts for {sizes[i] * 100:.2f}% of the total configurations.')

n_transitions, t_transitions = count_transitions(clusters)
print(f'Number of transitions between the two biggest clusters: {n_transitions}')
print(f'Time frames of the transitions (ps): {t_transitions}')

print('Calculating the inter-medoid RMSD between the two biggest clusters ...')
# Note that we pass outputs['cluster-pdb'] to -s so that the first medoid will be used as the reference
args = [
gmx_executable, 'rms',
'-f', outputs['cluster-pdb'],
'-s', outputs['cluster-pdb'],
'-o', outputs['rmsd'],
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])

# Here we simply assume same groups for least-squares fitting and RMSD calculation
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["rmsd"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')

rmsd = np.transpose(np.loadtxt(outputs['rmsd'], comments=['@', '#']))[1][1] # inter-medoid RMSD
print(f'Inter-medoid RMSD between the two biggest clusters: {rmsd:.3f} nm')


def get_cluster_info(cluster_log):
Expand Down
21 changes: 9 additions & 12 deletions ensemble_md/cli/analyze_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def main():
sys.stdout = utils.Logger(logfile=args.output)
sys.stderr = utils.Logger(logfile=args.output)
section_idx = 0
poor_sampling = None

rc('font', **{
'family': 'sans-serif',
Expand Down Expand Up @@ -264,9 +263,6 @@ def main():
print(f' - Trajectory {j} ({len(t_list[j])} events): {np.mean(t_list[j]):.2f} {units}')
print(f' - Average of the above: {np.mean([np.mean(i) for i in t_list]):.2f} {units} (std: {np.std([np.mean(i) for i in t_list], ddof=1):.2f} {units})') # noqa: E501

if np.sum(np.isnan([np.mean(i) for i in t_list])) != 0:
poor_sampling = True

if REXEE.msm is True:
section_idx += 1

Expand Down Expand Up @@ -389,9 +385,6 @@ def main():

# Section 4 (or Section 3). Free energy calculations
if REXEE.free_energy is True:
if poor_sampling is True:
print('\nFree energy calculation is not performed since the sampling appears poor.')
sys.exit()
section_idx += 1
print(f'\n[ Section {section_idx}. Free energy calculations ]')

Expand All @@ -401,19 +394,22 @@ def main():
if os.path.isfile(f'{args.dir}/u_nk_data.pickle') is True:
print('Loading the preprocessed data u_nk ...')
with open(f'{args.dir}/u_nk_data.pickle', 'rb') as handle:
data_list = pickle.load(handle)
data_all = pickle.load(handle)
data_list, t_idx_list, g_list = data_all[0], data_all[1], data_all[2]
else: # should always be 'dhdl'
if os.path.isfile(f'{args.dir}/dHdl_data.pickle') is True:
print('Loading the preprocessed data dHdl ...')
with open(f'{args.dir}/dHdl_data.pickle', 'rb') as handle:
data_list = pickle.load(handle)
data_all = pickle.load(handle)
data_list, t_idx_list, g_list = data_all[0], data_all[1], data_all[2]

if data_list == []:
files_list = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(REXEE.n_sim)]
data_list, t_idx_list, g_list = analyze_free_energy.preprocess_data(files_list, REXEE.temp, REXEE.df_data_type, REXEE.df_spacing) # noqa: E501

data_all = [data_list, t_idx_list, g_list]
with open(f'{args.dir}/{REXEE.df_data_type}_data.pickle', 'wb') as handle:
pickle.dump(data_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(data_all, handle, protocol=pickle.HIGHEST_PROTOCOL)

# 4-2. Calculate the free energy profile
f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, REXEE.state_ranges, REXEE.df_method, REXEE.err_method, REXEE.n_bootstrap, REXEE.seed) # noqa: E501
Expand All @@ -439,9 +435,10 @@ def main():
print(f'Averaged start index: {t_avg}')
print(f'Averaged statistical inefficiency: {g_avg:.2f}')

data_list, _, _ = analyze_free_energy.preprocess_data(files_list, REXEE.temp, REXEE.df_data_type, REXEE.df_spacing, t_avg, g_avg) # noqa: E501
data_list, t_idx_list, g_list = analyze_free_energy.preprocess_data(files_list, REXEE.temp, REXEE.df_data_type, REXEE.df_spacing, t_avg, g_avg) # noqa: E501
data_all = [data_list, t_idx_list, g_list]
with open(f'{args.dir}/{REXEE.df_data_type}_data_avg_subsampling.pickle', 'wb') as handle:
pickle.dump(data_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(data_all, handle, protocol=pickle.HIGHEST_PROTOCOL)

f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, REXEE.state_ranges, REXEE.df_method, REXEE.err_method, REXEE.n_bootstrap, REXEE.seed) # noqa: E501
print('Plotting the full-range free energy profile ...')
Expand Down

0 comments on commit 074dd47

Please sign in to comment.