diff --git a/mctools/generic/get_spacegroup.py b/mctools/generic/get_spacegroup.py index 95fbfe3..1fbc923 100755 --- a/mctools/generic/get_spacegroup.py +++ b/mctools/generic/get_spacegroup.py @@ -2,27 +2,28 @@ import argparse import os +from pathlib import Path +from typing import Optional import ase.io import spglib -def get_spacegroup(filename=False, format=False): - - if filename: - pass - elif os.path.isfile('geometry.in'): - filename = 'geometry.in' - elif os.path.isfile('POSCAR'): - filename = 'POSCAR' +def get_default_file() -> Path: + for candidate in ("geometry.in", "POSCAR", "castep.cell"): + if (input := Path.cwd() / candidate).is_file(): + return input else: - raise Exception('No input file!') + raise ValueError("Input file not specified, no default found.") - if format: - atoms = ase.io.read(filename, format=format) - else: - atoms = ase.io.read(filename) +def get_spacegroup(filename: Optional[Path] = None, + format: Optional[str] = None): + + if filename is None: + filename = get_default_file() + + atoms = ase.io.read(str(filename), format=format) cell = (atoms.cell.array, atoms.get_scaled_positions(), atoms.numbers) print("| Threshold / Å | Space group |") @@ -35,9 +36,9 @@ def get_spacegroup(filename=False, format=False): def main(): parser = argparse.ArgumentParser() - parser.add_argument('filename', action='store', default=False, + parser.add_argument('filename', type=Path, default=None, nargs="?", help="Input structure file") - parser.add_argument('-f', '--format', action='store', default=False, + parser.add_argument('-f', '--format', type=str, default=None, help="File format for ASE importer") args = parser.parse_args() get_spacegroup(filename=args.filename, format=args.format)