diff --git a/ensemble_md/replica_exchange_EE.py b/ensemble_md/replica_exchange_EE.py index dc8f4f0..4500a4a 100644 --- a/ensemble_md/replica_exchange_EE.py +++ b/ensemble_md/replica_exchange_EE.py @@ -1609,9 +1609,6 @@ def process_top(self): # Determine the atom names corresponding to the atom numbers start_line, atom_name, atom_num, state = coordinate_swap.get_names(input_file, self.resname_list[f]) - print(f'start line #: {start_line}') - print(f'atom name: {atom_name}') - print(f'atom #: {atom_num}') # Determine the connectivity of all atoms connect_1, connect_2, state_1, state_2 = [], [], [], [] # Atom 1 and atom 2 which are connected and which state they are dummy atoms # noqa: E501 diff --git a/ensemble_md/tests/test_coordinate_swap.py b/ensemble_md/tests/test_coordinate_swap.py index 36db8e2..4bb03b7 100644 --- a/ensemble_md/tests/test_coordinate_swap.py +++ b/ensemble_md/tests/test_coordinate_swap.py @@ -346,6 +346,7 @@ def test_swap_name(): def test_get_names(): top_files = ['A-B.itp', 'B-C.itp', 'C-D.itp', 'D-E.itp', 'E-F.itp'] + resnames = ['A2B', 'B2C', 'C2D', 'D2E', 'E2F'] start_lines = [26, 29, 33, 32, 36] names = [['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'H1', 'H2', 'H3', 'H4', 'H17', 'DC7', 'HV5', 'HV6', 'HV7'], ['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'C7', 'H1', 'H2', 'H3', 'H4', 'H5', 'H6', 'H7', 'DC8', 'HV8', 'HV9', 'HV10'], ['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'C7', 'C8', 'H1', 'H2', 'H3', 'H4', 'H6', 'H7', 'H8', 'H9', 'H10', 'DC9', 'HV5', 'HV11', 'HV12', 'HV13'], ['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'C7', 'C9', 'H1', 'H2', 'H3', 'H5', 'H6', 'H7', 'H11', 'H12', 'H13', 'DC8', 'HV8', 'HV9', 'HV10'], ['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'H1', 'H2', 'H3', 'H6', 'H7', 'H8', 'H9', 'H10', 'H11', 'H12', 'H13', 'DC10', 'HV4', 'HV14', 'HV15', 'HV16']] # noqa: E501 @@ -357,7 +358,7 @@ def test_get_names(): [-1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, 1, 1, 1, 0, 0, 0, 0, 0]] for i, top_file in enumerate(top_files): top = open(f'{input_path}/coord_swap/{top_file}', 'r').readlines() - test_start_line, test_names, test_lambda_states = coordinate_swap.get_names(top) + test_start_line, test_names, test_nums, test_lambda_states = coordinate_swap.get_names(top, resnames[i]) assert test_start_line == start_lines[i] assert test_names == names[i] assert test_lambda_states == lambda_states[i] diff --git a/ensemble_md/utils/coordinate_swap.py b/ensemble_md/utils/coordinate_swap.py index bec22f4..166d34f 100644 --- a/ensemble_md/utils/coordinate_swap.py +++ b/ensemble_md/utils/coordinate_swap.py @@ -1099,16 +1099,26 @@ def _swap_name(init_atom_names, new_resname, df_top): continue atom_num = re.findall(r'[0-9]+', atom)[0] atom_identifier = re.findall(r'[a-zA-Z]+', atom)[0] - if 'D' in atom_identifier: - atom_identifier = atom_identifier.strip('D') - if 'V' in atom_identifier: - atom_identifier = atom_identifier.strip('V') - if f'{atom_identifier}V{atom_num}' in new_names: - new_atom_names.append(f'{atom_identifier}V{atom_num}') - elif f'{atom_identifier}{atom_num}' in new_names: - new_atom_names.append(f'{atom_identifier}{atom_num}') - elif f'D{atom_identifier}{atom_num}' in new_names: - new_atom_names.append(f'D{atom_identifier}{atom_num}') + if list(atom_identifier)[0] == 'D': + element = list(atom_identifier)[1] + if len(list(atom_identifier)) > 2: + extra = ''.join(list(atom_identifier)[2:]) + else: + extra = '' + else: + element = list(atom_identifier)[0] + if len(list(atom_identifier)) > 1: + extra = ''.join(list(atom_identifier)[1:]) + else: + extra = '' + if 'V' in extra: + extra = extra.strip('V') + if f'{element}V{extra}{atom_num}' in new_names: + new_atom_names.append(f'{element}V{extra}{atom_num}') + elif f'{element}{extra}{atom_num}' in new_names: + new_atom_names.append(f'{element}{extra}{atom_num}') + elif f'D{element}{extra}{atom_num}' in new_names: + new_atom_names.append(f'D{element}{extra}{atom_num}') else: raise Exception(f'Compatible atom could not be found for {atom}') return new_atom_names @@ -1192,17 +1202,23 @@ def determine_connection(main_only, other_only, main_name, other_name, df_top, m miss, D2R, R2D = [], [], [] align_atom, angle_atom = [], [] for atom in main_only: - element = atom.strip('0123456789') - e_split = list(element) + raw_element = atom.strip('0123456789') + e_split = list(raw_element) if e_split[0] == 'D': real_element = ''.join(e_split[1:]) elif len(e_split) > 2 and e_split[1] == 'V': del e_split[1] real_element = ''.join(e_split) else: - real_element = element + real_element = raw_element + if len(real_element) != 1: + element = list(real_element)[0] + identifier = ''.join(list(real_element)[1:]) + else: + element = real_element + identifier = '' num = atom.strip('ABCDEFGHIJKLMNOPQRSTUVWXYZ') - if f'D{atom}' in other_only or f'{element}V{num}' in other_only: + if f'D{atom}' in other_only or f'{element}V{identifier}{num}' in other_only: D2R.append(atom) elif f'{real_element}{num}' in other_only: R2D.append(atom) @@ -1223,7 +1239,7 @@ def determine_connection(main_only, other_only, main_name, other_name, df_top, m # If the atom connects to non-missing atoms than keep these as anchors for a in connected_atoms: - if a not in miss: + if a not in miss and a not in anchor_atoms: anchor_atoms.append(a) # Seperate missing atoms connected to each anchor