Skip to content

Commit

Permalink
Merge pull request #7 from ajjackson/spglib
Browse files Browse the repository at this point in the history
Update get-primitive for modern versions of spglib
  • Loading branch information
ajjackson authored May 29, 2024
2 parents 1272f70 + 4d6e96e commit d23103b
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 82 deletions.
205 changes: 124 additions & 81 deletions mctools/generic/get_primitive.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,158 @@
import argparse
"""Get primitive cell from crystal structure using spglib"""

from argparse import ArgumentParser
from pathlib import Path
from typing import Any, Dict, List, Optional

import ase
import ase.io
import spglib


def main():
parser = argparse.ArgumentParser(
description="Find a primitive unit cell using pyspglib")
def get_parser() -> ArgumentParser:
parser = ArgumentParser(
description="Find a primitive unit cell using spglib")
parser.add_argument(
'input_file',
type=str,
default='POSCAR',
help="Path to crystal structure file, recognisable by ASE")
"input_file",
type=Path,
default="POSCAR",
help="Path to crystal structure file, recognisable by ASE",
)
parser.add_argument(
'--input_format',
"--input-format",
dest="input_format",
type=str,
help="Format for input file (needed if ASE can't guess from filename)")
help="Format for input file (needed if ASE can't guess from filename)",
)
parser.add_argument(
'-t',
'--threshold',
"-t",
"--threshold",
type=float,
default=1e-05,
help=("Distance threshold in AA for symmetry reduction "
"(corresponds to spglib 'symprec' keyword)"))
help=(
"Distance threshold in AA for symmetry reduction "
"(corresponds to spglib 'symprec' keyword)"
),
)
parser.add_argument(
'-a',
'--angle_tolerance',
"-a",
"--angle-tolerance",
dest="angle_tolerance",
type=float,
default=-1.0,
help="Angle tolerance for symmetry reduction")
help="Angle tolerance for symmetry reduction",
)
parser.add_argument(
'-o', '--output_file', default=None, help="Path/filename for output")
"-o",
"--output-file",
type=Path,
default=None,
dest="output_file",
help="Path/filename for output",
)
parser.add_argument(
'--output_format',
"--output-format",
dest="output_format",
type=str,
help="Format for input file (needed if ASE can't guess from filename)")
default=None,
help="Format for input file (needed if ASE can't guess from filename)",
)
parser.add_argument(
'-v',
'--verbose',
"-v",
"--verbose",
action="store_true",
help="Print output to screen even when writing to file.")
args = parser.parse_args()
help="Print output to screen even when writing to file.",
)
parser.add_argument(
"-p",
"--precision",
type=int,
help=("Number of decimal places for float display. "
"(Output files are not affected)"),
default=6,
)
return parser

get_primitive(**vars(args))

def main(params: Optional[List[str]] = None):
parser = get_parser()

def get_primitive(input_file='POSCAR',
input_format=None,
output_file=None,
output_format=None,
threshold=1e-5,
angle_tolerance=-1.,
verbose=False):
if params:
args = parser.parse_args(params)
else:
args = parser.parse_args()

if output_file is None:
verbose = True
get_primitive(**snake_case_args(vars(args)))

if verbose:

