From 837fb8d8e8336d7534191d8f74fcf55fd85f8d3c Mon Sep 17 00:00:00 2001 From: kousuke-nakano <37653569+kousuke-nakano@users.noreply.github.com> Date: Tue, 30 Jan 2024 23:24:32 +0900 Subject: [PATCH] added ase_atoms_writers in shry/core.py. It can be used instead of cif_writers() --- examples/example5.py | 27 +++++++++ shry/core.py | 129 +++++++++++++++++-------------------------- 2 files changed, 77 insertions(+), 79 deletions(-) create mode 100644 examples/example5.py diff --git a/examples/example5.py b/examples/example5.py new file mode 100644 index 0000000..a129c87 --- /dev/null +++ b/examples/example5.py @@ -0,0 +1,27 @@ +# Copyright (c) SHRY Development Team. +# Distributed under the terms of the MIT License. + +""" +Equivalent enumlib operations in SHRY +""" + +from pymatgen.core import Structure +import shry +from shry import Substitutor + +shry.const.DISABLE_PROGRESSBAR = True + +# PbSnTe structure +cif_file = "PbSnTe.cif" +structure = Structure.from_file(cif_file) +structure *= (2, 2, 2) + +# Generate ASE atoms instances with shry +s = Substitutor(structure) +# Shry uses generator; below is to put the Structures into a list +shry_ase_atoms = [x for x in s.ase_atoms_writers()] +shry_num_structs = s.count() +print( + f"SHRY (group equivalent sites) resulted in {shry_num_structs} structures" +) +print(shry_ase_atoms[0]) diff --git a/shry/core.py b/shry/core.py index 4ae41e1..a036f2a 100755 --- a/shry/core.py +++ b/shry/core.py @@ -32,6 +32,7 @@ from sympy.utilities.iterables import multiset_permutations from tabulate import tabulate from pymatgen.core.periodic_table import get_el_sp +from pymatgen.core import Structure # shry modules from . import const @@ -142,9 +143,7 @@ def formula(self) -> str: """ sym_amt = self.get_el_amt_dict() syms = sorted(sym_amt, key=lambda sym: get_el_sp(sym).X) - formula = [ - f"{s}{formula_double_format_tol(sym_amt[s], False)}" for s in syms - ] + formula = [f"{s}{formula_double_format_tol(sym_amt[s], False)}" for s in syms] return " ".join(formula) @@ -170,12 +169,10 @@ def to_s(x): return f"{x:0.6f}" outs.append( - "abc : " - + " ".join([to_s(i).rjust(10) for i in self.lattice.abc]) + "abc : " + " ".join([to_s(i).rjust(10) for i in self.lattice.abc]) ) outs.append( - "angles: " - + " ".join([to_s(i).rjust(10) for i in self.lattice.angles]) + "angles: " + " ".join([to_s(i).rjust(10) for i in self.lattice.angles]) ) if self._charge: if self._charge >= 0: @@ -247,8 +244,7 @@ def _loop_to_list(self, loop): s.append(line) else: sublines = [ - line[i : i + self.maxlen] - for i in range(0, len(line), self.maxlen) + line[i : i + self.maxlen] for i in range(0, len(line), self.maxlen) ] s.extend(sublines) return s @@ -309,9 +305,7 @@ def __init__(self, data, loops, header): for l in self.loops: if k in l: for _k in l: - self.data[_k] = list( - map(self._format_field, self.data[_k]) - ) + self.data[_k] = list(map(self._format_field, self.data[_k])) loop_id = "loop_\n " + "\n ".join(l) self.string_cache[loop_id] = self._loop_to_list(l) formatted.extend(l) @@ -551,9 +545,7 @@ def structure(self, structure): except TypeError as exc: raise RuntimeError("Couldn't find symmetry.") from exc - logging.info( - f"Space group: {sga.get_hall()} ({sga.get_space_group_number()})" - ) + logging.info(f"Space group: {sga.get_hall()} ({sga.get_space_group_number()})") logging.info(f"Total {len(self._symmops)} symmetry operations") logging.info(sga.get_symmetrized_structure()) equivalent_atoms = sga.get_symmetry_dataset()["equivalent_atoms"] @@ -569,9 +561,7 @@ def structure(self, structure): disorder_sites.append(site) # Ad hoc fix: if occupancy is less than 1, stop. # TODO: Automatic vacancy handling - if not np.isclose( - site.species.num_atoms, 1.0, atol=self._atol - ): + if not np.isclose(site.species.num_atoms, 1.0, atol=self._atol): logging.warning( f"The occupancy of the site {site.species} is {site.species.num_atoms}." ) @@ -581,9 +571,7 @@ def structure(self, structure): logging.warning( "If you want to consider vacancy sites, please add pseudo atoms." ) - raise RuntimeError( - "The sum of number of occupancies is not 1." - ) + raise RuntimeError("The sum of number of occupancies is not 1.") if not disorder_sites: logging.warning("No disorder sites found within the Structure.") @@ -594,15 +582,11 @@ def structure(self, structure): self._groupby = lambda x: x.properties["equivalent_atoms"] disorder_sites.sort(key=self._groupby) - for orbit, sites in itertools.groupby( - disorder_sites, key=self._groupby - ): + for orbit, sites in itertools.groupby(disorder_sites, key=self._groupby): # Can it fit? sites = tuple(sites) composition = sites[0].species.to_int_dict() - integer_formula = "".join( - e + str(a) for e, a in composition.items() - ) + integer_formula = "".join(e + str(a) for e, a in composition.items()) formula_unit_sum = sum(composition.values()) if len(sites) % formula_unit_sum: raise NeedSupercellError( @@ -620,12 +604,8 @@ def structure(self, structure): # DMAT indices = [x.properties["index"] for x in sites] self._group_indices[orbit] = indices - group_dmat = self._structure.distance_matrix[ - np.ix_(indices, indices) - ] - self._group_dmat[orbit] = self.ordinalize( - group_dmat, atol=self._atol - ) + group_dmat = self._structure.distance_matrix[np.ix_(indices, indices)] + self._group_dmat[orbit] = self.ordinalize(group_dmat, atol=self._atol) # PERM coords = [x.frac_coords for x in sites] @@ -784,14 +764,12 @@ def _sorted_compositions(self): def _disorder_elements(self): return { - orbit: tuple(x.keys()) - for orbit, x in self._sorted_compositions().items() + orbit: tuple(x.keys()) for orbit, x in self._sorted_compositions().items() } def _disorder_amounts(self): return { - orbit: tuple(x.values()) - for orbit, x in self._sorted_compositions().items() + orbit: tuple(x.values()) for orbit, x in self._sorted_compositions().items() } def make_patterns(self): @@ -867,9 +845,7 @@ def maker_recurse_unit(aut, pattern, orbit, amount): def maker_recurse_c(aut, pattern, orbit, chain): if len(chain) > 0: amount = chain.pop() - for aut, pattern in maker_recurse_unit( - aut, pattern, orbit, amount - ): + for aut, pattern in maker_recurse_unit(aut, pattern, orbit, amount): _chain = chain.copy() yield from maker_recurse_c(aut, pattern, orbit, _chain) else: @@ -879,9 +855,7 @@ def maker_recurse_o(aut, pattern, ochain): if len(ochain) > 0: orbit, sites = ochain.pop() - chain = list(rscum(self._disorder_amounts()[orbit][::-1]))[ - ::-1 - ] + chain = list(rscum(self._disorder_amounts()[orbit][::-1]))[::-1] indices = np.arange(len(sites)) for aut, pattern in maker_recurse_c( @@ -917,18 +891,14 @@ def total_count(self): """ Total number of combinations. """ - ocount = ( - multinomial_coeff(x) for x in self._disorder_amounts().values() - ) + ocount = (multinomial_coeff(x) for x in self._disorder_amounts().values()) return functools.reduce(lambda x, y: x * y, ocount, 1) def count(self): """ Final number of patterns. """ - logging.info( - f"\nCounting unique patterns for {self.structure.formula}" - ) + logging.info(f"\nCounting unique patterns for {self.structure.formula}") if len(self._symmops): enumerator = self._enumerator_collection.get( @@ -1013,6 +983,26 @@ def structure_writers(self, symprec=None): for _, p in self.make_patterns(): yield self._get_structure(p) + def ase_atoms_writers(self, symprec=None): + from ase import Atoms + + def _from_pymatgen_struct_to_ase_atoms(structure: Structure) -> Atoms: + return Atoms( + symbols=[specie.symbol for specie in structure.species], + positions=structure.cart_coords, + cell=structure.lattice.matrix, + pbc=True, + ) + + """ + ASE atoms instances generator. + """ + # This one does not need symprec. + # Just to keep the signature the same. + del symprec + for _, p in self.make_patterns(): + yield _from_pymatgen_struct_to_ase_atoms(self._get_structure(p)) + def ewalds(self, symprec=None): """ Ewald energy generator. @@ -1086,7 +1076,7 @@ def _get_cifwriter(self, p, symprec=None): cfkey = list(cfkey)[0] block = AltCifBlock.from_string(str(cifwriter.ciffile.data[cfkey])) cifwriter.ciffile.data[cfkey] = block - + self._template_cifwriter = cifwriter self._template_structure = template_structure else: @@ -1127,8 +1117,7 @@ def _get_cifwriter(self, p, symprec=None): # Flattened list of species @ disorder sites specie = [y for x in des.values() for y in x] z_map = [ - cell_specie.index(Composition({specie[j]: 1})) - for j in range(len(p)) + cell_specie.index(Composition({specie[j]: 1})) for j in range(len(p)) ] zs = [cell_specie.index(x.species) for x in template_structure] @@ -1168,8 +1157,7 @@ def _get_cifwriter(self, p, symprec=None): sorted( j, key=lambda s: tuple( - abs(x) - for x in template_structure.sites[s].frac_coords + abs(x) for x in template_structure.sites[s].frac_coords ), )[0], len(j), @@ -1187,9 +1175,7 @@ def _get_cifwriter(self, p, symprec=None): ), ) - block["_symmetry_space_group_name_H-M"] = space_group_data[ - "international" - ] + block["_symmetry_space_group_name_H-M"] = space_group_data["international"] block["_symmetry_Int_Tables_number"] = space_group_data["number"] block["_symmetry_equiv_pos_site_id"] = [ str(i) for i in range(1, len(ops) + 1) @@ -1425,9 +1411,7 @@ def reindex(perm_list): # TODO: More intuitive if we do this first, then the previous one. # Relabel to match column position relabel_index = perm_list[0] - relabel_element = np.vectorize( - {s: i for i, s in enumerate(relabel_index)}.get - ) + relabel_element = np.vectorize({s: i for i, s in enumerate(relabel_index)}.get) try: perm_list = relabel_element(perm_list) except TypeError as exc: @@ -1527,9 +1511,7 @@ def cached_ap(self, n): for a, p in self._search(start=start, stop=_n) ] else: - ap = [ - (a, p) for a, p in self._search(start=start, stop=_n) - ] + ap = [(a, p) for a, p in self._search(start=start, stop=_n)] self._auts[n], self._patterns[n] = zip(*ap) for a, p in zip(self._auts[n], self._patterns[n]): @@ -1772,9 +1754,7 @@ def _invar_search(self, start=0, stop=None): leaf_array = np.flatnonzero(leaf_mask) # Calculate subobject Ts for all leaves - leaf_subobj_ts = self._get_subobj_ts( - pattern, leaf_array, subobj_ts - ) + leaf_subobj_ts = self._get_subobj_ts(pattern, leaf_array, subobj_ts) # Reject all leaves where any T is smaller than the new row's T delta_t = leaf_subobj_ts[:, :-1] - leaf_subobj_ts[:, -1:] @@ -1785,9 +1765,7 @@ def _invar_search(self, start=0, stop=None): # Discard symmetry duplicates from the remaining leaves if aut.size > 1 and not_reject_mask.sum() > 1: not_reject_leaf = leaf_array[not_reject_mask] - leaf_reps = self._perms[np.ix_(aut, not_reject_leaf)].min( - axis=0 - ) + leaf_reps = self._perms[np.ix_(aut, not_reject_leaf)].min(axis=0) leaf_indices = leaf_array.searchsorted(leaf_reps) uniq_mask = np.zeros(leaf_array.shape, dtype="bool") uniq_mask[leaf_indices] = True @@ -1807,9 +1785,7 @@ def _invar_search(self, start=0, stop=None): _pbs = self._bit_perm[:, x] + pbs _subobj_ts = leaf_subobj_ts[i] - _subobj_ts[j:] = np.concatenate( - (_subobj_ts[-1:], _subobj_ts[j:-1]) - ) + _subobj_ts[j:] = np.concatenate((_subobj_ts[-1:], _subobj_ts[j:-1])) _i = np.concatenate((pattern[:j], [x], pattern[j:])) # NOTE: just in case I fail to consistently sort perm @@ -1919,9 +1895,7 @@ def ci(self): logging.error(f"IMAP: {index_map}") logging.error(f"P: {permutation}") logging.error(f"BP: {permutations}") - raise RuntimeError( - "Check permutation list." - ) from exc + raise RuntimeError("Check permutation list.") from exc cycles.append(cycle) counter = collections.Counter(len(cycle) for cycle in cycles) cycle_index[i].append(counter) @@ -2020,10 +1994,7 @@ def exmul(arrays): counts = [ functools.reduce( lambda x, y: x * y, - [ - multinomial_coeff(tuple(p[j])) - for p, j in zip(f_parts, i) - ], + [multinomial_coeff(tuple(p[j])) for p, j in zip(f_parts, i)], ) for i in match_i ]