Skip to content

Commit

Permalink
added ase_atoms_writers in shry/core.py. It can be used instead of ci…
Browse files Browse the repository at this point in the history
…f_writers()
  • Loading branch information
kousuke-nakano committed Jan 30, 2024
1 parent 11096da commit 837fb8d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 79 deletions.
27 changes: 27 additions & 0 deletions examples/example5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) SHRY Development Team.
# Distributed under the terms of the MIT License.

"""
Equivalent enumlib operations in SHRY
"""

from pymatgen.core import Structure
import shry
from shry import Substitutor

shry.const.DISABLE_PROGRESSBAR = True

# PbSnTe structure
cif_file = "PbSnTe.cif"
structure = Structure.from_file(cif_file)
structure *= (2, 2, 2)

# Generate ASE atoms instances with shry
s = Substitutor(structure)
# Shry uses generator; below is to put the Structures into a list
shry_ase_atoms = [x for x in s.ase_atoms_writers()]
shry_num_structs = s.count()
print(
f"SHRY (group equivalent sites) resulted in {shry_num_structs} structures"
)
print(shry_ase_atoms[0])
129 changes: 50 additions & 79 deletions shry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sympy.utilities.iterables import multiset_permutations
from tabulate import tabulate
from pymatgen.core.periodic_table import get_el_sp
from pymatgen.core import Structure