def vprint(*args):
for arg in args:
print(arg,)
print("")
else:
def snake_case_args(kwarg_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Convert user-friendly hyphenated arguments to python_friendly ones"""
return {key.replace("-", "_"): value for key, value in kwarg_dict.items()}

def vprint(*args):
pass

try:
if input_format is None:
A = ase.io.read(input_file)
else:
A = ase.io.read(input_file, format=input_format)
except IOError as e:
raise Exception("I/O error({0}): {1}".format(e.errno, e.strerror))
def get_primitive_atoms(
atoms: ase.Atoms,
threshold: float = 1e-5,
angle_tolerance: float = -1.0,
print_spacegroup: bool = False,
) -> ase.Atoms:
"""Convert ASE Atoms to primitive cell using spglib"""
atoms_spglib = (
atoms.cell.array,
atoms.get_scaled_positions(),
atoms.numbers,
)

vprint(
"# Space group: ",
str(
spglib.get_spacegroup(
A, symprec=threshold, angle_tolerance=angle_tolerance)),
'\n')
spacegroup = spglib.get_spacegroup(
atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance)
if print_spacegroup:
print(f"Space group: {spacegroup}")

cell, positions, atomic_numbers = spglib.find_primitive(
A, symprec=threshold, angle_tolerance=angle_tolerance)
atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance)

primitive_atoms = ase.Atoms(
scaled_positions=positions,
cell=cell,
numbers=atomic_numbers,
pbc=True)

return primitive_atoms


if positions is None:
print("This space group doesn't have a more primitive unit cell.")
def get_primitive(input_file: Path = Path('POSCAR'),
input_format: Optional[str] = None,
output_file: Optional[Path] = None,
output_format: Optional[str] = None,
threshold: float = 1e-5,
angle_tolerance: float = -1.,
verbose: bool = False,
precision: int = 6) -> None:

if output_file is None:
verbose = True

float_format_str = f"{{:{precision+4}.{precision}f}}"

def format_float(x: float) -> str:
return float_format_str.format(x)

atoms = ase.io.read(input_file, format=input_format)
atoms = get_primitive_atoms(
atoms, threshold=threshold, angle_tolerance=angle_tolerance, print_spacegroup=verbose
)

if verbose:
print("Primitive cell vectors:")
for row in atoms.cell:
print(" ".join(map(format_float, row)))

print("Atomic positions and proton numbers:")
for position, number in zip(atoms.get_scaled_positions(), atoms.numbers):
print(" ".join(map(format_float, position)) + f"\t{number}")

if output_file is None:
pass
else:
vprint("Primitive cell vectors:")
vprint(cell, '\n')
vprint("Atomic positions and proton numbers:")
for position, number in zip(positions, atomic_numbers):
vprint(position, '\t', number)

if output_file is None:
pass
else:
atoms = ase.Atoms(
scaled_positions=positions,
cell=cell,
numbers=atomic_numbers,
pbc=True)
if output_format is None:
try:
atoms.write(output_file, vasp5=True)
except TypeError:
atoms.write(output_file)
elif output_format == "vasp":
atoms.write(output_file, format="vasp", vasp5=True)
else:
atoms.write(output_file, format=output_format)

atoms.write(output_file, format=output_format)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Repository = "https://github.com/ajjackson/mctools"
[tool.setuptools.dynamic]
readme = {file = ["README.md"]}

[tool.black]
line-length = 79

[project.scripts]
fold-prim = "mctools.generic.fold_prim:main"
get-energy = "mctools.generic.get_energy:main"
Expand Down
67 changes: 67 additions & 0 deletions tests/test_get_primitive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import textwrap

import ase
import ase.build
from numpy.random import RandomState
import pytest

from mctools.generic.get_primitive import main


FILENAME = "rattled_cu.extxyz"
OUTFILENAME = "POSCAR"
REF_STDOUT = textwrap.dedent(
"""\
Space group: Fm-3m (225)
Primitive cell vectors:
0.000 1.805 1.805
1.805 0.000 1.805
1.805 1.805 0.000
Atomic positions and proton numbers:
0.500 0.500 0.500\t29
"""
)
REF_POSCAR = textwrap.dedent(
"""\
Cu
1.0000000000000000
0.0000000000000000 1.8050080045893520 1.8050080045893520
1.8050080045893520 0.0000000000000000 1.8050080045893520
1.8050080045893520 1.8050080045893520 0.0000000000000000
Cu
1
Cartesian
1.8050080045893520 1.8050080045893520 1.8050080045893520
""" # noqa:W291
)


@pytest.fixture
def rattled_cu() -> ase.Atoms:
rng = RandomState(seed=1)

atoms = ase.build.bulk("Cu", cubic=True) * (2, 2, 2)
atoms.rattle(stdev=1e-3, seed=1)
atoms.set_cell(atoms.cell.array + 1e-4 * rng.rand(3, 3))

return atoms


def test_get_primitive(rattled_cu, tmp_path, capsys) -> None:
rattled_cu.write(tmp_path / FILENAME)

main([str(tmp_path / FILENAME),
"--input-format=extxyz",
"--threshold=1e-2",
"--angle-tolerance=1",
"-o",
str(tmp_path / OUTFILENAME),
"-v",
"--precision=3"
])

captured = capsys.readouterr()
assert captured.out == REF_STDOUT

with open(tmp_path / OUTFILENAME, "r") as fd:
assert fd.read() == REF_POSCAR
2 changes: 1 addition & 1 deletion tests/test_get_spacegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)


@pytest.fixture
@pytest.fixture(scope="module")
def symmetry_broken_cu() -> ase.Atoms:
atoms = ase.build.bulk("Cu", cubic=True) * (2, 2, 2)

Expand Down

0 comments on commit d23103b

Please sign in to comment.