From 1887b0313894f7df6428560d46487b27d3c912a8 Mon Sep 17 00:00:00 2001 From: chrisjonesBSU Date: Tue, 2 Jul 2024 15:07:22 -0600 Subject: [PATCH 1/5] add changes from ruff hook --- .pre-commit-config.yaml | 24 ++--- docs/source/conf.py | 5 +- foyer/atomtyper.py | 16 +-- foyer/forcefield.py | 118 ++++++---------------- foyer/smarts_graph.py | 47 +++------ foyer/tests/base_test.py | 4 +- foyer/tests/test_atomtyping.py | 13 +-- foyer/tests/test_forcefield.py | 106 +++++-------------- foyer/tests/test_forcefield_parameters.py | 44 ++------ foyer/tests/test_graph.py | 8 +- foyer/tests/test_opls.py | 24 ++--- foyer/tests/test_plugin.py | 4 +- foyer/tests/test_smarts.py | 8 +- foyer/tests/test_topology_graph.py | 12 +-- foyer/tests/test_utils.py | 4 +- foyer/tests/utils.py | 1 + foyer/topology_graph.py | 7 +- foyer/utils/io.py | 20 ++-- foyer/utils/nbfixes.py | 5 +- foyer/validator.py | 22 ++-- foyer/xml_writer.py | 56 ++++------ setup.py | 6 +- 22 files changed, 157 insertions(+), 397 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 480b85a2..882a17ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,27 +8,27 @@ ci: skip: [] submodules: false repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.5.0 + hooks: + # Run the linter. + - id: ruff + args: [--line-length=80, --fix] + # Run the formatter. + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - exclude: 'setup.cfg' -- repo: https://github.com/psf/black - rev: 24.4.2 - hooks: - - id: black - args: [--line-length=80] + exclude: 'setup.cfg|foyer/tests/files/.*' + - repo: https://github.com/pycqa/isort rev: 5.13.2 hooks: - id: isort name: isort (python) args: [--profile=black, --line-length=80] -- repo: https://github.com/pycqa/pydocstyle - rev: '6.3.0' - hooks: - - id: pydocstyle - exclude: ^(foyer/tests/|docs/|devtools/|setup.py) - args: [--convention=numpy] + exclude: "foyer/tests/files/.*" diff --git a/docs/source/conf.py b/docs/source/conf.py index d00fd8ee..ff30b3ad 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -18,6 +18,8 @@ import pathlib import sys +import sphinx_rtd_theme + sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("sphinxext")) @@ -25,8 +27,6 @@ os.system("python {} --name".format((base_path / "../../setup.py").resolve())) -import foyer - # -- Project information ----------------------------------------------------- project = "foyer" @@ -147,7 +147,6 @@ # a list of builtin themes. # # html_theme = 'alabaster' -import sphinx_rtd_theme html_theme = "sphinx_rtd_theme" hhtml_theme_path = [sphinx_rtd_theme.get_html_theme_path()] diff --git a/foyer/atomtyper.py b/foyer/atomtyper.py index f12a160d..fe477de8 100644 --- a/foyer/atomtyper.py +++ b/foyer/atomtyper.py @@ -79,9 +79,7 @@ def find_atomtypes(structure, forcefield, max_iter=10): topology_graph = TopologyGraph.from_gmso_topology(structure) if isinstance(forcefield, Forcefield): - atomtype_rules = AtomTypingRulesProvider.from_foyer_forcefield( - forcefield - ) + atomtype_rules = AtomTypingRulesProvider.from_foyer_forcefield(forcefield) elif isinstance(forcefield, AtomTypingRulesProvider): atomtype_rules = forcefield else: @@ -110,9 +108,7 @@ def find_atomtypes(structure, forcefield, max_iter=10): atomic_number = atom_data.atomic_number atomic_symbol = atom_data.element try: - element_from_num = ele.element_from_atomic_number( - atomic_number - ).symbol + element_from_num = ele.element_from_atomic_number(atomic_number).symbol element_from_sym = ele.element_from_symbol(atomic_symbol).symbol assert element_from_num == element_from_sym system_elements.add(element_from_num) @@ -210,13 +206,9 @@ def _iterate_rules(rules, topology_graph, typemap, max_iter): def _resolve_atomtypes(topology_graph, typemap): """Determine the final atomtypes from the white- and blacklists.""" - atoms = { - atom_idx: data for atom_idx, data in topology_graph.atoms(data=True) - } + atoms = {atom_idx: data for atom_idx, data in topology_graph.atoms(data=True)} for atom_id, atom in typemap.items(): - atomtype = [ - rule_name for rule_name in atom["whitelist"] - atom["blacklist"] - ] + atomtype = [rule_name for rule_name in atom["whitelist"] - atom["blacklist"]] if len(atomtype) == 1: atom["atomtype"] = atomtype[0] elif len(atomtype) > 1: diff --git a/foyer/forcefield.py b/foyer/forcefield.py index 1e02f296..3c6ecdc4 100644 --- a/foyer/forcefield.py +++ b/foyer/forcefield.py @@ -146,7 +146,7 @@ def generate_topology(non_omm_topology, non_element_types=None, residues=None): return _topology_from_parmed(non_omm_topology, non_element_types) elif has_mbuild: mb = import_("mbuild") - if (non_omm_topology, mb.Compound): + if all([non_omm_topology, mb.Compound]): pmd_comp_struct = non_omm_topology.to_parmed(residues=residues) return _topology_from_parmed(pmd_comp_struct, non_element_types) else: @@ -162,16 +162,12 @@ def generate_topology(non_omm_topology, non_element_types=None, residues=None): def _structure_from_residue(residue, parent_structure): """Convert a ParmEd Residue to an equivalent Structure.""" structure = pmd.Structure() - orig_to_copy = ( - dict() - ) # Clone a lot of atoms to avoid any of parmed's tracking + orig_to_copy = dict() # Clone a lot of atoms to avoid any of parmed's tracking for atom in residue.atoms: new_atom = copy(atom) new_atom._idx = atom.idx orig_to_copy[atom] = new_atom - structure.add_atom( - new_atom, resname=residue.name, resnum=residue.number - ) + structure.add_atom(new_atom, resname=residue.name, resnum=residue.number) for bond in parent_structure.bonds: if bond.atom1 in residue.atoms and bond.atom2 in residue.atoms: @@ -198,10 +194,7 @@ def _topology_from_parmed(structure, non_element_types): if pmd_atom.name in non_element_types: element = non_element_types[pmd_atom.name] else: - if ( - isinstance(pmd_atom.atomic_number, int) - and pmd_atom.atomic_number != 0 - ): + if isinstance(pmd_atom.atomic_number, int) and pmd_atom.atomic_number != 0: element = elem.Element.getByAtomicNumber(pmd_atom.atomic_number) else: element = elem.Element.getBySymbol(pmd_atom.name) @@ -221,9 +214,7 @@ def _topology_from_parmed(structure, non_element_types): topology.addBond(atom1, atom2) atom1.bond_partners.append(atom2) atom2.bond_partners.append(atom1) - if structure.box_vectors and np.any( - [x._value for x in structure.box_vectors] - ): + if structure.box_vectors and np.any([x._value for x in structure.box_vectors]): topology.setPeriodicBoxVectors(structure.box_vectors) positions = structure.positions @@ -293,9 +284,7 @@ def _unwrap_typemap(structure, residue_map): for res_ref, val in residue_map.items(): if id(res.name) == id(res_ref): for i, atom in enumerate(res.atoms): - master_typemap[int(atom.idx)]["atomtype"] = val[i][ - "atomtype" - ] + master_typemap[int(atom.idx)]["atomtype"] = val[i]["atomtype"] return master_typemap @@ -325,9 +314,7 @@ def _separate_urey_bradleys(system, topology): ) not in bonds: ub_force.addBond(*force.getBondParameters(bond_idx)) else: - harmonic_bond_force.addBond( - *force.getBondParameters(bond_idx) - ) + harmonic_bond_force.addBond(*force.getBondParameters(bond_idx)) system.removeForce(force_idx) system.addForce(harmonic_bond_force) @@ -499,9 +486,7 @@ class Forcefield(app.ForceField): """ - def __init__( - self, forcefield_files=None, name=None, validation=True, debug=False - ): + def __init__(self, forcefield_files=None, name=None, validation=True, debug=False): self.atomTypeDefinitions = dict() self.atomTypeOverrides = dict() self.atomTypeDesc = dict() @@ -539,13 +524,9 @@ def __init__( if len(preprocessed_files) == 1: self._version = self._parse_version_number(preprocessed_files[0]) self._name = self._parse_name(preprocessed_files[0]) - self._combining_rule = self._parse_combining_rule( - preprocessed_files[0] - ) + self._combining_rule = self._parse_combining_rule(preprocessed_files[0]) elif len(preprocessed_files) > 1: - self._version = [ - self._parse_version_number(f) for f in preprocessed_files - ] + self._version = [self._parse_version_number(f) for f in preprocessed_files] self._name = [self._parse_name(f) for f in preprocessed_files] self._combining_rule = [ self._parse_combining_rule(f) for f in preprocessed_files @@ -639,9 +620,7 @@ def _parse_name(self, forcefield_file): try: return root.attrib["name"] except KeyError: - warnings.warn( - "No force field name found in force field XML file." - ) + warnings.warn("No force field name found in force field XML file.") return None def _parse_combining_rule(self, forcefield_file): @@ -651,9 +630,7 @@ def _parse_combining_rule(self, forcefield_file): try: return root.attrib["combining_rule"] except KeyError: - warnings.warn( - "No combining rule found in force field XML file." - ) + warnings.warn("No combining rule found in force field XML file.") return "lorentz" def _create_element(self, element, mass): @@ -679,9 +656,7 @@ def registerAtomType(self, parameters): """Register a new atom type.""" name = parameters["name"] if name in self._atomTypes: - raise ValueError( - "Found multiple definitions for atom type: " + name - ) + raise ValueError("Found multiple definitions for atom type: " + name) atom_class = parameters["class"] mass = _convertParameterToNumber(parameters["mass"]) element = None @@ -846,10 +821,7 @@ def run_atomtyping(self, structure, use_residue_map=True, **kwargs): # Need to call this only once and store results for later id() comparisons for res_id, res in enumerate(structure.residues): - if ( - structure.residues[res_id].name - not in residue_map.keys() - ): + if structure.residues[res_id].name not in residue_map.keys(): tmp_res = _structure_from_residue(res, structure) typemap = find_atomtypes(tmp_res, forcefield=self) residue_map[res.name] = typemap @@ -877,9 +849,7 @@ def parametrize_system( **kwargs, ): """Create system based on resulting typemapping.""" - topology, positions = _topology_from_parmed( - structure, self.non_element_types - ) + topology, positions = _topology_from_parmed(structure, self.non_element_types) system = self.createSystem(topology, *args, **kwargs) @@ -918,9 +888,7 @@ def parametrize_system( ) if self.combining_rule == "geometric": - self._patch_parmed_adjusts( - structure, combining_rule=self.combining_rule - ) + self._patch_parmed_adjusts(structure, combining_rule=self.combining_rule) total_charge = sum([atom.charge for atom in structure.atoms]) if not np.allclose(total_charge, 0): @@ -1032,9 +1000,7 @@ def createSystem( elem.hydrogen, None, ): - transfer_mass = hydrogenMass - sys.getParticleMass( - atom2.index - ) + transfer_mass = hydrogenMass - sys.getParticleMass(atom2.index) sys.setParticleMass(atom2.index, hydrogenMass) mass = sys.getParticleMass(atom1.index) - transfer_mass sys.setParticleMass(atom1.index, mass) @@ -1091,9 +1057,7 @@ def createSystem( bonded_to = data.bondedToAtom[atom] if len(bonded_to) > 2: for subset in itertools.combinations(bonded_to, 3): - data.impropers.append( - (atom, subset[0], subset[1], subset[2]) - ) + data.impropers.append((atom, subset[0], subset[1], subset[2])) # Identify bonds that should be implemented with constraints if constraints == AllBonds or constraints == HAngles: @@ -1188,15 +1152,9 @@ def createSystem( site.originWeights[1], site.originWeights[2], ), - mm.Vec3( - site.xWeights[0], site.xWeights[1], site.xWeights[2] - ), - mm.Vec3( - site.yWeights[0], site.yWeights[1], site.yWeights[2] - ), - mm.Vec3( - site.localPos[0], site.localPos[1], site.localPos[2] - ), + mm.Vec3(site.xWeights[0], site.xWeights[1], site.xWeights[2]), + mm.Vec3(site.yWeights[0], site.yWeights[1], site.yWeights[2]), + mm.Vec3(site.localPos[0], site.localPos[1], site.localPos[2]), ) sys.setVirtualSite(index, local_coord_site) @@ -1263,9 +1221,7 @@ def _write_references_to_file(self, atom_types, references_file): for atomtype, dois in atomtype_references.items(): for doi in dois: unique_references[doi].append(atomtype) - unique_references = collections.OrderedDict( - sorted(unique_references.items()) - ) + unique_references = collections.OrderedDict(sorted(unique_references.items())) with open(references_file, "w") as f: for doi, atomtypes in unique_references.items(): url = "http://api.crossref.org/works/{}/transform/application/x-bibtex".format( @@ -1338,11 +1294,7 @@ def get_parameters(self, group, key, keys_are_atom_classes=False): if group not in param_extractors: raise ValueError(f"Cannot extract parameters for {group}") - key = ( - [key] - if isinstance(key, str) or not isinstance(key, Iterable) - else key - ) + key = [key] if isinstance(key, str) or not isinstance(key, Iterable) else key validate_type(key, str) @@ -1367,18 +1319,14 @@ def _extract_non_bonded_params(self, atom_type): atom_type = atom_type[0] - non_bonded_forces_gen = self.get_generator( - ff=self, gen_type=NonbondedGenerator - ) + non_bonded_forces_gen = self.get_generator(ff=self, gen_type=NonbondedGenerator) non_bonded_params = non_bonded_forces_gen.params.paramsForType try: return non_bonded_params[atom_type] except KeyError: - raise MissingParametersError( - f"Missing parameters for atom {atom_type}" - ) + raise MissingParametersError(f"Missing parameters for atom {atom_type}") def _extract_harmonic_bond_params(self, atom_types): """Return parameters for a specific HarmonicBondForce between atom types.""" @@ -1548,9 +1496,7 @@ def _extract_rb_proper_params(self, atom_types): f"be extracted for four atoms. Provided {len(atom_types)}" ) - rb_torsion_force_gen = self.get_generator( - ff=self, gen_type=RBTorsionGenerator - ) + rb_torsion_force_gen = self.get_generator(ff=self, gen_type=RBTorsionGenerator) wildcard = self._atomClasses[""] ( @@ -1600,9 +1546,7 @@ def _extract_rb_improper_params(self, atom_types): f"be extracted for four atoms. Provided {len(atom_types)}" ) - rb_torsion_force_gen = self.get_generator( - ff=self, gen_type=RBTorsionGenerator - ) + rb_torsion_force_gen = self.get_generator(ff=self, gen_type=RBTorsionGenerator) match = self._match_impropers(atom_types, rb_torsion_force_gen) @@ -1622,9 +1566,7 @@ def map_atom_classes_to_types(self, atom_classes_keys, strict=False): # When to do this substitution with wildcards? substitution = self._atomClasses.get(key) if not substitution: - raise ValueError( - f"Atom class {key} is missing from the Forcefield" - ) + raise ValueError(f"Atom class {key} is missing from the Forcefield") atom_type_keys.append(next(iter(substitution))) return atom_type_keys @@ -1715,9 +1657,7 @@ def get_generator(ff, gen_type): @staticmethod def substitute_wildcards(atom_types, wildcard): """Return possible wildcard options.""" - return tuple( - atom_type or next(iter(wildcard)) for atom_type in atom_types - ) + return tuple(atom_type or next(iter(wildcard)) for atom_type in atom_types) pmd.Structure.write_foyer = write_foyer diff --git a/foyer/smarts_graph.py b/foyer/smarts_graph.py index febe2759..2fd3d01b 100644 --- a/foyer/smarts_graph.py +++ b/foyer/smarts_graph.py @@ -46,7 +46,7 @@ def __init__( overrides=None, typemap=None, *args, - **kwargs + **kwargs, ): super(SMARTSGraph, self).__init__(*args, **kwargs) @@ -120,23 +120,17 @@ def _atom_expr_matches(self, atom_expr, atom, bond_partners): elif atom_expr.data in ("and_expression", "weak_and_expression"): return self._atom_expr_matches( atom_expr.children[0], atom, bond_partners - ) and self._atom_expr_matches( - atom_expr.children[1], atom, bond_partners - ) + ) and self._atom_expr_matches(atom_expr.children[1], atom, bond_partners) elif atom_expr.data == "or_expression": return self._atom_expr_matches( atom_expr.children[0], atom, bond_partners - ) or self._atom_expr_matches( - atom_expr.children[1], atom, bond_partners - ) + ) or self._atom_expr_matches(atom_expr.children[1], atom, bond_partners) elif atom_expr.data == "atom_id": return self._atom_id_matches( atom_expr.children[0], atom, bond_partners, self.typemap ) elif atom_expr.data == "atom_symbol": - return self._atom_id_matches( - atom_expr, atom, bond_partners, self.typemap - ) + return self._atom_id_matches(atom_expr, atom, bond_partners, self.typemap) else: raise TypeError( "Expected atom_id, atom_symbol, and_expression, " @@ -162,9 +156,7 @@ def _atom_id_matches(atom_id, atom, bond_partners, typemap): else: return atomic_num == pt.AtomicNum[str(atom_id.children[0])] elif atom_id.data == "has_label": - label = atom_id.children[0][ - 1: - ] # Strip the % sign from the beginning. + label = atom_id.children[0][1:] # Strip the % sign from the beginning. return label in typemap[atom_idx]["whitelist"] elif atom_id.data == "neighbor_count": return len(bond_partners) == int(atom_id.children[0]) @@ -205,9 +197,7 @@ def find_matches(self, topology_graph, typemap): """ # Note: Needs to be updated in sync with the grammar in `smarts.py`. ring_tokens = ["ring_size", "ring_count"] - has_ring_rules = any( - list(self.ast.find_data(token)) for token in ring_tokens - ) + has_ring_rules = any(list(self.ast.find_data(token)) for token in ring_tokens) topology_graph.add_bond_partners() _prepare_atoms(topology_graph, typemap, compute_cycles=has_ring_rules) @@ -220,9 +210,7 @@ def find_matches(self, topology_graph, typemap): element = next(atom.find_data("atom_symbol")).children[0] except IndexError: try: - atomic_num = next( - atom.find_data("atomic_num") - ).children[0] + atomic_num = next(atom.find_data("atomic_num")).children[0] element = pt.Element[int(atomic_num)] except IndexError: element = None @@ -275,9 +263,7 @@ def candidate_pairs_iter(self): else: # First we determine the candidate node for G2 other_node = min(G2_nodes - set(self.core_2)) - host_nodes = ( - self.valid_nodes if other_node == 0 else self.G1.nodes() - ) + host_nodes = self.valid_nodes if other_node == 0 else self.G1.nodes() for node in host_nodes: if node not in self.core_1: yield node, other_node @@ -326,23 +312,17 @@ def _find_chordless_cycles(bond_graph, max_cycle_size): """ new_possible_rings = [] for possible_ring in possible_rings: - next_neighbors = list( - bond_graph.neighbors(possible_ring[-1]) - ) + next_neighbors = list(bond_graph.neighbors(possible_ring[-1])) for next_neighbor in next_neighbors: if next_neighbor != possible_ring[-2]: - new_possible_rings.append( - possible_ring + [next_neighbor] - ) + new_possible_rings.append(possible_ring + [next_neighbor]) possible_rings = new_possible_rings for possible_ring in possible_rings: if bond_graph.has_edge(possible_ring[-1], last_node): if any( [ - bond_graph.has_edge( - possible_ring[-1], internal_node - ) + bond_graph.has_edge(possible_ring[-1], internal_node) for internal_node in possible_ring[1:-2] ] ): @@ -351,10 +331,7 @@ def _find_chordless_cycles(bond_graph, max_cycle_size): cycles[i].append(possible_ring) connected = True - if ( - not possible_rings - or len(possible_rings[0]) == max_cycle_size - ): + if not possible_rings or len(possible_rings[0]) == max_cycle_size: break return cycles diff --git a/foyer/tests/base_test.py b/foyer/tests/base_test.py index 7b4ff932..69f38e39 100644 --- a/foyer/tests/base_test.py +++ b/foyer/tests/base_test.py @@ -7,9 +7,7 @@ from foyer import forcefields from foyer.smarts import SMARTS -OPLS_TEST_FILE_DIR = Path( - resource_filename("foyer", "opls_validation") -).resolve() +OPLS_TEST_FILE_DIR = Path(resource_filename("foyer", "opls_validation")).resolve() class BaseTest: diff --git a/foyer/tests/test_atomtyping.py b/foyer/tests/test_atomtyping.py index 6dd05274..eda9fbcf 100644 --- a/foyer/tests/test_atomtyping.py +++ b/foyer/tests/test_atomtyping.py @@ -1,6 +1,5 @@ import parmed as pmd import pytest -from pkg_resources import resource_filename from foyer import Forcefield from foyer.exceptions import FoyerError @@ -13,16 +12,12 @@ class TestRunAtomTyping(BaseTest): def missing_overrides_ff(self): return Forcefield(get_fn("missing_overrides.xml")) - def test_missing_overrides( - self, opls_validation_benzene, missing_overrides_ff - ): + def test_missing_overrides(self, opls_validation_benzene, missing_overrides_ff): with pytest.raises(FoyerError): missing_overrides_ff.apply(opls_validation_benzene) def test_missing_definition(self, missing_overrides_ff): - structure = pmd.load_file( - get_fn("silly_chemistry.mol2"), structure=True - ) + structure = pmd.load_file(get_fn("silly_chemistry.mol2"), structure=True) with pytest.raises(FoyerError): missing_overrides_ff.apply(structure) @@ -36,9 +31,7 @@ def test_element_not_found(self, oplsaa, atomic_num, symbol): # The central C has bad element info, which will trigger an error # during the atomtyping step top_graph = TopologyGraph() - top_graph.add_atom( - name="C", index=0, atomic_number=atomic_num, symbol=symbol - ) + top_graph.add_atom(name="C", index=0, atomic_number=atomic_num, symbol=symbol) for i in range(1, 5): top_graph.add_atom(name="H", index=i, atomic_number=1, symbol="H") top_graph.add_bond(0, i) diff --git a/foyer/tests/test_forcefield.py b/foyer/tests/test_forcefield.py index 2e9ad6ca..785b66d9 100644 --- a/foyer/tests/test_forcefield.py +++ b/foyer/tests/test_forcefield.py @@ -66,7 +66,7 @@ def test_load_files(self, ff_file): def test_duplicate_type_definitions(self): with pytest.raises(ValueError): - ff4 = Forcefield(name="oplsaa", forcefield_files=FORCEFIELDS) + Forcefield(name="oplsaa", forcefield_files=FORCEFIELDS) def test_missing_type_definitions(self): with pytest.raises(FoyerError): @@ -132,14 +132,12 @@ def test_write_refs(self, requests_mock, oplsaa): text=RESPONSE_BIB_ETHANE_JA962170, ) mol2 = mb.load(get_fn("ethane.mol2")) - ethane = oplsaa.apply(mol2, references_file="ethane.bib") + oplsaa.apply(mol2, references_file="ethane.bib") assert os.path.isfile("ethane.bib") with open(get_fn("ethane.bib")) as file1: with open("ethane.bib") as file2: diff = list( - difflib.unified_diff( - file1.readlines(), file2.readlines(), n=0 - ) + difflib.unified_diff(file1.readlines(), file2.readlines(), n=0) ) assert not diff @@ -163,14 +161,12 @@ def test_write_refs_multiple(self, requests_mock): ) mol2 = mb.load(get_fn("ethane.mol2")) oplsaa = Forcefield(forcefield_files=get_fn("refs-multi.xml")) - ethane = oplsaa.apply(mol2, references_file="ethane-multi.bib") + oplsaa.apply(mol2, references_file="ethane-multi.bib") assert os.path.isfile("ethane-multi.bib") with open(get_fn("ethane-multi.bib")) as file1: with open("ethane-multi.bib") as file2: diff = list( - difflib.unified_diff( - file1.readlines(), file2.readlines(), n=0 - ) + difflib.unified_diff(file1.readlines(), file2.readlines(), n=0) ) assert not diff @@ -188,7 +184,7 @@ def test_write_bad_ref(self, requests_mock): mol2 = mb.load(get_fn("ethane.mol2")) oplsaa = Forcefield(forcefield_files=get_fn("refs-bad.xml")) with pytest.warns(UserWarning): - ethane = oplsaa.apply(mol2, references_file="ethane.bib") + oplsaa.apply(mol2, references_file="ethane.bib") def test_preserve_resname(self, oplsaa): untyped_ethane = pmd.load_file(get_fn("ethane.mol2"), structure=True) @@ -211,9 +207,7 @@ def test_from_mbuild_customtype(self): import mbuild as mb mol2 = mb.load(get_fn("ethane_customtype.pdb")) - customtype_ff = Forcefield( - forcefield_files=get_fn("validate_customtypes.xml") - ) + customtype_ff = Forcefield(forcefield_files=get_fn("validate_customtypes.xml")) ethane = customtype_ff.apply(mol2) assert sum((1 for at in ethane.atoms if at.type == "C3")) == 2 @@ -227,12 +221,8 @@ def test_from_mbuild_customtype(self): def test_improper_dihedral(self): untyped_benzene = pmd.load_file(get_fn("benzene.mol2"), structure=True) - ff_improper = Forcefield( - forcefield_files=get_fn("improper_dihedral.xml") - ) - benzene = ff_improper.apply( - untyped_benzene, assert_dihedral_params=False - ) + ff_improper = Forcefield(forcefield_files=get_fn("improper_dihedral.xml")) + benzene = ff_improper.apply(untyped_benzene, assert_dihedral_params=False) assert len(benzene.dihedrals) == 18 assert len([dih for dih in benzene.dihedrals if dih.improper]) == 6 assert len([dih for dih in benzene.dihedrals if not dih.improper]) == 12 @@ -300,9 +290,7 @@ def test_residue_map(self, oplsaa): struct_without = ethane oplsaa._apply_typemap(struct_with, map_with) oplsaa._apply_typemap(struct_without, map_without) - for atom_with, atom_without in zip( - struct_with.atoms, struct_without.atoms - ): + for atom_with, atom_without in zip(struct_with.atoms, struct_without.atoms): assert atom_with.type == atom_without.type b_with = atom_with.bond_partners b_without = atom_without.bond_partners @@ -352,22 +340,12 @@ def test_topology_precedence(self): assert ( len( - [ - bond - for bond in typed_ethane.bonds - if round(bond.type.req, 2) == 1.15 - ] + [bond for bond in typed_ethane.bonds if round(bond.type.req, 2) == 1.15] ) == 6 ) assert ( - len( - [ - bond - for bond in typed_ethane.bonds - if round(bond.type.req, 2) == 1.6 - ] - ) + len([bond for bond in typed_ethane.bonds if round(bond.type.req, 2) == 1.6]) == 1 ) assert ( @@ -392,11 +370,7 @@ def test_topology_precedence(self): ) assert ( len( - [ - rb - for rb in typed_ethane.rb_torsions - if round(rb.type.c0, 3) == 0.287 - ] + [rb for rb in typed_ethane.rb_torsions if round(rb.type.c0, 3) == 0.287] ) == 9 ) @@ -458,9 +432,7 @@ def test_assert_bonds(self): with pytest.raises(Exception): ff.apply(derponium) - thing = ff.apply( - derponium, assert_bond_params=False, assert_angle_params=False - ) + thing = ff.apply(derponium, assert_bond_params=False, assert_angle_params=False) assert any(b.type is None for b in thing.bonds) def test_apply_subfuncs(self, oplsaa): @@ -518,9 +490,7 @@ def test_write_xml(self, filename, oplsaa): for adj in typed.adjusts: type1 = adj.atom1.atom_type type2 = adj.atom1.atom_type - sigma_factor_pre = adj.type.sigma / ( - (type1.sigma + type2.sigma) / 2 - ) + sigma_factor_pre = adj.type.sigma / ((type1.sigma + type2.sigma) / 2) epsilon_factor_pre = adj.type.epsilon / ( (type1.epsilon * type2.epsilon) ** 0.5 ) @@ -528,9 +498,7 @@ def test_write_xml(self, filename, oplsaa): for adj in typed_by_partial.adjusts: type1 = adj.atom1.atom_type type2 = adj.atom1.atom_type - sigma_factor_post = adj.type.sigma / ( - (type1.sigma + type2.sigma) / 2 - ) + sigma_factor_post = adj.type.sigma / ((type1.sigma + type2.sigma) / 2) epsilon_factor_post = adj.type.epsilon / ( (type1.epsilon * type2.epsilon) ** 0.5 ) @@ -543,18 +511,14 @@ def test_write_xml(self, filename, oplsaa): oplsaa = Forcefield(get_fn("oplsaa-periodic.xml")) typed = oplsaa.apply(mol) - typed.write_foyer( - filename="opls-snippet.xml", forcefield=oplsaa, unique=True - ) + typed.write_foyer(filename="opls-snippet.xml", forcefield=oplsaa, unique=True) oplsaa_partial = Forcefield("opls-snippet.xml") typed_by_partial = oplsaa_partial.apply(mol) for adj in typed.adjusts: type1 = adj.atom1.atom_type type2 = adj.atom1.atom_type - sigma_factor_pre = adj.type.sigma / ( - (type1.sigma + type2.sigma) / 2 - ) + sigma_factor_pre = adj.type.sigma / ((type1.sigma + type2.sigma) / 2) epsilon_factor_pre = adj.type.epsilon / ( (type1.epsilon * type2.epsilon) ** 0.5 ) @@ -562,9 +526,7 @@ def test_write_xml(self, filename, oplsaa): for adj in typed_by_partial.adjusts: type1 = adj.atom1.atom_type type2 = adj.atom1.atom_type - sigma_factor_post = adj.type.sigma / ( - (type1.sigma + type2.sigma) / 2 - ) + sigma_factor_post = adj.type.sigma / ((type1.sigma + type2.sigma) / 2) epsilon_factor_post = adj.type.epsilon / ( (type1.epsilon * type2.epsilon) ** 0.5 ) @@ -575,9 +537,7 @@ def test_write_xml(self, filename, oplsaa): @pytest.mark.parametrize("filename", ["ethane.mol2", "benzene.mol2"]) def test_write_xml_multiple_periodictorsions(self, filename): cmpd = pmd.load_file(get_fn(filename), structure=True) - ff = Forcefield( - forcefield_files=get_fn("oplsaa_multiperiodicitytorsion.xml") - ) + ff = Forcefield(forcefield_files=get_fn("oplsaa_multiperiodicitytorsion.xml")) typed_struc = ff.apply(cmpd, assert_dihedral_params=False) typed_struc.write_foyer( filename="multi-periodictorsions.xml", forcefield=ff, unique=True @@ -606,15 +566,13 @@ def test_load_xml(self, filename, oplsaa): typed = ff.apply(mol) typed.write_foyer(filename="snippet.xml", forcefield=ff, unique=True) - generated_ff = Forcefield("snippet.xml") + Forcefield("snippet.xml") def test_write_xml_overrides(self, oplsaa): # Test xml_writer new overrides and comments features mol = pmd.load_file(get_fn("styrene.mol2"), structure=True) typed = oplsaa.apply(mol, assert_dihedral_params=False) - typed.write_foyer( - filename="opls-styrene.xml", forcefield=oplsaa, unique=True - ) + typed.write_foyer(filename="opls-styrene.xml", forcefield=oplsaa, unique=True) styrene = ET.parse("opls-styrene.xml") atom_types = styrene.getroot().find("AtomTypes").findall("Type") for item in atom_types: @@ -637,9 +595,7 @@ def test_load_metadata(self): assert lj_ff.version == "0.4.1" assert lj_ff.name == "LJ" - lj_ff = Forcefield( - forcefield_files=[get_fn("lj.xml"), get_fn("lj2.xml")] - ) + lj_ff = Forcefield(forcefield_files=[get_fn("lj.xml"), get_fn("lj2.xml")]) assert lj_ff.version == ["0.4.1", "4.8.2"] assert lj_ff.name == ["LJ", "JL"] @@ -649,18 +605,14 @@ def test_load_metadata_single_xml(self): assert from_xml_ff.name == "LJ" def test_load_metadata_list_xml(self): - from_xml_ff = Forcefield( - forcefield_files=[get_fn("lj.xml"), get_fn("lj2.xml")] - ) + from_xml_ff = Forcefield(forcefield_files=[get_fn("lj.xml"), get_fn("lj2.xml")]) assert isinstance(from_xml_ff.version, List) assert isinstance(from_xml_ff.name, List) assert all([x in from_xml_ff.version for x in ["0.4.1", "4.8.2"]]) assert all([x in from_xml_ff.name for x in ["JL", "LJ"]]) with pytest.raises(FoyerError): - mismatch_comb_rule = Forcefield( - forcefield_files=[get_fn("lj.xml"), get_fn("lj3.xml")] - ) + Forcefield(forcefield_files=[get_fn("lj.xml"), get_fn("lj3.xml")]) def test_load_metadata_from_internal_forcefield_plugin_loader(self): from_xml_ff = forcefields.load_OPLSAA() @@ -684,9 +636,7 @@ def test_no_overlap_residue_atom_overlap(self): mol1.name = "CCC" mol2.name = "COC" - box = mb.fill_box( - [mol1, mol2], n_compounds=[2, 2], overlap=0.01, density=700 - ) + box = mb.fill_box([mol1, mol2], n_compounds=[2, 2], overlap=0.01, density=700) all_substructures = [] structure = box.to_parmed(residues=["CCC", "COC"]) @@ -746,9 +696,7 @@ def test_combining_rule_in_forcefield_overrides_apply_arg(self, oplsaa): ], ) @pytest.mark.skipif(not has_mbuild, reason="mbuild is not installed") - def test_combining_rule( - self, ff_name, expected_combining_rule, expected_14_sigma - ): + def test_combining_rule(self, ff_name, expected_combining_rule, expected_14_sigma): import mbuild as mb ff = Forcefield(get_fn(ff_name)) diff --git a/foyer/tests/test_forcefield_parameters.py b/foyer/tests/test_forcefield_parameters.py index 5740ec3e..2f5d5849 100644 --- a/foyer/tests/test_forcefield_parameters.py +++ b/foyer/tests/test_forcefield_parameters.py @@ -49,16 +49,8 @@ def test_gaff_angle_parameters(self, gaff): def test_gaff_angle_parameters_reversed(self, gaff): assert np.allclose( - list( - gaff.get_parameters( - "harmonic_angles", ["f", "c2", "ha"] - ).values() - ), - list( - gaff.get_parameters( - "harmonic_angles", ["ha", "c2", "f"] - ).values() - ), + list(gaff.get_parameters("harmonic_angles", ["f", "c2", "ha"]).values()), + list(gaff.get_parameters("harmonic_angles", ["ha", "c2", "f"]).values()), ) def test_gaff_missing_angle_parameters(self, gaff): @@ -70,9 +62,7 @@ def test_gaff_periodic_proper_parameters(self, gaff): "periodic_propers", ["c3", "c", "sh", "hs"] ) assert np.allclose(periodic_proper_params["periodicity"], [2.0, 1.0]) - assert np.allclose( - periodic_proper_params["k"], [9.414, 5.4392000000000005] - ) + assert np.allclose(periodic_proper_params["k"], [9.414, 5.4392000000000005]) assert np.allclose( periodic_proper_params["phase"], [3.141592653589793, 3.141592653589793], @@ -110,21 +100,15 @@ def test_gaff_periodic_improper_parameters(self, gaff): ) assert np.allclose(periodic_improper_params["periodicity"], [2.0]) assert np.allclose(periodic_improper_params["k"], [4.6024]) - assert np.allclose( - periodic_improper_params["phase"], [3.141592653589793] - ) + assert np.allclose(periodic_improper_params["phase"], [3.141592653589793]) def test_gaff_periodic_improper_parameters_reversed(self, gaff): assert np.allclose( list( - gaff.get_parameters( - "periodic_impropers", ["c", "", "o", "o"] - ).values() + gaff.get_parameters("periodic_impropers", ["c", "", "o", "o"]).values() ), list( - gaff.get_parameters( - "periodic_impropers", ["c", "o", "", "o"] - ).values() + gaff.get_parameters("periodic_impropers", ["c", "o", "", "o"]).values() ), ) @@ -147,16 +131,12 @@ def test_opls_get_parameters_atoms_list(self, oplsaa): assert atom_params["epsilon"] == 0.29288 def test_opls_get_parameters_atom_class(self, oplsaa): - atom_params = oplsaa.get_parameters( - "atoms", "CA", keys_are_atom_classes=True - ) + atom_params = oplsaa.get_parameters("atoms", "CA", keys_are_atom_classes=True) assert atom_params["sigma"] == 0.355 assert atom_params["epsilon"] == 0.29288 def test_opls_get_parameters_bonds(self, oplsaa): - bond_params = oplsaa.get_parameters( - "harmonic_bonds", ["opls_760", "opls_145"] - ) + bond_params = oplsaa.get_parameters("harmonic_bonds", ["opls_760", "opls_145"]) assert bond_params["length"] == 0.146 assert bond_params["k"] == 334720.0 @@ -177,14 +157,10 @@ def test_opls_get_parameters_bonds_reversed(self, oplsaa): def test_opls_get_parameters_bonds_atom_classes_reversed(self, oplsaa): assert np.allclose( list( - oplsaa.get_parameters( - "harmonic_bonds", ["C_2", "O_2"], True - ).values() + oplsaa.get_parameters("harmonic_bonds", ["C_2", "O_2"], True).values() ), list( - oplsaa.get_parameters( - "harmonic_bonds", ["O_2", "C_2"], True - ).values() + oplsaa.get_parameters("harmonic_bonds", ["O_2", "C_2"], True).values() ), ) diff --git a/foyer/tests/test_graph.py b/foyer/tests/test_graph.py index 04576426..4d77d434 100644 --- a/foyer/tests/test_graph.py +++ b/foyer/tests/test_graph.py @@ -38,9 +38,7 @@ def test_lazy_cycle_finding(self): ring_tokens = ["R1", "r6"] for token in ring_tokens: - rule = SMARTSGraph( - smarts_string="[C;{}]".format(token), typemap=typemap - ) + rule = SMARTSGraph(smarts_string="[C;{}]".format(token), typemap=typemap) list(rule.find_matches(TopologyGraph.from_parmed(mol2), typemap)) assert all(["cycles" in typemap[a.idx] for a in mol2.atoms]) @@ -51,9 +49,7 @@ def test_cycle_finding_multiple(self): for atom in mol2.atoms } - _prepare_atoms( - TopologyGraph.from_parmed(mol2), typemap, compute_cycles=True - ) + _prepare_atoms(TopologyGraph.from_parmed(mol2), typemap, compute_cycles=True) cycle_lengths = [ list(map(len, typemap[atom.idx]["cycles"])) for atom in mol2.atoms ] diff --git a/foyer/tests/test_opls.py b/foyer/tests/test_opls.py index edaa1800..57516824 100644 --- a/foyer/tests/test_opls.py +++ b/foyer/tests/test_opls.py @@ -49,9 +49,7 @@ def test_opls_metadata(self, oplsaa): assert oplsaa.combining_rule == "geometric" @pytest.mark.parametrize("mol_name", correctly_implemented) - def test_atomtyping( - self, mol_name, oplsaa, testfiles_dir=OPLS_TESTFILES_DIR - ): + def test_atomtyping(self, mol_name, oplsaa, testfiles_dir=OPLS_TESTFILES_DIR): files = glob.glob(os.path.join(testfiles_dir, mol_name, "*")) for mol_file in files: _, ext = os.path.splitext(mol_file) @@ -60,9 +58,7 @@ def test_atomtyping( gro_filename = "{}.gro".format(mol_name) top_path = os.path.join(testfiles_dir, mol_name, top_filename) gro_path = os.path.join(testfiles_dir, mol_name, gro_filename) - structure = pmd.load_file( - top_path, xyz=gro_path, parametrize=False - ) + structure = pmd.load_file(top_path, xyz=gro_path, parametrize=False) elif ext == ".mol2": mol2_path = os.path.join(testfiles_dir, mol_name, mol_file) structure = pmd.load_file(mol2_path, structure=True) @@ -74,12 +70,8 @@ def test_full_parametrization(self, oplsaa): structure = pmd.load_file(top, xyz=gro) parametrized = oplsaa.apply(structure) - assert ( - sum((1 for at in parametrized.atoms if at.type == "opls_145")) == 6 - ) - assert ( - sum((1 for at in parametrized.atoms if at.type == "opls_146")) == 6 - ) + assert sum((1 for at in parametrized.atoms if at.type == "opls_145")) == 6 + assert sum((1 for at in parametrized.atoms if at.type == "opls_146")) == 6 assert len(parametrized.bonds) == 12 assert all(x.type for x in parametrized.bonds) assert len(parametrized.angles) == 18 @@ -110,12 +102,8 @@ def test_improper_in_structure(self): ("13-difluorobenzene", 6), ] # found in the "impropers" sections of molecule_name.top for molecule, n_impropers in files_with_impropers: - top = os.path.join( - OPLS_TESTFILES_DIR, molecule + "/" + molecule + ".top" - ) - gro = os.path.join( - OPLS_TESTFILES_DIR, molecule + "/" + molecule + ".gro" - ) + top = os.path.join(OPLS_TESTFILES_DIR, molecule + "/" + molecule + ".top") + gro = os.path.join(OPLS_TESTFILES_DIR, molecule + "/" + molecule + ".gro") structure = pmd.load_file(top, xyz=gro) impropers = [] [ diff --git a/foyer/tests/test_plugin.py b/foyer/tests/test_plugin.py index 364cfaec..f5b590fc 100644 --- a/foyer/tests/test_plugin.py +++ b/foyer/tests/test_plugin.py @@ -14,7 +14,7 @@ def test_loading_forcefields(self): eval("foyer.forcefields." + func)() def test_load_forcefield(self): - OPLSAA = foyer.forcefields.get_forcefield(name="oplsaa") - TRAPPE_UA = foyer.forcefields.get_forcefield(name="trappe-ua") + foyer.forcefields.get_forcefield(name="oplsaa") + foyer.forcefields.get_forcefield(name="trappe-ua") with pytest.raises(ValueError): foyer.forcefields.get_forcefield("bogus_name") diff --git a/foyer/tests/test_smarts.py b/foyer/tests/test_smarts.py index 6bfa1ef1..2c7bfc08 100644 --- a/foyer/tests/test_smarts.py +++ b/foyer/tests/test_smarts.py @@ -81,9 +81,7 @@ def test_ringness(self, rule_match): for atom in not_ring_mol2.atoms } - rule_match( - not_ring_mol2_graph, typemap, "[#6]1[#6][#6][#6][#6][#6]1", False - ) + rule_match(not_ring_mol2_graph, typemap, "[#6]1[#6][#6][#6][#6][#6]1", False) def test_fused_ring(self, smarts_parser): mol2 = pmd.load_file(get_fn("fused.mol2"), structure=True) @@ -248,11 +246,11 @@ def test_hexa_coordinated(self): def test_optional_names_bad_syntax(self): bad_optional_names = ["_C", "XXX", "C"] with pytest.raises(FoyerError): - S = SMARTS(optional_names=bad_optional_names) + SMARTS(optional_names=bad_optional_names) def test_optional_names_good_syntax(self): good_optional_names = ["_C", "_CH2", "_CH"] - S = SMARTS(optional_names=good_optional_names) + SMARTS(optional_names=good_optional_names) def test_optional_name_parser(self): optional_names = ["_C", "_CH2", "_CH"] diff --git a/foyer/tests/test_topology_graph.py b/foyer/tests/test_topology_graph.py index b0e98981..3a641ace 100644 --- a/foyer/tests/test_topology_graph.py +++ b/foyer/tests/test_topology_graph.py @@ -12,9 +12,7 @@ @pytest.mark.skipif( - condition=( - is_running_on_windows() or (not (has_gmso or has_openff_toolkit)) - ), + condition=(is_running_on_windows() or (not (has_gmso or has_openff_toolkit))), reason="openff-toolkit and gmso not installed", ) class TestTopologyGraph(BaseTest): @@ -101,13 +99,9 @@ def test_atom_typing( oplsaa, ): # ToDo: More robust testing for atomtyping - openff_typemap = find_atomtypes( - openff_topology_graph, forcefield=oplsaa - ) + openff_typemap = find_atomtypes(openff_topology_graph, forcefield=oplsaa) gmso_typemap = find_atomtypes(gmso_topology_graph, forcefield=oplsaa) - parmed_typemap = find_atomtypes( - parmed_topology_graph, forcefield=oplsaa - ) + parmed_typemap = find_atomtypes(parmed_topology_graph, forcefield=oplsaa) assert openff_typemap assert gmso_typemap assert parmed_typemap diff --git a/foyer/tests/test_utils.py b/foyer/tests/test_utils.py index 1e1692a3..d83d86ab 100644 --- a/foyer/tests/test_utils.py +++ b/foyer/tests/test_utils.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 import platform import numpy as np @@ -14,8 +15,7 @@ class TestUtils(BaseTest): platform.system() == "Windows" or pmd.version.major < 4 or ( - pmd.version.major == 4 - and pmd.version.minor == pmd.version.patchlevel == 0 + pmd.version.major == 4 and pmd.version.minor == pmd.version.patchlevel == 0 ), reason="obsolete parmed version", ) diff --git a/foyer/tests/utils.py b/foyer/tests/utils.py index 76b04a6d..227a01e1 100644 --- a/foyer/tests/utils.py +++ b/foyer/tests/utils.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 import glob import platform import urllib.parse as parseurl diff --git a/foyer/topology_graph.py b/foyer/topology_graph.py index f2a7f6c3..e5cf6e03 100644 --- a/foyer/topology_graph.py +++ b/foyer/topology_graph.py @@ -206,9 +206,7 @@ def from_openff_topology(cls, openff_topology: "OpenFFTopology"): ) for bond in openff_topology.bonds: - atoms_indices = [ - openff_topology.atom_index(atom) for atom in bond.atoms - ] + atoms_indices = [openff_topology.atom_index(atom) for atom in bond.atoms] top_graph.add_bond(*atoms_indices) return top_graph @@ -261,8 +259,7 @@ def from_gmso_topology(cls, gmso_topology: "gmso.Topology"): for top_bond in gmso_topology.bonds: atoms_indices = [ - gmso_topology.get_index(atom) - for atom in top_bond.connection_members + gmso_topology.get_index(atom) for atom in top_bond.connection_members ] top_graph.add_bond(atoms_indices[0], atoms_indices[1]) diff --git a/foyer/utils/io.py b/foyer/utils/io.py index 94ef5e9a..cdb1f44a 100644 --- a/foyer/utils/io.py +++ b/foyer/utils/io.py @@ -17,9 +17,7 @@ class DelayImportError(ImportError, SkipTest): MESSAGES = dict() -MESSAGES[ - "mbuild" -] = """ +MESSAGES["mbuild"] = """ The code at {filename}:{line_number} requires the "mbuild" package @@ -29,9 +27,7 @@ class DelayImportError(ImportError, SkipTest): """ -MESSAGES[ - "gmso" -] = """ +MESSAGES["gmso"] = """ The code at {filename}:{line_number} requires the "gmso" package @@ -40,9 +36,7 @@ class DelayImportError(ImportError, SkipTest): # conda install -c conda-forge gmso """ -MESSAGES[ - "openff.toolkit" -] = """ +MESSAGES["openff.toolkit"] = """ The code at {filename}:{line_number} requires the "openff-toolkit" package @@ -76,7 +70,7 @@ def import_(module): """ try: return importlib.import_module(module) - except ImportError as e: + except ImportError: try: message = MESSAGES[module] except KeyError: @@ -85,7 +79,7 @@ def import_(module): + module + " package" ) - e = ImportError("No module named %s" % module) + ImportError("No module named %s" % module) ( frame, @@ -96,9 +90,7 @@ def import_(module): index, ) = inspect.getouterframes(inspect.currentframe())[1] - m = message.format( - filename=os.path.basename(filename), line_number=line_number - ) + m = message.format(filename=os.path.basename(filename), line_number=line_number) m = textwrap.dedent(m) bar = ( diff --git a/foyer/utils/nbfixes.py b/foyer/utils/nbfixes.py index cc0b50f8..f9bda7ec 100644 --- a/foyer/utils/nbfixes.py +++ b/foyer/utils/nbfixes.py @@ -29,8 +29,9 @@ def apply_nbfix(struct, atom_type1, atom_type2, sigma, epsilon): atypes_name = set(a.atom_type.name for a in struct_copy.atoms) if atom_type1 not in atypes_name or atom_type2 not in atypes_name: raise ValueError( - "Atom types {} and {} not found " - "in structure.".format(atom_type1, atom_type2) + "Atom types {} and {} not found " "in structure.".format( + atom_type1, atom_type2 + ) ) # Calculate rmin from sigma because parmed uses it internally diff --git a/foyer/validator.py b/foyer/validator.py index 417478c1..529eee45 100644 --- a/foyer/validator.py +++ b/foyer/validator.py @@ -21,16 +21,12 @@ def __init__(self, ff_file_name, debug=False): from foyer.forcefield import preprocess_forcefield_files try: - preprocessed_ff_file_name = preprocess_forcefield_files( - [ff_file_name] - ) + preprocessed_ff_file_name = preprocess_forcefield_files([ff_file_name]) ff_tree = etree.parse(preprocessed_ff_file_name[0]) self.validate_xsd(ff_tree) - self.atom_type_names = ff_tree.xpath( - "/ForceField/AtomTypes/Type/@name" - ) + self.atom_type_names = ff_tree.xpath("/ForceField/AtomTypes/Type/@name") self.atom_types = ff_tree.xpath("/ForceField/AtomTypes/Type") self.validate_class_type_exclusivity(ff_tree) @@ -53,9 +49,7 @@ def __init__(self, ff_file_name, debug=False): def validate_xsd(ff_tree, xsd_file=None): """Check consistency with forcefields/ff.xsd.""" if xsd_file is None: - xsd_file = join( - split(abspath(__file__))[0], "forcefields", "ff.xsd" - ) + xsd_file = join(split(abspath(__file__))[0], "forcefields", "ff.xsd") xmlschema_doc = etree.parse(xsd_file) xmlschema = etree.XMLSchema(xmlschema_doc) @@ -72,7 +66,7 @@ def validate_xsd(ff_tree, xsd_file=None): def create_error(keyword, message, line): atomtype = message[message.find("[") + 1 : message.find("]")] error_text = error_texts[keyword].format(atomtype, line) - return ValidationError(error_text, ex, line) + return ValidationError(error_text, keyword, line) try: xmlschema.assertValid(ff_tree) @@ -175,9 +169,7 @@ def validate_smarts(self, debug): except lark.ParseError as ex: if " col " in ex.args[0]: column = ex.args[0][ex.args[0].find(" col ") + 5 :].strip() - column = " at character {} of {}".format( - column, smarts_string - ) + column = " at character {} of {}".format(column, smarts_string) else: column = "" @@ -198,9 +190,7 @@ def validate_smarts(self, debug): name=name, overrides=entry.attrib.get("overrides"), ) - for atom_expr in nx.get_node_attributes( - smarts_graph, name="atom" - ).values(): + for atom_expr in nx.get_node_attributes(smarts_graph, name="atom").values(): labels = atom_expr.find_data("has_label") for label in labels: atom_type = label.children[0][1:] diff --git a/foyer/xml_writer.py b/foyer/xml_writer.py index 5238038f..4afe6413 100644 --- a/foyer/xml_writer.py +++ b/foyer/xml_writer.py @@ -5,7 +5,6 @@ import collections import warnings -import gmso import networkx as nx import numpy as np import parmed as pmd @@ -66,9 +65,7 @@ def write_foyer( # Assume if a Structure has a bond and bond type that the Structure is # parameterized. ParmEd uses the same logic to denote parameterization. if not (len(self.bonds) > 0 and self.bonds[0].type is not None): - raise Exception( - "Cannot write Foyer XML from an unparametrized " "Structure." - ) + raise Exception("Cannot write Foyer XML from an unparametrized " "Structure.") root = ET.Element("ForceField") # Write Forcefield information @@ -87,11 +84,11 @@ def write_foyer( _write_rb_torsions(root, self.rb_torsions, unique) # TO DO - elif isinstance(self, gmso.Topology): - raise FoyerError( - "Currently, cannot write foyer XML file from a gmso.Topology. " - "This feature will be implemented in future releases." - ) + # elif isinstance(self, gmso.Topology): + # raise FoyerError( + # "Currently, cannot write foyer XML file from a gmso.Topology. " + # "This feature will be implemented in future releases." + # ) _remove_duplicate_elements(root, unique) @@ -175,12 +172,8 @@ def _update_defs(atomtypes, nonbonded, forcefield): smarts_list = list() smarts_parser = forcefield.parser for smarts_string, name in zip(def_list, name_list): - smarts_graph = SMARTSGraph( - smarts_string, parser=smarts_parser, name=name - ) - for atom_expr in nx.get_node_attributes( - smarts_graph, name="atom" - ).values(): + smarts_graph = SMARTSGraph(smarts_string, parser=smarts_parser, name=name) + for atom_expr in nx.get_node_attributes(smarts_graph, name="atom").values(): labels = atom_expr.find_data("has_label") for label in labels: atom_type = label.children[0][1:] @@ -225,8 +218,7 @@ def _write_angles(root, angles, unique): for angle in angles: angle_force = ET.SubElement(angle_forces, "Angle") atypes = [ - atype - for atype in [angle.atom1.type, angle.atom2.type, angle.atom3.type] + atype for atype in [angle.atom1.type, angle.atom2.type, angle.atom3.type] ] if unique: # Sort the first and last atom types @@ -237,9 +229,7 @@ def _write_angles(root, angles, unique): angle_force.set("id3", str(angle.atom3.idx)) for id in range(3): angle_force.set("type{}".format(id + 1), atypes[id]) - angle_force.set( - "angle", str(round(angle.type.theteq * (np.pi / 180), 10)) - ) + angle_force.set("angle", str(round(angle.type.theteq * (np.pi / 180), 10))) angle_force.set("k", str(round(angle.type.k * 4.184 * 2, 3))) @@ -284,9 +274,7 @@ def _write_periodic_torsions(root, dihedrals, unique): for id in range(4): dihedral_force.set("type{}".format(id + 1), atypes[id]) dihedral_force.set("periodicity1", str(dihedral.type.per)) - dihedral_force.set( - "phase1", str(round(dihedral.type.phase * (np.pi / 180), 8)) - ) + dihedral_force.set("phase1", str(round(dihedral.type.phase * (np.pi / 180), 8))) dihedral_force.set("k1", str(round(dihedral.type.phi_k * 4.184, 3))) if last_dihedral_force is not None: # Check to see if this current dihedral force needs to be @@ -317,12 +305,12 @@ def _write_periodic_torsions(root, dihedrals, unique): last_dihedral_force.attrib["periodicity{}".format(n)] = ( dihedral_force.attrib["periodicity1"] ) - last_dihedral_force.attrib["phase{}".format(n)] = ( - dihedral_force.attrib["phase1"] - ) - last_dihedral_force.attrib["k{}".format(n)] = ( - dihedral_force.attrib["k1"] - ) + last_dihedral_force.attrib["phase{}".format(n)] = dihedral_force.attrib[ + "phase1" + ] + last_dihedral_force.attrib["k{}".format(n)] = dihedral_force.attrib[ + "k1" + ] periodic_torsion_forces.remove(dihedral_force) else: last_dihedral_force = dihedral_force @@ -389,11 +377,7 @@ def _write_rb_torsions(root, rb_torsions, unique): for c_id in range(6): rb_torsion_force.set( "c{}".format(c_id), - str( - round( - getattr(rb_torsion.type, "c{}".format(c_id)) * 4.184, 4 - ) - ), + str(round(getattr(rb_torsion.type, "c{}".format(c_id)) * 4.184, 4)), ) @@ -482,9 +466,7 @@ def _infer_lj14scale(struct, combining_rule: str): raise ValueError( "Unexpected 1-4 sigma value found in adj {}. Expected {}" "and found {}. This estimate was made assuming a combining " - "rule of {}".format( - adj, adj.type.sigma, expected_sigma, combining_rule - ) + "rule of {}".format(adj, adj.type.sigma, expected_sigma, combining_rule) ) lj14scale.append(adj.type.epsilon / expected_epsilon) diff --git a/setup.py b/setup.py index c607dd0c..4b8a0228 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -"""Foyer: Atomtyping and forcefield applying. """ +"""Foyer: Atomtyping and forcefield applying.""" from __future__ import print_function @@ -22,9 +22,7 @@ author="Janos Sallai, Christoph Klein", author_email="janos.sallai@vanderbilt.edu, christoph.klein@vanderbilt.edu", url="https://github.com/mosdef-hub/foyer", - download_url="https://github.com/mosdef-hub/foyer/tarball/{}".format( - __version__ - ), + download_url="https://github.com/mosdef-hub/foyer/tarball/{}".format(__version__), packages=find_packages(), package_data={ "foyer": [ From 07cc0520ea067def5961fbf44ce1035946175c70 Mon Sep 17 00:00:00 2001 From: chrisjonesBSU Date: Tue, 2 Jul 2024 21:43:42 -0600 Subject: [PATCH 2/5] raise error correctly --- foyer/utils/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foyer/utils/io.py b/foyer/utils/io.py index cdb1f44a..945f6ed0 100644 --- a/foyer/utils/io.py +++ b/foyer/utils/io.py @@ -79,7 +79,7 @@ def import_(module): + module + " package" ) - ImportError("No module named %s" % module) + raise ImportError("No module named %s" % module) ( frame, From 48a442495881b4f298714f994f95402f62f4b60c Mon Sep 17 00:00:00 2001 From: chrisjonesBSU Date: Fri, 26 Jul 2024 17:35:41 -0600 Subject: [PATCH 3/5] update ruff-pre-commit version --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 882a17ae..5bd85924 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.5.0 + rev: v0.5.4 hooks: # Run the linter. - id: ruff From 17489bc8c8f67ae4d298f309c559a4e50f5b50f9 Mon Sep 17 00:00:00 2001 From: chrisjonesBSU Date: Thu, 15 Aug 2024 14:04:15 -0600 Subject: [PATCH 4/5] change != to is not --- foyer/xml_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foyer/xml_writer.py b/foyer/xml_writer.py index 4afe6413..55655751 100644 --- a/foyer/xml_writer.py +++ b/foyer/xml_writer.py @@ -414,7 +414,7 @@ def _elements_equal(e1, e2): Note: This was grabbed, basically verbatim, from: https://stackoverflow.com/questions/7905380/testing-equivalence-of-xml-etree-elementtree """ - if type(e1) != type(e2): + if type(e1) is not type(e2): return False if e1.tag != e2.tag: return False From a7c48587e0475de783c82e9dbd671251ed8b99ac Mon Sep 17 00:00:00 2001 From: chrisjonesBSU Date: Thu, 15 Aug 2024 14:05:24 -0600 Subject: [PATCH 5/5] use latest ruff version --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 846b9b49..94160206 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.5.4 + rev: v0.5.7 hooks: # Run the linter. - id: ruff