diff --git a/pyscf/grad/rhf.py b/pyscf/grad/rhf.py index 11ddcd5522..59b86d56a3 100644 --- a/pyscf/grad/rhf.py +++ b/pyscf/grad/rhf.py @@ -27,6 +27,7 @@ from pyscf import lib from pyscf.lib import logger from pyscf.scf import _vhf +from pyscf.gto.mole import is_au def grad_elec(mf_grad, mo_energy=None, mo_coeff=None, mo_occ=None, atmlst=None): @@ -292,7 +293,7 @@ def dump_flags(self, verbose=None): self.base.__class__.__name__) log.info('******** %s for %s ********', self.__class__, self.base.__class__) - if 'ANG' in self.unit.upper(): + if not is_au(self.unit): raise NotImplementedError('unit Eh/Ang is not supported') else: log.info('unit = Eh/Bohr') diff --git a/pyscf/gto/mole.py b/pyscf/gto/mole.py index 091e030cff..7faba12c3f 100644 --- a/pyscf/gto/mole.py +++ b/pyscf/gto/mole.py @@ -410,9 +410,9 @@ def str2atm(line): axes = numpy.eye(3) if isinstance(unit, (str, unicode)): - if unit.upper().startswith(('B', 'AU')): + if is_au(unit): unit = 1. - else: #unit[:3].upper() == 'ANG': + else: unit = 1./param.BOHR else: unit = 1./unit @@ -2086,6 +2086,10 @@ def fromstring(string, format='xyz'): else: raise NotImplementedError +def is_au(unit): + '''Return whether the unit is recogized as A.U. or not + ''' + return unit.upper().startswith(('B', 'AU')) # # Mole class handles three layers: input, internal format, libcint arguments. @@ -3025,9 +3029,9 @@ def set_geom_(self, atoms_or_coords, unit=None, symmetry=None, if isinstance(atoms_or_coords, numpy.ndarray) and not symmetry: if isinstance(unit, (str, unicode)): - if unit.upper().startswith(('B', 'AU')): + if is_au(unit): unit = 1. - else: #unit[:3].upper() == 'ANG': + else: unit = 1./param.BOHR else: unit = 1./unit @@ -3161,7 +3165,7 @@ def atom_coord(self, atm_id, unit='Bohr'): [ 0. 0. 2.07869874] ''' ptr = self._atm[atm_id,PTR_COORD] - if unit[:3].upper() == 'ANG': + if not is_au(unit): return self._env[ptr:ptr+3] * param.BOHR else: return self._env[ptr:ptr+3].copy() @@ -3170,7 +3174,7 @@ def atom_coords(self, unit='Bohr'): '''np.asarray([mol.atom_coords(i) for i in range(mol.natm)])''' ptr = self._atm[:,PTR_COORD] c = self._env[numpy.vstack((ptr,ptr+1,ptr+2)).T].copy() - if unit[:3].upper() == 'ANG': + if not is_au(unit): c *= param.BOHR return c diff --git a/pyscf/pbc/gto/cell.py b/pyscf/pbc/gto/cell.py index 8911ea28a8..a5479f8be4 100644 --- a/pyscf/pbc/gto/cell.py +++ b/pyscf/pbc/gto/cell.py @@ -36,7 +36,7 @@ from pyscf.gto import mole from pyscf.gto import moleintor from pyscf.gto.mole import (_symbol, _rm_digit, _atom_symbol, _std_symbol, - _std_symbol_without_ghost, charge, is_ghost_atom) # noqa + _std_symbol_without_ghost, charge, is_ghost_atom, is_au) # noqa from pyscf.gto.mole import conc_env, uncontract from pyscf.pbc.gto import basis from pyscf.pbc.gto import pseudo @@ -1627,7 +1627,7 @@ def lattice_vectors(self): else: a = np.asarray(self.a, dtype=np.double) if isinstance(self.unit, (str, unicode)): - if self.unit.startswith(('B','b','au','AU')): + if is_au(self.unit): return a else: return a/param.BOHR