# shry modules
from . import const
Expand Down Expand Up @@ -142,9 +143,7 @@ def formula(self) -> str:
"""
sym_amt = self.get_el_amt_dict()
syms = sorted(sym_amt, key=lambda sym: get_el_sp(sym).X)
formula = [
f"{s}{formula_double_format_tol(sym_amt[s], False)}" for s in syms
]
formula = [f"{s}{formula_double_format_tol(sym_amt[s], False)}" for s in syms]
return " ".join(formula)


Expand All @@ -170,12 +169,10 @@ def to_s(x):
return f"{x:0.6f}"

outs.append(
"abc : "
+ " ".join([to_s(i).rjust(10) for i in self.lattice.abc])
"abc : " + " ".join([to_s(i).rjust(10) for i in self.lattice.abc])
)
outs.append(
"angles: "
+ " ".join([to_s(i).rjust(10) for i in self.lattice.angles])
"angles: " + " ".join([to_s(i).rjust(10) for i in self.lattice.angles])
)
if self._charge:
if self._charge >= 0:
Expand Down Expand Up @@ -247,8 +244,7 @@ def _loop_to_list(self, loop):
s.append(line)
else:
sublines = [
line[i : i + self.maxlen]
for i in range(0, len(line), self.maxlen)
line[i : i + self.maxlen] for i in range(0, len(line), self.maxlen)
]
s.extend(sublines)
return s
Expand Down Expand Up @@ -309,9 +305,7 @@ def __init__(self, data, loops, header):
for l in self.loops:
if k in l:
for _k in l:
self.data[_k] = list(
map(self._format_field, self.data[_k])
)
self.data[_k] = list(map(self._format_field, self.data[_k]))
loop_id = "loop_\n " + "\n ".join(l)
self.string_cache[loop_id] = self._loop_to_list(l)
formatted.extend(l)
Expand Down Expand Up @@ -551,9 +545,7 @@ def structure(self, structure):
except TypeError as exc:
raise RuntimeError("Couldn't find symmetry.") from exc

logging.info(
f"Space group: {sga.get_hall()} ({sga.get_space_group_number()})"
)
logging.info(f"Space group: {sga.get_hall()} ({sga.get_space_group_number()})")
logging.info(f"Total {len(self._symmops)} symmetry operations")
logging.info(sga.get_symmetrized_structure())
equivalent_atoms = sga.get_symmetry_dataset()["equivalent_atoms"]
Expand All @@ -569,9 +561,7 @@ def structure(self, structure):
disorder_sites.append(site)
# Ad hoc fix: if occupancy is less than 1, stop.
# TODO: Automatic vacancy handling
if not np.isclose(
site.species.num_atoms, 1.0, atol=self._atol
):
if not np.isclose(site.species.num_atoms, 1.0, atol=self._atol):
logging.warning(
f"The occupancy of the site {site.species} is {site.species.num_atoms}."
)
Expand All @@ -581,9 +571,7 @@ def structure(self, structure):
logging.warning(
"If you want to consider vacancy sites, please add pseudo atoms."
)
raise RuntimeError(
"The sum of number of occupancies is not 1."
)
raise RuntimeError("The sum of number of occupancies is not 1.")
if not disorder_sites:
logging.warning("No disorder sites found within the Structure.")

Expand All @@ -594,15 +582,11 @@ def structure(self, structure):
self._groupby = lambda x: x.properties["equivalent_atoms"]
disorder_sites.sort(key=self._groupby)

for orbit, sites in itertools.groupby(
disorder_sites, key=self._groupby
):
for orbit, sites in itertools.groupby(disorder_sites, key=self._groupby):
# Can it fit?
sites = tuple(sites)
composition = sites[0].species.to_int_dict()
integer_formula = "".join(
e + str(a) for e, a in composition.items()
)
integer_formula = "".join(e + str(a) for e, a in composition.items())
formula_unit_sum = sum(composition.values())
if len(sites) % formula_unit_sum:
raise NeedSupercellError(
Expand All @@ -620,12 +604,8 @@ def structure(self, structure):
# DMAT
indices = [x.properties["index"] for x in sites]
self._group_indices[orbit] = indices
group_dmat = self._structure.distance_matrix[
np.ix_(indices, indices)
]
self._group_dmat[orbit] = self.ordinalize(
group_dmat, atol=self._atol
)
group_dmat = self._structure.distance_matrix[np.ix_(indices, indices)]
self._group_dmat[orbit] = self.ordinalize(group_dmat, atol=self._atol)

# PERM
coords = [x.frac_coords for x in sites]
Expand Down Expand Up @@ -784,14 +764,12 @@ def _sorted_compositions(self):

def _disorder_elements(self):
return {
orbit: tuple(x.keys())
for orbit, x in self._sorted_compositions().items()
orbit: tuple(x.keys()) for orbit, x in self._sorted_compositions().items()
}

def _disorder_amounts(self):
return {
orbit: tuple(x.values())
for orbit, x in self._sorted_compositions().items()
orbit: tuple(x.values()) for orbit, x in self._sorted_compositions().items()
}

def make_patterns(self):
Expand Down Expand Up @@ -867,9 +845,7 @@ def maker_recurse_unit(aut, pattern, orbit, amount):
def maker_recurse_c(aut, pattern, orbit, chain):
if len(chain) > 0:
amount = chain.pop()
for aut, pattern in maker_recurse_unit(
aut, pattern, orbit, amount
):
for aut, pattern in maker_recurse_unit(aut, pattern, orbit, amount):
_chain = chain.copy()
yield from maker_recurse_c(aut, pattern, orbit, _chain)
else:
Expand All @@ -879,9 +855,7 @@ def maker_recurse_o(aut, pattern, ochain):
if len(ochain) > 0:
orbit, sites = ochain.pop()

chain = list(rscum(self._disorder_amounts()[orbit][::-1]))[
::-1
]
chain = list(rscum(self._disorder_amounts()[orbit][::-1]))[::-1]
indices = np.arange(len(sites))

for aut, pattern in maker_recurse_c(
Expand Down Expand Up @@ -917,18 +891,14 @@ def total_count(self):
"""
Total number of combinations.
"""
ocount = (
multinomial_coeff(x) for x in self._disorder_amounts().values()
)
ocount = (multinomial_coeff(x) for x in self._disorder_amounts().values())
return functools.reduce(lambda x, y: x * y, ocount, 1)

def count(self):
"""
Final number of patterns.
"""
logging.info(
f"\nCounting unique patterns for {self.structure.formula}"
)
logging.info(f"\nCounting unique patterns for {self.structure.formula}")

if len(self._symmops):
enumerator = self._enumerator_collection.get(
Expand Down Expand Up @@ -1013,6 +983,26 @@ def structure_writers(self, symprec=None):
for _, p in self.make_patterns():
yield self._get_structure(p)

def ase_atoms_writers(self, symprec=None):
from ase import Atoms

def _from_pymatgen_struct_to_ase_atoms(structure: Structure) -> Atoms:
return Atoms(
symbols=[specie.symbol for specie in structure.species],
positions=structure.cart_coords,
cell=structure.lattice.matrix,
pbc=True,
)

"""
ASE atoms instances generator.
"""
# This one does not need symprec.
# Just to keep the signature the same.
del symprec
for _, p in self.make_patterns():
yield _from_pymatgen_struct_to_ase_atoms(self._get_structure(p))

def ewalds(self, symprec=None):
"""
Ewald energy generator.
Expand Down Expand Up @@ -1086,7 +1076,7 @@ def _get_cifwriter(self, p, symprec=None):
cfkey = list(cfkey)[0]
block = AltCifBlock.from_string(str(cifwriter.ciffile.data[cfkey]))
cifwriter.ciffile.data[cfkey] = block

self._template_cifwriter = cifwriter
self._template_structure = template_structure
else:
Expand Down Expand Up @@ -1127,8 +1117,7 @@ def _get_cifwriter(self, p, symprec=None):
# Flattened list of species @ disorder sites
specie = [y for x in des.values() for y in x]
z_map = [
cell_specie.index(Composition({specie[j]: 1}))
for j in range(len(p))
cell_specie.index(Composition({specie[j]: 1})) for j in range(len(p))
]
zs = [cell_specie.index(x.species) for x in template_structure]

Expand Down Expand Up @@ -1168,8 +1157,7 @@ def _get_cifwriter(self, p, symprec=None):
sorted(
j,
key=lambda s: tuple(
abs(x)
for x in template_structure.sites[s].frac_coords
abs(x) for x in template_structure.sites[s].frac_coords
),
)[0],
len(j),
Expand All @@ -1187,9 +1175,7 @@ def _get_cifwriter(self, p, symprec=None):
),
)

block["_symmetry_space_group_name_H-M"] = space_group_data[
"international"
]
block["_symmetry_space_group_name_H-M"] = space_group_data["international"]
block["_symmetry_Int_Tables_number"] = space_group_data["number"]
block["_symmetry_equiv_pos_site_id"] = [
str(i) for i in range(1, len(ops) + 1)
Expand Down Expand Up @@ -1425,9 +1411,7 @@ def reindex(perm_list):
# TODO: More intuitive if we do this first, then the previous one.
# Relabel to match column position
relabel_index = perm_list[0]
relabel_element = np.vectorize(
{s: i for i, s in enumerate(relabel_index)}.get
)
relabel_element = np.vectorize({s: i for i, s in enumerate(relabel_index)}.get)
try:
perm_list = relabel_element(perm_list)
except TypeError as exc:
Expand Down Expand Up @@ -1527,9 +1511,7 @@ def cached_ap(self, n):
for a, p in self._search(start=start, stop=_n)
]
else:
ap = [
(a, p) for a, p in self._search(start=start, stop=_n)
]
ap = [(a, p) for a, p in self._search(start=start, stop=_n)]
self._auts[n], self._patterns[n] = zip(*ap)

for a, p in zip(self._auts[n], self._patterns[n]):
Expand Down Expand Up @@ -1772,9 +1754,7 @@ def _invar_search(self, start=0, stop=None):
leaf_array = np.flatnonzero(leaf_mask)

# Calculate subobject Ts for all leaves
leaf_subobj_ts = self._get_subobj_ts(
pattern, leaf_array, subobj_ts
)
leaf_subobj_ts = self._get_subobj_ts(pattern, leaf_array, subobj_ts)

# Reject all leaves where any T is smaller than the new row's T
delta_t = leaf_subobj_ts[:, :-1] - leaf_subobj_ts[:, -1:]
Expand All @@ -1785,9 +1765,7 @@ def _invar_search(self, start=0, stop=None):
# Discard symmetry duplicates from the remaining leaves
if aut.size > 1 and not_reject_mask.sum() > 1:
not_reject_leaf = leaf_array[not_reject_mask]
leaf_reps = self._perms[np.ix_(aut, not_reject_leaf)].min(
axis=0
)
leaf_reps = self._perms[np.ix_(aut, not_reject_leaf)].min(axis=0)
leaf_indices = leaf_array.searchsorted(leaf_reps)
uniq_mask = np.zeros(leaf_array.shape, dtype="bool")
uniq_mask[leaf_indices] = True
Expand All @@ -1807,9 +1785,7 @@ def _invar_search(self, start=0, stop=None):
_pbs = self._bit_perm[:, x] + pbs

_subobj_ts = leaf_subobj_ts[i]
_subobj_ts[j:] = np.concatenate(
(_subobj_ts[-1:], _subobj_ts[j:-1])
)
_subobj_ts[j:] = np.concatenate((_subobj_ts[-1:], _subobj_ts[j:-1]))

_i = np.concatenate((pattern[:j], [x], pattern[j:]))
# NOTE: just in case I fail to consistently sort perm
Expand Down Expand Up @@ -1919,9 +1895,7 @@ def ci(self):
logging.error(f"IMAP: {index_map}")
logging.error(f"P: {permutation}")
logging.error(f"BP: {permutations}")
raise RuntimeError(
"Check permutation list."
) from exc
raise RuntimeError("Check permutation list.") from exc
cycles.append(cycle)
counter = collections.Counter(len(cycle) for cycle in cycles)
cycle_index[i].append(counter)
Expand Down Expand Up @@ -2020,10 +1994,7 @@ def exmul(arrays):
counts = [
functools.reduce(
lambda x, y: x * y,
[
multinomial_coeff(tuple(p[j]))
for p, j in zip(f_parts, i)
],
[multinomial_coeff(tuple(p[j])) for p, j in zip(f_parts, i)],
)
for i in match_i
]
Expand Down

0 comments on commit 837fb8d

Please sign in to comment.