Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ajfriedman22 committed Sep 19, 2024
2 parents b221831 + c789ec2 commit 0451c84
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 87 deletions.
36 changes: 11 additions & 25 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import warnings
import importlib
import subprocess
import mdtraj as md
import pandas as pd
import numpy as np
from mpi4py import MPI
from itertools import combinations
Expand Down Expand Up @@ -1511,24 +1513,17 @@ def run_REXEE(self, n, swap_pattern=None):
# rank 3 that has not been generated, which will lead to an I/O error.
comm.barrier()

def default_coords_fn(self, molA_file_name: str, molB_file_name: str):
def default_coords_fn(self, molA_file_name, molB_file_name):
"""
Swap coordinates between two GRO files
Swaps coordinates between two GRO files.
Parameters
----------
molA_file_name : str
GRO file name for the moleucle to be swapped
GRO file name for the moleucle to be swapped.
molB_file_name : str
GRO file name for the other moleucle to be swapped
Return
------
None
GRO file name for the other moleucle to be swapped.
"""
# Step 1: Load necessary files
import mdtraj as md
import pandas as pd

# Determine name for transformed residue
molA_dir = molA_file_name.rsplit('/', 1)[0] + '/'
molB_dir = molB_file_name.rsplit('/', 1)[0] + '/'
Expand All @@ -1541,27 +1536,27 @@ def default_coords_fn(self, molA_file_name: str, molB_file_name: str):
connection_map = pd.read_csv('residue_connect.csv')
swap_map = pd.read_csv('residue_swap_map.csv')

# Step 2: Read the GRO input coordinate files and open temporary Output files
# Step 1: Read the GRO input coordinate files and open temporary Output files
molA_file = open(molA_file_name, 'r').readlines() # open input file
molB_new_file_name = 'B_hybrid_swap.gro'
molB_new = open(molB_new_file_name, 'w')
molB_file = open(molB_file_name, 'r').readlines() # open input file
molA_new_file_name = 'A_hybrid_swap.gro'
molA_new = open(molA_new_file_name, 'w')

# Step 3: Determine atoms for alignment and swapping
# Step 2: Determine atoms for alignment and swapping
nameA = coordinate_swap.identify_res(molA.topology, swap_map['Swap A'].to_list() + swap_map['Swap B'].to_list()) # noqa: E501
nameB = coordinate_swap.identify_res(molB.topology, swap_map['Swap A'].to_list() + swap_map['Swap B'].to_list()) # noqa: E501
df_atom_swap = coordinate_swap.find_common(molA_file, molB_file, nameA, nameB)

# Step 4: Fix break if present for solvated systems only
# Step 3: Fix break if present for solvated systems only
if len(molA.topology.select('water')) != 0:
A_dimensions = coordinate_swap.get_dimensions(molA_file)
B_dimensions = coordinate_swap.get_dimensions(molB_file)
molA = coordinate_swap.fix_break(molA, nameA, A_dimensions, connection_map[connection_map['Resname'] == nameA]) # noqa: E501
molB = coordinate_swap.fix_break(molB, nameB, B_dimensions, connection_map[connection_map['Resname'] == nameB]) # noqa: E501

# Step 5: Determine coordinates of atoms which need to be reconstructed as we swap coordinates between molecules # noqa: E501
# Step 4: Determine coordinates of atoms which need to be reconstructed as we swap coordinates between molecules # noqa: E501
miss_B = df_atom_swap[(df_atom_swap['Swap'] == 'B2A') & (df_atom_swap['Direction'] == 'miss')]['Name'].to_list() # noqa: E501
miss_A = df_atom_swap[(df_atom_swap['Swap'] == 'A2B') & (df_atom_swap['Direction'] == 'miss')]['Name'].to_list() # noqa: E501
if len(miss_B) != 0:
Expand All @@ -1588,18 +1583,9 @@ def default_coords_fn(self, molA_file_name: str, molB_file_name: str):

def process_top(self):
"""
Process the input topologies in order to determine the atoms for alignment in the default GRO swapping
Processes the input topologies in order to determine the atoms for alignment in the default GRO swapping
function. Output as csv files to prevent needing to re-run this step.
Parameters
----------
None
Return
------
None
"""
import pandas as pd

if not os.path.exists('residue_connect.csv'):
df_top = pd.DataFrame()
for f, file_name in enumerate(self.top):
Expand Down
147 changes: 87 additions & 60 deletions ensemble_md/tests/test_coordinate_swap.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,87 @@
####################################################################
# #
# ensemble_md, #
# a python package for running GROMACS simulation ensembles #
# #
# Written by Wei-Tse Hsu <wehs7661@colorado.edu> #
# Copyright (c) 2022 University of Colorado Boulder #
# #
####################################################################
"""
Unit tests for the module coordinate_swap.py.
"""
from ensemble_md.utils import coordinate_swap
import numpy as np
import pandas as pd

def test_get_dimenstion():
test_file1 = open('ensemble_md/tests/data/coord_swap/input_A.gro', 'r')
test_file2 = open('ensemble_md/tests/data/coord_swap/input_B.gro', 'r')
assert coordinate_swap.get_dimensions(test_file1) == [2.74964, 2.74964, 2.74964]
assert coordinate_swap.get_dimensions(test_file2) == [2.74243, 2.74243, 2.74243]

def test_find_common():
test_file1 = open('ensemble_md/tests/data/coord_swap/input_A.gro', 'r')
test_file2 = open('ensemble_md/tests/data/coord_swap/input_B.gro', 'r')
test_df = coordinate_swap.find_common(test_file1, test_file2, 'C2D', 'D2E')
df = pd.read_csv('ensemble_md/tests/data/coord_swap/find_common.csv')

