-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from ajjackson/spglib
Update get-primitive for modern versions of spglib
- Loading branch information
Showing
4 changed files
with
195 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,115 +1,158 @@ | ||
import argparse | ||
"""Get primitive cell from crystal structure using spglib""" | ||
|
||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional | ||
|
||
import ase | ||
import ase.io | ||
import spglib | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Find a primitive unit cell using pyspglib") | ||
def get_parser() -> ArgumentParser: | ||
parser = ArgumentParser( | ||
description="Find a primitive unit cell using spglib") | ||
parser.add_argument( | ||
'input_file', | ||
type=str, | ||
default='POSCAR', | ||
help="Path to crystal structure file, recognisable by ASE") | ||
"input_file", | ||
type=Path, | ||
default="POSCAR", | ||
help="Path to crystal structure file, recognisable by ASE", | ||
) | ||
parser.add_argument( | ||
'--input_format', | ||
"--input-format", | ||
dest="input_format", | ||
type=str, | ||
help="Format for input file (needed if ASE can't guess from filename)") | ||
help="Format for input file (needed if ASE can't guess from filename)", | ||
) | ||
parser.add_argument( | ||
'-t', | ||
'--threshold', | ||
"-t", | ||
"--threshold", | ||
type=float, | ||
default=1e-05, | ||
help=("Distance threshold in AA for symmetry reduction " | ||
"(corresponds to spglib 'symprec' keyword)")) | ||
help=( | ||
"Distance threshold in AA for symmetry reduction " | ||
"(corresponds to spglib 'symprec' keyword)" | ||
), | ||
) | ||
parser.add_argument( | ||
'-a', | ||
'--angle_tolerance', | ||
"-a", | ||
"--angle-tolerance", | ||
dest="angle_tolerance", | ||
type=float, | ||
default=-1.0, | ||
help="Angle tolerance for symmetry reduction") | ||
help="Angle tolerance for symmetry reduction", | ||
) | ||
parser.add_argument( | ||
'-o', '--output_file', default=None, help="Path/filename for output") | ||
"-o", | ||
"--output-file", | ||
type=Path, | ||
default=None, | ||
dest="output_file", | ||
help="Path/filename for output", | ||
) | ||
parser.add_argument( | ||
'--output_format', | ||
"--output-format", | ||
dest="output_format", | ||
type=str, | ||
help="Format for input file (needed if ASE can't guess from filename)") | ||
default=None, | ||
help="Format for input file (needed if ASE can't guess from filename)", | ||
) | ||
parser.add_argument( | ||
'-v', | ||
'--verbose', | ||
"-v", | ||
"--verbose", | ||
action="store_true", | ||
help="Print output to screen even when writing to file.") | ||
args = parser.parse_args() | ||
help="Print output to screen even when writing to file.", | ||
) | ||
parser.add_argument( | ||
"-p", | ||
"--precision", | ||
type=int, | ||
help=("Number of decimal places for float display. " | ||
"(Output files are not affected)"), | ||
default=6, | ||
) | ||
return parser | ||
|
||
get_primitive(**vars(args)) | ||
|
||
def main(params: Optional[List[str]] = None): | ||
parser = get_parser() | ||
|
||
def get_primitive(input_file='POSCAR', | ||
input_format=None, | ||
output_file=None, | ||
output_format=None, | ||
threshold=1e-5, | ||
angle_tolerance=-1., | ||
verbose=False): | ||
if params: | ||
args = parser.parse_args(params) | ||
else: | ||
args = parser.parse_args() | ||
|
||
if output_file is None: | ||
verbose = True | ||
get_primitive(**snake_case_args(vars(args))) | ||
|
||
if verbose: | ||
|
||
def vprint(*args): | ||
for arg in args: | ||
print(arg,) | ||
print("") | ||
else: | ||
def snake_case_args(kwarg_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Convert user-friendly hyphenated arguments to python_friendly ones""" | ||
return {key.replace("-", "_"): value for key, value in kwarg_dict.items()} | ||
|
||
def vprint(*args): | ||
pass | ||
|
||
try: | ||
if input_format is None: | ||
A = ase.io.read(input_file) | ||
else: | ||
A = ase.io.read(input_file, format=input_format) | ||
except IOError as e: | ||
raise Exception("I/O error({0}): {1}".format(e.errno, e.strerror)) | ||
def get_primitive_atoms( | ||
atoms: ase.Atoms, | ||
threshold: float = 1e-5, | ||
angle_tolerance: float = -1.0, | ||
print_spacegroup: bool = False, | ||
) -> ase.Atoms: | ||
"""Convert ASE Atoms to primitive cell using spglib""" | ||
atoms_spglib = ( | ||
atoms.cell.array, | ||
atoms.get_scaled_positions(), | ||
atoms.numbers, | ||
) | ||
|
||
vprint( | ||
"# Space group: ", | ||
str( | ||
spglib.get_spacegroup( | ||
A, symprec=threshold, angle_tolerance=angle_tolerance)), | ||
'\n') | ||
spacegroup = spglib.get_spacegroup( | ||
atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) | ||
if print_spacegroup: | ||
print(f"Space group: {spacegroup}") | ||
|
||
cell, positions, atomic_numbers = spglib.find_primitive( | ||
A, symprec=threshold, angle_tolerance=angle_tolerance) | ||
atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) | ||
|
||
primitive_atoms = ase.Atoms( | ||
scaled_positions=positions, | ||
cell=cell, | ||
numbers=atomic_numbers, | ||
pbc=True) | ||
|
||
return primitive_atoms | ||
|
||
|
||
if positions is None: | ||
print("This space group doesn't have a more primitive unit cell.") | ||
def get_primitive(input_file: Path = Path('POSCAR'), | ||
input_format: Optional[str] = None, | ||
output_file: Optional[Path] = None, | ||
output_format: Optional[str] = None, | ||
threshold: float = 1e-5, | ||
angle_tolerance: float = -1., | ||
verbose: bool = False, | ||
precision: int = 6) -> None: | ||
|
||
if output_file is None: | ||
verbose = True | ||
|
||
float_format_str = f"{{:{precision+4}.{precision}f}}" | ||
|
||
def format_float(x: float) -> str: | ||
return float_format_str.format(x) | ||
|
||
atoms = ase.io.read(input_file, format=input_format) | ||
atoms = get_primitive_atoms( | ||
atoms, threshold=threshold, angle_tolerance=angle_tolerance, print_spacegroup=verbose | ||
) | ||
|
||
if verbose: | ||
print("Primitive cell vectors:") | ||
for row in atoms.cell: | ||
print(" ".join(map(format_float, row))) | ||
|
||
print("Atomic positions and proton numbers:") | ||
for position, number in zip(atoms.get_scaled_positions(), atoms.numbers): | ||
print(" ".join(map(format_float, position)) + f"\t{number}") | ||
|
||
if output_file is None: | ||
pass | ||
else: | ||
vprint("Primitive cell vectors:") | ||
vprint(cell, '\n') | ||
vprint("Atomic positions and proton numbers:") | ||
for position, number in zip(positions, atomic_numbers): | ||
vprint(position, '\t', number) | ||
|
||
if output_file is None: | ||
pass | ||
else: | ||
atoms = ase.Atoms( | ||
scaled_positions=positions, | ||
cell=cell, | ||
numbers=atomic_numbers, | ||
pbc=True) | ||
if output_format is None: | ||
try: | ||
atoms.write(output_file, vasp5=True) | ||
except TypeError: | ||
atoms.write(output_file) | ||
elif output_format == "vasp": | ||
atoms.write(output_file, format="vasp", vasp5=True) | ||
else: | ||
atoms.write(output_file, format=output_format) | ||
|
||
atoms.write(output_file, format=output_format) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import textwrap | ||
|
||
import ase | ||
import ase.build | ||
from numpy.random import RandomState | ||
import pytest | ||
|
||
from mctools.generic.get_primitive import main | ||
|
||
|
||
FILENAME = "rattled_cu.extxyz" | ||
OUTFILENAME = "POSCAR" | ||
REF_STDOUT = textwrap.dedent( | ||
"""\ | ||
Space group: Fm-3m (225) | ||
Primitive cell vectors: | ||
0.000 1.805 1.805 | ||
1.805 0.000 1.805 | ||
1.805 1.805 0.000 | ||
Atomic positions and proton numbers: | ||
0.500 0.500 0.500\t29 | ||
""" | ||
) | ||
REF_POSCAR = textwrap.dedent( | ||
"""\ | ||
Cu | ||
1.0000000000000000 | ||
0.0000000000000000 1.8050080045893520 1.8050080045893520 | ||
1.8050080045893520 0.0000000000000000 1.8050080045893520 | ||
1.8050080045893520 1.8050080045893520 0.0000000000000000 | ||
Cu | ||
1 | ||
Cartesian | ||
1.8050080045893520 1.8050080045893520 1.8050080045893520 | ||
""" # noqa:W291 | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def rattled_cu() -> ase.Atoms: | ||
rng = RandomState(seed=1) | ||
|
||
atoms = ase.build.bulk("Cu", cubic=True) * (2, 2, 2) | ||
atoms.rattle(stdev=1e-3, seed=1) | ||
atoms.set_cell(atoms.cell.array + 1e-4 * rng.rand(3, 3)) | ||
|
||
return atoms | ||
|
||
|
||
def test_get_primitive(rattled_cu, tmp_path, capsys) -> None: | ||
rattled_cu.write(tmp_path / FILENAME) | ||
|
||
main([str(tmp_path / FILENAME), | ||
"--input-format=extxyz", | ||
"--threshold=1e-2", | ||
"--angle-tolerance=1", | ||
"-o", | ||
str(tmp_path / OUTFILENAME), | ||
"-v", | ||
"--precision=3" | ||
]) | ||
|
||
captured = capsys.readouterr() | ||
assert captured.out == REF_STDOUT | ||
|
||
with open(tmp_path / OUTFILENAME, "r") as fd: | ||
assert fd.read() == REF_POSCAR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters