diff --git a/mctools/generic/get_energy.py b/mctools/generic/get_energy.py index dc3267a..bf64a92 100644 --- a/mctools/generic/get_energy.py +++ b/mctools/generic/get_energy.py @@ -1,4 +1,6 @@ from argparse import ArgumentParser +from typing import List, Optional + import ase.io @@ -8,7 +10,7 @@ def get_energy(filename): return atoms.get_total_energy() -def main(): +def main(params: Optional[List[str]] = None): """Get calculated energy from output file using ASE""" parser = ArgumentParser(description="Read energy from output") @@ -16,7 +18,10 @@ def main(): default="vasprun.xml", help="Path to ab initio output file") - args = parser.parse_args() + if params: + args = parser.parse_args(params) + else: + args = parser.parse_args() energy = get_energy(args.filename) print(energy) diff --git a/tests/test_get_energy.py b/tests/test_get_energy.py index 9c80d69..5ef60da 100644 --- a/tests/test_get_energy.py +++ b/tests/test_get_energy.py @@ -3,7 +3,7 @@ import ase.io import pytest -from mctools.generic.get_energy import get_energy +from mctools.generic.get_energy import get_energy, main ENERGY = 3.141 @@ -21,3 +21,11 @@ def test_get_energy(methane_with_energy, tmp_path) -> None: ase.io.write(tmp_path / FILENAME, methane_with_energy) assert get_energy(tmp_path / FILENAME) == pytest.approx(ENERGY) + + +def test_main(methane_with_energy, tmp_path, capsys) -> None: + ase.io.write(tmp_path / FILENAME, methane_with_energy) + + main([str(tmp_path / FILENAME)]) + captured = capsys.readouterr() + assert float(captured.out) == pytest.approx(ENERGY)