From 3d8c5c8f21aa952604514c813b488c926359637b Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 09:28:24 +0100 Subject: [PATCH 01/11] get_spacegroup: spglib api update, remove some redundant logic --- mctools/generic/get_primitive.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/mctools/generic/get_primitive.py b/mctools/generic/get_primitive.py index e95ab66..860d7d9 100755 --- a/mctools/generic/get_primitive.py +++ b/mctools/generic/get_primitive.py @@ -68,19 +68,14 @@ def vprint(*args): def vprint(*args): pass - try: - if input_format is None: - A = ase.io.read(input_file) - else: - A = ase.io.read(input_file, format=input_format) - except IOError as e: - raise Exception("I/O error({0}): {1}".format(e.errno, e.strerror)) + atoms = ase.io.read(input_file, format=input_format) + cell = (atoms.cell.array, atoms.get_scaled_positions(), atoms.numbers) vprint( "# Space group: ", str( spglib.get_spacegroup( - A, symprec=threshold, angle_tolerance=angle_tolerance)), + cell, symprec=threshold, angle_tolerance=angle_tolerance)), '\n') cell, positions, atomic_numbers = spglib.find_primitive( From 5618fda8c9aa1b89624a861a00ada8b0c1d90b13 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 09:33:27 +0100 Subject: [PATCH 02/11] Configure line length for black I'm not running black on everything (yet...) but black-macchiato can by a nice tool for formatting while editing. I prefer to stick to 79-char lines at this stage, so putthing this in pyprojct.toml makes life easy. --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b6c95a8..73c8f8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ Repository = "https://github.com/ajjackson/mctools" [tool.setuptools.dynamic] readme = {file = ["README.md"]} +[tool.black] +line-length = 79 + [project.scripts] fold-prim = "mctools.generic.fold_prim:main" get-energy = "mctools.generic.get_energy:main" From e8d0d816cb62f2c881d688552a2c7a718cc5005d Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 09:42:25 +0100 Subject: [PATCH 03/11] get_primitive spglib updates --- mctools/generic/get_primitive.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mctools/generic/get_primitive.py b/mctools/generic/get_primitive.py index 860d7d9..08409c2 100755 --- a/mctools/generic/get_primitive.py +++ b/mctools/generic/get_primitive.py @@ -58,28 +58,26 @@ def get_primitive(input_file='POSCAR', verbose = True if verbose: - def vprint(*args): for arg in args: print(arg,) print("") else: - def vprint(*args): pass atoms = ase.io.read(input_file, format=input_format) - cell = (atoms.cell.array, atoms.get_scaled_positions(), atoms.numbers) + atoms_spglib = (atoms.cell.array, atoms.get_scaled_positions(), atoms.numbers) vprint( "# Space group: ", str( spglib.get_spacegroup( - cell, symprec=threshold, angle_tolerance=angle_tolerance)), + atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance)), '\n') cell, positions, atomic_numbers = spglib.find_primitive( - A, symprec=threshold, angle_tolerance=angle_tolerance) + atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) if positions is None: print("This space group doesn't have a more primitive unit cell.") From 75c47bc1325dc61044ef3ce3765779b59681fbb3 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 09:42:46 +0100 Subject: [PATCH 04/11] ASE VASP writer defaults to vasp5 now, can simplify logic here --- mctools/generic/get_primitive.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/mctools/generic/get_primitive.py b/mctools/generic/get_primitive.py index 08409c2..59467d5 100755 --- a/mctools/generic/get_primitive.py +++ b/mctools/generic/get_primitive.py @@ -97,12 +97,5 @@ def vprint(*args): cell=cell, numbers=atomic_numbers, pbc=True) - if output_format is None: - try: - atoms.write(output_file, vasp5=True) - except TypeError: - atoms.write(output_file) - elif output_format == "vasp": - atoms.write(output_file, format="vasp", vasp5=True) - else: - atoms.write(output_file, format=output_format) + + atoms.write(output_file, format=output_format) From 24eae93edbb4fa99cc0d3a85b9ec6dadef77ce1f Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 09:46:08 +0100 Subject: [PATCH 05/11] get_primitive returns primitive even when starting with primitive cell --- mctools/generic/get_primitive.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/mctools/generic/get_primitive.py b/mctools/generic/get_primitive.py index 59467d5..57bcea1 100755 --- a/mctools/generic/get_primitive.py +++ b/mctools/generic/get_primitive.py @@ -79,23 +79,19 @@ def vprint(*args): cell, positions, atomic_numbers = spglib.find_primitive( atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) - if positions is None: - print("This space group doesn't have a more primitive unit cell.") + vprint("Primitive cell vectors:") + vprint(cell, '\n') + vprint("Atomic positions and proton numbers:") + for position, number in zip(positions, atomic_numbers): + vprint(position, '\t', number) + if output_file is None: + pass else: - vprint("Primitive cell vectors:") - vprint(cell, '\n') - vprint("Atomic positions and proton numbers:") - for position, number in zip(positions, atomic_numbers): - vprint(position, '\t', number) - - if output_file is None: - pass - else: - atoms = ase.Atoms( - scaled_positions=positions, - cell=cell, - numbers=atomic_numbers, - pbc=True) + atoms = ase.Atoms( + scaled_positions=positions, + cell=cell, + numbers=atomic_numbers, + pbc=True) - atoms.write(output_file, format=output_format) + atoms.write(output_file, format=output_format) From e82ea5363421a3e92ff424566df292a5e82bdbd3 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 09:56:20 +0100 Subject: [PATCH 06/11] get_primitive: avoid underscores in CLI parameters --- mctools/generic/get_primitive.py | 62 +++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/mctools/generic/get_primitive.py b/mctools/generic/get_primitive.py index 57bcea1..543f74b 100755 --- a/mctools/generic/get_primitive.py +++ b/mctools/generic/get_primitive.py @@ -1,4 +1,5 @@ import argparse +from typing import Any, Dict import ase import ase.io @@ -7,43 +8,64 @@ def main(): parser = argparse.ArgumentParser( - description="Find a primitive unit cell using pyspglib") + description="Find a primitive unit cell using spglib") parser.add_argument( - 'input_file', + "input_file", type=str, - default='POSCAR', - help="Path to crystal structure file, recognisable by ASE") + default="POSCAR", + help="Path to crystal structure file, recognisable by ASE", + ) parser.add_argument( - '--input_format', + "--input-format", + dest="input_format", type=str, - help="Format for input file (needed if ASE can't guess from filename)") + help="Format for input file (needed if ASE can't guess from filename)", + ) parser.add_argument( - '-t', - '--threshold', + "-t", + "--threshold", type=float, default=1e-05, - help=("Distance threshold in AA for symmetry reduction " - "(corresponds to spglib 'symprec' keyword)")) + help=( + "Distance threshold in AA for symmetry reduction " + "(corresponds to spglib 'symprec' keyword)" + ), + ) parser.add_argument( - '-a', - '--angle_tolerance', + "-a", + "--angle-tolerance", + dest="angle_tolerance", type=float, default=-1.0, - help="Angle tolerance for symmetry reduction") + help="Angle tolerance for symmetry reduction", + ) parser.add_argument( - '-o', '--output_file', default=None, help="Path/filename for output") + "-o", + "--output-file", + default=None, + dest="output_file", + help="Path/filename for output", + ) parser.add_argument( - '--output_format', + "--output-format", + dest="output_format", type=str, - help="Format for input file (needed if ASE can't guess from filename)") + help="Format for input file (needed if ASE can't guess from filename)", + ) parser.add_argument( - '-v', - '--verbose', + "-v", + "--verbose", action="store_true", - help="Print output to screen even when writing to file.") + help="Print output to screen even when writing to file.", + ) args = parser.parse_args() - get_primitive(**vars(args)) + get_primitive(**snake_case_args(vars(args))) + + +def snake_case_args(kwarg_dict: Dict[str, Any]) -> Dict[str, Any]: + """Convert user-friendly hyphenated arguments to python_friendly ones""" + return {key.replace("-", "_"): value for key, value in kwarg_dict.items()} def get_primitive(input_file='POSCAR', From 77e50e9d1d74f23c6c9ee1320a42e1838df956f5 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 10:12:11 +0100 Subject: [PATCH 07/11] Improve appearance of get_primitive output, add --precision param --- mctools/generic/get_primitive.py | 42 +++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/mctools/generic/get_primitive.py b/mctools/generic/get_primitive.py index 543f74b..02ef6ab 100755 --- a/mctools/generic/get_primitive.py +++ b/mctools/generic/get_primitive.py @@ -58,6 +58,14 @@ def main(): action="store_true", help="Print output to screen even when writing to file.", ) + parser.add_argument( + "-p", + "--precision", + type=int, + help=("Number of decimal places for float display. " + "(Output files are not affected)"), + default=6, + ) args = parser.parse_args() get_primitive(**snake_case_args(vars(args))) @@ -74,38 +82,50 @@ def get_primitive(input_file='POSCAR', output_format=None, threshold=1e-5, angle_tolerance=-1., - verbose=False): + verbose=False, + precision=6): if output_file is None: verbose = True if verbose: + def vprint(*args): for arg in args: - print(arg,) + print(arg, end="") print("") + else: + def vprint(*args): pass + float_format_str = f"{{:{precision+4}.{precision}f}}" + + def format_float(x: float) -> str: + return float_format_str.format(x) + atoms = ase.io.read(input_file, format=input_format) - atoms_spglib = (atoms.cell.array, atoms.get_scaled_positions(), atoms.numbers) + atoms_spglib = ( + atoms.cell.array, + atoms.get_scaled_positions(), + atoms.numbers, + ) - vprint( - "# Space group: ", - str( - spglib.get_spacegroup( - atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance)), - '\n') + spacegroup = spglib.get_spacegroup( + atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) + vprint(f"Space group: {spacegroup}") cell, positions, atomic_numbers = spglib.find_primitive( atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) vprint("Primitive cell vectors:") - vprint(cell, '\n') + for row in cell: + vprint(' '.join(map(format_float, row))) + vprint("Atomic positions and proton numbers:") for position, number in zip(positions, atomic_numbers): - vprint(position, '\t', number) + vprint(' '.join(map(format_float, position)), '\t', number) if output_file is None: pass From 35ce84cd0fc94c4522f41d4c443d2d693ced075f Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 11:34:06 +0100 Subject: [PATCH 08/11] Refactor get_primitive argument-handling, add test --- mctools/generic/get_primitive.py | 20 +++++++--- tests/test_get_primitive.py | 65 ++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 5 deletions(-) create mode 100644 tests/test_get_primitive.py diff --git a/mctools/generic/get_primitive.py b/mctools/generic/get_primitive.py index 02ef6ab..fc6a8c2 100755 --- a/mctools/generic/get_primitive.py +++ b/mctools/generic/get_primitive.py @@ -1,13 +1,13 @@ -import argparse -from typing import Any, Dict +from argparse import ArgumentParser +from typing import Any, Dict, List, Optional import ase import ase.io import spglib -def main(): - parser = argparse.ArgumentParser( +def get_parser() -> ArgumentParser: + parser = ArgumentParser( description="Find a primitive unit cell using spglib") parser.add_argument( "input_file", @@ -50,6 +50,7 @@ def main(): "--output-format", dest="output_format", type=str, + default=None, help="Format for input file (needed if ASE can't guess from filename)", ) parser.add_argument( @@ -66,7 +67,16 @@ def main(): "(Output files are not affected)"), default=6, ) - args = parser.parse_args() + return parser + + +def main(params: Optional[List[str]] = None): + parser = get_parser() + + if params: + args = parser.parse_args(params) + else: + args = parser.parse_args() get_primitive(**snake_case_args(vars(args))) diff --git a/tests/test_get_primitive.py b/tests/test_get_primitive.py new file mode 100644 index 0000000..6581b89 --- /dev/null +++ b/tests/test_get_primitive.py @@ -0,0 +1,65 @@ +import textwrap + +import ase +import ase.build +from numpy.random import RandomState +import pytest + +from mctools.generic.get_primitive import main + + +FILENAME = "rattled_cu.extxyz" +OUTFILENAME = "POSCAR" +REF_STDOUT = textwrap.dedent( + """\ + Space group: Fm-3m (225) + Primitive cell vectors: + 0.000 1.805 1.805 + 1.805 0.000 1.805 + 1.805 1.805 0.000 + Atomic positions and proton numbers: + 0.500 0.500 0.500\t29 + """ +) +REF_POSCAR = textwrap.dedent( + """\ + Cu + 1.0000000000000000 + 0.0000000000000000 1.8050080045893520 1.8050080045893520 + 1.8050080045893520 0.0000000000000000 1.8050080045893520 + 1.8050080045893520 1.8050080045893520 0.0000000000000000 + Cu + 1 + Cartesian + 1.8050080045893520 1.8050080045893520 1.8050080045893520 + """ +) + +@pytest.fixture +def rattled_cu() -> ase.Atoms: + rng = RandomState(seed=1) + + atoms = ase.build.bulk("Cu", cubic=True) * (2, 2, 2) + atoms.rattle(stdev=1e-3, seed=1) + atoms.set_cell(atoms.cell.array + 1e-4 * rng.rand(3, 3)) + + return atoms + +def test_get_primitive(rattled_cu, tmp_path, capsys) -> None: + rattled_cu.write(tmp_path / FILENAME) + + main([str(tmp_path / FILENAME), + "--input-format=extxyz", + "--threshold=1e-2", + "--angle-tolerance=1", + "-o", + str(tmp_path / OUTFILENAME), + "-v", + "--precision=3" + ]) + + captured = capsys.readouterr() + assert captured.out == REF_STDOUT + + with open(tmp_path / OUTFILENAME, "r") as fd: + assert fd.read() == REF_POSCAR From edd69078ad94f02da227ad1dcc133b48bcfefdda Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 11:34:22 +0100 Subject: [PATCH 09/11] Scope test_get_spacegroup fixture to avoid recalculation --- tests/test_get_spacegroup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_get_spacegroup.py b/tests/test_get_spacegroup.py index aa41c16..d281b5f 100644 --- a/tests/test_get_spacegroup.py +++ b/tests/test_get_spacegroup.py @@ -26,7 +26,7 @@ ) -@pytest.fixture +@pytest.fixture(scope="module") def symmetry_broken_cu() -> ase.Atoms: atoms = ase.build.bulk("Cu", cubic=True) * (2, 2, 2) From 6b57c1d83aa0f63edb9ed1742ea6faec0b702d67 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 11:37:09 +0100 Subject: [PATCH 10/11] Linting: test_get_primitive --- tests/test_get_primitive.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_get_primitive.py b/tests/test_get_primitive.py index 6581b89..264c9d9 100644 --- a/tests/test_get_primitive.py +++ b/tests/test_get_primitive.py @@ -32,9 +32,10 @@ 1 Cartesian 1.8050080045893520 1.8050080045893520 1.8050080045893520 - """ + """ # noqa:W291 ) + @pytest.fixture def rattled_cu() -> ase.Atoms: rng = RandomState(seed=1) @@ -45,6 +46,7 @@ def rattled_cu() -> ase.Atoms: return atoms + def test_get_primitive(rattled_cu, tmp_path, capsys) -> None: rattled_cu.write(tmp_path / FILENAME) From 4d6e96e15cb8916b17689ab1034e459f1879e41f Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 29 May 2024 11:53:01 +0100 Subject: [PATCH 11/11] Refactor get_primitive - create an intermediate function that receives/returns ASE Atoms. This should be useful for Python scripting, as well as breaking up the main program here. - Simplify "verbose" printing logic - Iteration over atoms.cell doesn't require "array" to be addressed --- mctools/generic/get_primitive.py | 91 ++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 41 deletions(-) diff --git a/mctools/generic/get_primitive.py b/mctools/generic/get_primitive.py index fc6a8c2..ac325e6 100755 --- a/mctools/generic/get_primitive.py +++ b/mctools/generic/get_primitive.py @@ -1,4 +1,7 @@ +"""Get primitive cell from crystal structure using spglib""" + from argparse import ArgumentParser +from pathlib import Path from typing import Any, Dict, List, Optional import ase @@ -11,7 +14,7 @@ def get_parser() -> ArgumentParser: description="Find a primitive unit cell using spglib") parser.add_argument( "input_file", - type=str, + type=Path, default="POSCAR", help="Path to crystal structure file, recognisable by ASE", ) @@ -42,6 +45,7 @@ def get_parser() -> ArgumentParser: parser.add_argument( "-o", "--output-file", + type=Path, default=None, dest="output_file", help="Path/filename for output", @@ -86,29 +90,47 @@ def snake_case_args(kwarg_dict: Dict[str, Any]) -> Dict[str, Any]: return {key.replace("-", "_"): value for key, value in kwarg_dict.items()} -def get_primitive(input_file='POSCAR', - input_format=None, - output_file=None, - output_format=None, - threshold=1e-5, - angle_tolerance=-1., - verbose=False, - precision=6): +def get_primitive_atoms( + atoms: ase.Atoms, + threshold: float = 1e-5, + angle_tolerance: float = -1.0, + print_spacegroup: bool = False, +) -> ase.Atoms: + """Convert ASE Atoms to primitive cell using spglib""" + atoms_spglib = ( + atoms.cell.array, + atoms.get_scaled_positions(), + atoms.numbers, + ) - if output_file is None: - verbose = True + spacegroup = spglib.get_spacegroup( + atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) + if print_spacegroup: + print(f"Space group: {spacegroup}") - if verbose: + cell, positions, atomic_numbers = spglib.find_primitive( + atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) - def vprint(*args): - for arg in args: - print(arg, end="") - print("") + primitive_atoms = ase.Atoms( + scaled_positions=positions, + cell=cell, + numbers=atomic_numbers, + pbc=True) - else: + return primitive_atoms - def vprint(*args): - pass + +def get_primitive(input_file: Path = Path('POSCAR'), + input_format: Optional[str] = None, + output_file: Optional[Path] = None, + output_format: Optional[str] = None, + threshold: float = 1e-5, + angle_tolerance: float = -1., + verbose: bool = False, + precision: int = 6) -> None: + + if output_file is None: + verbose = True float_format_str = f"{{:{precision+4}.{precision}f}}" @@ -116,34 +138,21 @@ def format_float(x: float) -> str: return float_format_str.format(x) atoms = ase.io.read(input_file, format=input_format) - atoms_spglib = ( - atoms.cell.array, - atoms.get_scaled_positions(), - atoms.numbers, + atoms = get_primitive_atoms( + atoms, threshold=threshold, angle_tolerance=angle_tolerance, print_spacegroup=verbose ) - spacegroup = spglib.get_spacegroup( - atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) - vprint(f"Space group: {spacegroup}") - - cell, positions, atomic_numbers = spglib.find_primitive( - atoms_spglib, symprec=threshold, angle_tolerance=angle_tolerance) - - vprint("Primitive cell vectors:") - for row in cell: - vprint(' '.join(map(format_float, row))) + if verbose: + print("Primitive cell vectors:") + for row in atoms.cell: + print(" ".join(map(format_float, row))) - vprint("Atomic positions and proton numbers:") - for position, number in zip(positions, atomic_numbers): - vprint(' '.join(map(format_float, position)), '\t', number) + print("Atomic positions and proton numbers:") + for position, number in zip(atoms.get_scaled_positions(), atoms.numbers): + print(" ".join(map(format_float, position)) + f"\t{number}") if output_file is None: pass else: - atoms = ase.Atoms( - scaled_positions=positions, - cell=cell, - numbers=atomic_numbers, - pbc=True) atoms.write(output_file, format=output_format)