Skip to content

Commit

Permalink
Merge pull request #42 from wehs7661/more_tests
Browse files Browse the repository at this point in the history
Adding more unit tests for the package
  • Loading branch information
wehs7661 authored Apr 16, 2024
2 parents f5daa68 + 27320da commit 029896d
Show file tree
Hide file tree
Showing 11 changed files with 2,983 additions and 70 deletions.
25 changes: 18 additions & 7 deletions ensemble_md/analysis/analyze_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ def preprocess_data(files_list, temp, data_type, spacing=1, t=None, g=None):
Neff_max = int((len(data_series.values) - t) / g)

print(f'Subsampling and decorrelating the concatenated {data_type} data ...')
print(f' Adopted spacing: {spacing: .0f}')
print(f' {t / len(data_series) * 100: .1f}% of the {data_type} data was in the equilibrium region and therfore discarded.') # noqa: E501
print(f' Statistical inefficiency of {data_type}: {g: .1f}')
print(f' Number of effective samples: {Neff_max: .0f}\n')
print(f' Adopted spacing: {spacing:.0f}')
print(f' {t / len(data_series) * 100:.1f}% of the {data_type} data was in the equilibrium region and therfore discarded.') # noqa: E501
print(f' Statistical inefficiency of {data_type}: {g:.1f}')
print(f' Number of effective samples: {Neff_max:.0f}\n')

data_series_equil, data_equil = data_series[t:], data[t:]
indices = subsample_correlated_data(data_series_equil, g=g)
Expand Down Expand Up @@ -220,8 +220,11 @@ def _combine_df_adjacent(df_adjacent, df_err_adjacent, state_ranges, err_type):
mean, error = utils.weighted_mean(df_list, df_err_list)

if err_type == 'std':
# overwrite the error calculated above
error = np.std(df_list, ddof=1)
if len(df_list) == 1:
error = df_err_list[0]
else:
# overwrite the error calculated above
error = np.std(df_list, ddof=1)

df.append(mean)
df_err.append(error)
Expand Down Expand Up @@ -291,6 +294,10 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method='prop
if overlap_bool[i] is True:
print(f'Replaced the propagated error with the bootstrapped error for states {i} and {i + 1}: {df_err[i]:.5f} -> {error_bootstrap[i]:.5f}.') # noqa: E501
df_err[i] = error_bootstrap[i]
elif err_method == 'propagate':
pass
else:
raise ParameterError('Specified err_method not available.')

df.insert(0, 0)
df_err.insert(0, 0)
Expand Down Expand Up @@ -328,6 +335,8 @@ def calculate_df_rmse(estimators, df_ref, state_ranges):
df = np.array(estimators[i].delta_f_.iloc[0]) # the first state always has 0 free energy here
ref = df_ref[state_ranges[i]]
ref -= ref[0] # shift the free energy of the first state in the range to 0
print(df)
print(ref)
rmse_list.append(np.sqrt(np.sum((df - ref) ** 2) / len(df)))

return rmse_list
Expand Down Expand Up @@ -381,7 +390,9 @@ def average_weights(g_vecs, frac):
for i in range(N):
dg.append(g_vecs[i][-1] - g_vecs[i][0])
n = int(np.floor(N * frac))
if n <= 1:
print('The number of samples to be averaged is less than 2, so all samples will be averaged.')
dg_avg = np.mean(dg[-n:])
dg_avg_err = np.std(dg_avg[-n:], ddof=1)
dg_avg_err = np.std(dg[-n:], ddof=1)

return dg_avg, dg_avg_err
3 changes: 2 additions & 1 deletion ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def traj2transmtx(traj, N, normalize=True):
Parameters
---------
traj : list
A list of state indices showing the trajectory in the state space.
A list of state indices showing the trajectory in the state space. The index
is assumed to start from 0.
N : int
The size (N) of the expcted transition matrix (N by N).
normalize : bool
Expand Down
123 changes: 92 additions & 31 deletions ensemble_md/analysis/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
# #
####################################################################
import numpy as np
import matplotlib.pyplot as plt
from ensemble_md.utils.utils import run_gmx_cmd
from ensemble_md.analysis import analyze_traj


def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkage', cutoff=0.1, suffix=None):
Expand Down Expand Up @@ -44,6 +46,18 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkag
suffix : str
The suffix for the output files. The default is :code:`None`, which means no suffix will be added.
"""
# Check input parameters
required_keys_1 = ['traj', 'config', 'xvg', 'index']
for key in required_keys_1:
if key not in inputs:
raise ValueError(f'The key "{key}" is missing in the inputs dictionary.')
required_keys_2 = ['center', 'rmsd', 'output']
for key in required_keys_2:
if key not in grps:
raise ValueError(f'The key "{key}" is missing in the grps dictionary.')
if coupled_only and inputs['xvg'] is None:
raise ValueError('The parameter "coupled_only" is set to True but no XVG file is provided.')

# Check if the index file is provided
if inputs['index'] is None:
print('Running gmx make_ndx to generate an index file ...')
Expand All @@ -60,13 +74,13 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkag
content = f.read()
for key in grps:
if grps[key] not in content:
raise ValueError(f'The group {grps[key]} is not present in the provided/generated index file.')
raise ValueError(f'The group "{grps[key]}" is not present in the provided/generated index file.')

outputs = {
'nojump': 'nojump.xtc',
'center': 'center.xtc',
'rmsd-clust': 'rmsd-clust.xpm',
'rmsd-dist': 'rmsd-dist.xvg',
'rmsd-clust': 'rmsd_clust.xpm',
'rmsd-dist': 'rmsd_dist.xvg',
'cluster-log': 'cluster.log',
'cluster-pdb': 'clusters.pdb',
'rmsd': 'rmsd.xvg', # inter-medoid RMSD
Expand All @@ -93,16 +107,14 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkag
]

if coupled_only:
if inputs['xvg'] is None:
raise ValueError('The parameter "coupled_only" is set to True but no XVG file is provided.')
args.extend([
'-drop', inputs['xvg'],
'-dropover', '0'
])

returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')
raise ValueError(f'Error with return code {returncode}:\n{stderr}')

print('Centering the system ...')
args = [
Expand All @@ -117,7 +129,7 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkag
]
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')
raise ValueError(f'Error with return code {returncode}:\n{stderr}')

if coupled_only is True:
N_coupled = np.count_nonzero(lambda_data == 0)
Expand All @@ -138,7 +150,7 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkag
]
returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["output"]}\n')
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')
raise ValueError(f'Error with return code {returncode}:\n{stderr}')

rmsd_range, rmsd_avg, n_clusters = get_cluster_info(outputs['cluster-log'])

Expand All @@ -151,9 +163,12 @@ def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkag
for i in range(1, n_clusters + 1):
print(f' - Cluster {i} accounts for {sizes[i] * 100:.2f}% of the total configurations.')

n_transitions, t_transitions = count_transitions(clusters)
print(f'Number of transitions between the two biggest clusters: {n_transitions}')
print(f'Time frames of the transitions (ps): {t_transitions}')
if n_clusters == 2:
transmtx, _, t_transitions = analyze_transitions(clusters, normalize=False) # Note that this is a 2D count matrix. # noqa: E501
n_transitions = np.sum(transmtx) - np.trace(transmtx) # This is the sum of all off-diagonal elements. np.trace calculates the sum of the diagonal elements. # noqa: E501
print(f'Number of transitions between the two clusters: {n_transitions}')
if n_transitions > 0:
print(f'Time frames of the transitions (ps): {t_transitions[(1, 2)]}')

print('Calculating the inter-medoid RMSD between the two biggest clusters ...')
# Note that we pass outputs['cluster-pdb'] to -s so that the first medoid will be used as the reference
Expand Down Expand Up @@ -261,35 +276,81 @@ def get_cluster_members(cluster_log):
return clusters, sizes


def count_transitions(clusters):
def analyze_transitions(clusters, normalize=True, plot_type=None):
"""
Counts the number of transitions between the two biggest clusters.
Analyzes transitions between clusters, including estimating the transition matrix, generating/plotting a trajectory
showing which cluster each configuration belongs to, and/or plotting the distribution of the clusters.
Parameters
----------
clusters : dict
A dictionary that contains the cluster index (starting from 1) as the key and the list of members
(configurations at different timeframes) as the value.
(configurations at different timeframes in ps) as the value.
plot_type : str
The type of the figure to be plotted. The default is :code:`None`, which means no figure will be plotted.
The other options are :code:`'bar'` and :code:`'xy'`. The former plots the distribution of the clusters,
while the latter plots the trajectory showing which cluster each configuration belongs to.
Returns
-------
n_transitions : int
The number of transitions between the two biggest clusters.
t_transitions : list
The list of time frames when the transitions occur.
transmtx: np.ndarray
The transition matrix.
traj: np.ndarray
The trajectory showing which cluster each configuration belongs to.
t_transitions: dict
A dictionary with keys being pairs of cluster indices and values being the time frames of transitions
between the two clusters. If there is no transition, an empty dictionary will be returned.
"""
# Combine and sort all cluster members for the first two biggest clusters while keeping track of their origin
all_members = [(member, 1) for member in clusters[1]] + [(member, 2) for member in clusters[2]]
# Combine all cluster members and sort them
all_members = []
for key in clusters:
all_members.extend([(member, key) for member in clusters[key]])
all_members.sort()

# Count transitions and record time frames
n_transitions = 0
t_transitions = []
last_cluster = all_members[0][1] # the cluster index of the last time frame in the previous iteration
for member in all_members[1:]:
if member[1] != last_cluster:
n_transitions += 1
last_cluster = member[1]
t_transitions.append(member[0])

return n_transitions, t_transitions
# Generate the trajectory
t = np.array([member[0] for member in all_members])
traj = np.array([member[1] for member in all_members])

# Generate the transition matrix
# Since traj2transmtx assumes an index starting from 0, we subtract 1 from the trajectory
transmtx = analyze_traj.traj2transmtx(traj - 1, len(clusters), normalize=normalize)

# Generate the dictionary of transitions
t_transitions = {}
for i in range(len(traj) - 1):
if traj[i] != traj[i + 1]:
pair = tuple(sorted((traj[i], traj[i + 1])))
if pair not in t_transitions:
t_transitions[pair] = [t[i + 1]]
else:
t_transitions[pair].append(t[i + 1])

if plot_type is not None:
if plot_type == 'bar':
fig = plt.figure()
ax = fig.add_subplot(111)
plt.bar(clusters.keys(), [len(clusters[i]) for i in clusters], width=0.35)
plt.xlabel('Cluster index')
plt.ylabel('Number of configurations')
plt.grid()
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
plt.savefig('cluster_distribution.png', dpi=600)
elif plot_type == 'xy':
fig = plt.figure()
ax = fig.add_subplot(111)
if len(t) > 1000:
t = t / 1000 # convert to ns
units = 'ns'
else:
units = 'ps'
plt.plot(t, traj)
plt.xlabel(f'Time frame ({units})')
plt.ylabel('Cluster index')
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
plt.grid()
plt.savefig('cluster_traj.png', dpi=600)
else:
raise ValueError(f'Invalid plot type: {plot_type}. The plot type must be either "bar" or "xy" or unspecified.') # noqa: E501

return transmtx, traj, t_transitions
9 changes: 7 additions & 2 deletions ensemble_md/analysis/synthesize_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""
import numpy as np
from ensemble_md.analysis import analyze_traj
from ensemble_md.analysis import analyze_matrix


def synthesize_traj(trans_mtx, n_frames=100000, method='transmtx', start=0, seed=None):
Expand Down Expand Up @@ -48,8 +49,12 @@ def synthesize_traj(trans_mtx, n_frames=100000, method='transmtx', start=0, seed
if seed is not None:
np.random.seed(seed)
N = len(trans_mtx) # Can be the number of states or replicas depending on the type of the input mtraix

if start >= N:
raise ValueError(f'The starting state {start} is out of the range of the input transition matrix.')

if method == 'equil_prob':
equil_prob = analyze_traj.calc_equil_prob(trans_mtx)
equil_prob = analyze_matrix.calc_equil_prob(trans_mtx)
syn_traj = np.random.choice(N, size=n_frames, p=equil_prob.reshape(N))
elif method == 'transmtx':
check_row = sum([np.isclose(np.sum(trans_mtx[i]), 1, atol=1e-8) for i in range(len(trans_mtx))])
Expand All @@ -59,7 +64,7 @@ def synthesize_traj(trans_mtx, n_frames=100000, method='transmtx', start=0, seed
elif check_col == N:
mtx = trans_mtx.T
else:
raise ValueError('The input matrix is not normalized')
raise ValueError('The input matrix is not normalized.')

syn_traj = np.zeros(n_frames, dtype=int)
syn_traj[0] = start
Expand Down
21 changes: 21 additions & 0 deletions ensemble_md/tests/data/cluster.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Using linkage method for clustering
Using RMSD cutoff 0.13 nm
The RMSD ranges from 0.0236461 to 0.316756 nm
Average RMSD is 0.182848
Number of structures for matrix 56
Energy of the matrix is 0.21062.

Found 2 clusters

Writing middle structure for each cluster to clusters_0.pdb

cl. | #st rmsd | middle rmsd | cluster members
1 | 27 0.069 | 300 .055 | 0 178 184 186 300 302 304
| | | 306 308 310 312 318 362 366
| | | 370 372 374 376 378 380 382
| | | 390 460 464 468 470 476
2 | 29 0.075 | 12092 .061 | 11910 11992 11996 12014 12054 12058 12062
| | | 12064 12084 12092 12098 12100 12102 12104
| | | 12106 12108 12110 12112 12114 12116 12118
| | | 12120 12242 12262 12310 12318 12330 12334
| | | 12340
Loading

0 comments on commit 029896d

Please sign in to comment.