for index, row in df.iterrows():
test_row = test_df[test_df['Name'] == row['Name']]
assert row['Atom Name Number'] == int(test_row['Atom Name Number'].to_list()[0])
assert row['Element'] == test_row['Element'].to_list()[0]
assert row['Direction'] == test_row['Direction'].to_list()[0]
assert row['Swap'] == test_row['Swap'].to_list()[0]
assert row['File line'] == int(test_row['File line'].to_list()[0])
assert row['Final Type'] == test_row['Final Type'].to_list()[0]

def test_rotate_point_around_axis():
initial_point = np.array([0.16, 0.19, -0.05])
vertex = np.array([0, 0, 0])
axis = np.array([0.15, 0.82, 0.14])
angle = 0.13
rotated_point = [0.1693233, 0.18548463, -0.0335421]
assert coordinate_swap.rotate_point_around_axis(initial_point, vertex, axis, angle) == rotated_point

def test_find_rotation_angle():
initial_point = np.array([0.16, 0.19, -0.05])
vertex = np.array([0, 0, 0])
axis = np.array([0.15, 0.82, 0.14])
rotated_point = [0.1693233, 0.18548463, -0.0335421]
angle = 0.13
test_angle = coordinate_swap.find_rotation_angle(initial_point, vertex, rotated_point, axis)
assert np.isclose(angle, test_angle, 10**(-5))

def test_compute_angle():
vec1_start = [0.13, 0.15, 0.16]
vec1_end = [0.16, 0.18, 0.17]
vec2_start = [0.13, 0.15, 0.16]
vec2_end = [0.11, 0.23, 0.05]
assert np.isclose(coordinate_swap.compute_angle([vec1_start, vec1_end, vec2_start, vec2_end]), 0.991836017536949, atol=10**(-4))
####################################################################
# #
# ensemble_md, #
# a python package for running GROMACS simulation ensembles #
# #
# Written by Wei-Tse Hsu <wehs7661@colorado.edu> #
# Copyright (c) 2022 University of Colorado Boulder #
# #
####################################################################
"""
Unit tests for the module coordinate_swap.py.
"""
from ensemble_md.utils import coordinate_swap
import numpy as np
import pandas as pd
import os

current_path = os.path.dirname(os.path.abspath(__file__))
input_path = os.path.join(current_path, "data")


def test_get_dimensions():
gro = os.path.join(input_path, 'sys.gro')
f = open(gro, 'r')
lines = f.readlines()
f.close()
vec = coordinate_swap.get_dimensions(lines)
assert vec == [3.32017, 3.32017, 2.34772, 0.00000, 0.00000, 0.00000, 0.00000, 1.66009, 1.66009]

# Write a flat file with cubic box dimensions
f = open('test.gro', 'w')
f.write('test\n')
f.write(' 1.00000 2.00000 3.00000\n')
f.close()

f = open('test.gro', 'r')
lines = f.readlines()
f.close()
vec = coordinate_swap.get_dimensions(lines)
assert vec == [1.0, 2.0, 3.0]

os.remove('test.gro')

def test_find_common():
test_file1 = open(f'{input_path}/coord_swap/input_A.gro', 'r')
test_file2 = open(f'{input_path}/coord_swap/input_B.gro', 'r')
test_df = coordinate_swap.find_common(test_file1, test_file2, 'C2D', 'D2E')
df = pd.read_csv(f'{input_path}/coord_swap/find_common.csv')

for index, row in df.iterrows():
test_row = test_df[test_df['Name'] == row['Name']]
assert row['Atom Name Number'] == int(test_row['Atom Name Number'].to_list()[0])
assert row['Element'] == test_row['Element'].to_list()[0]
assert row['Direction'] == test_row['Direction'].to_list()[0]
assert row['Swap'] == test_row['Swap'].to_list()[0]
assert row['File line'] == int(test_row['File line'].to_list()[0])
assert row['Final Type'] == test_row['Final Type'].to_list()[0]

def test_rotate_point_around_axis():
initial_point = np.array([0.16, 0.19, -0.05])
vertex = np.array([0, 0, 0])
axis = np.array([0.15, 0.82, 0.14])
angle = 0.13
rotated_point = [0.1693233, 0.18548463, -0.0335421]
assert coordinate_swap.rotate_point_around_axis(initial_point, vertex, axis, angle) == rotated_point

def test_find_rotation_angle():
initial_point = np.array([0.16, 0.19, -0.05])
vertex = np.array([0, 0, 0])
axis = np.array([0.15, 0.82, 0.14])
rotated_point = [0.1693233, 0.18548463, -0.0335421]
angle = 0.13
test_angle = coordinate_swap.find_rotation_angle(initial_point, vertex, rotated_point, axis)
assert np.isclose(angle, test_angle, 10**(-5))

def test_compute_angle():
coords_1 = [
np.array([0.0, 0.0, 0.0]),
np.array([1.0, 0.0, 0.0]),
np.array([0.0, 1.0, 0.0])
]
coords_2 = coords_1[-1::-1]
coords_3 = [coords_1[1], coords_1[0], coords_1[2]]

assert np.isclose(coordinate_swap.compute_angle(coords_1), np.pi / 4)
assert np.isclose(coordinate_swap.compute_angle(coords_2), np.pi / 4)
assert np.isclose(coordinate_swap.compute_angle(coords_3), np.pi / 2)
4 changes: 2 additions & 2 deletions ensemble_md/utils/coordinate_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,8 @@ def compute_angle(coords):
Parameters
----------
coords : numpy.ndarray
Four points which define two vectors to compute the angle between.
coords : list
A list of numpy arrays containing the XYZ coordinates of 3 points, for which the angle 1-2-3 is to be computed.
Returns
-------
Expand Down

0 comments on commit 0451c84

Please sign in to comment.