diff --git a/abipy/dynamics/cpx.py b/abipy/dynamics/cpx.py index 4ed4dfdc0..0e39d5050 100644 --- a/abipy/dynamics/cpx.py +++ b/abipy/dynamics/cpx.py @@ -23,7 +23,9 @@ import numpy as np import pandas as pd import dataclasses +import abipy.core.abinit_units as abu +from pathlib import Path from monty.functools import lazy_property from monty.bisect import find_le from abipy.core.mixins import TextFile, NotebookWriter @@ -33,8 +35,7 @@ get_fig_plotly, add_plotly_fig_kwargs, PlotlyRowColDesc, plotly_klabels, plotly_set_lims) - -def parse_file_with_blocks(filepath: PathLike, block_size: int) -> tuple[int, np.ndarray, list]: +def parse_file_with_blocks(filepath: PathLike, block_len: int) -> tuple[int, np.ndarray, list]: """ Parse a QE file whose format is: @@ -47,11 +48,11 @@ def parse_file_with_blocks(filepath: PathLike, block_size: int) -> tuple[int, np 0.78345550025202E+01 0.67881458009071E+01 0.12273445304341E+01 0.20762325213327E+01 0.59955571558384E+01 0.47335647385293E+01 - i.e. a header followed by block_size lines + i.e. a header followed by block_len lines Args: filepath: Filename - block_size: Number of lines in block. + block_len: Number of lines in block. Return: (number_of_steps, array_to_be_reshaped, list_of_headers) """ @@ -59,7 +60,7 @@ def parse_file_with_blocks(filepath: PathLike, block_size: int) -> tuple[int, np with open(filepath, "rt") as fh: for il, line in enumerate(fh): toks = line.split() - if il % (block_size + 1) == 0: + if il % (block_len + 1) == 0: #assert len(toks) == 2 headers.append([int(toks[0]), float(toks[1])]) nsteps += 1 @@ -250,3 +251,135 @@ def traj_to_qepos(traj_filepath: PathLike, pos_filepath: PathLike) -> None: fh.write(str(it) + '\n') for ia in range(natoms): fh.write(str(pos_tac[it,ia,0]) + ' ' + str(pos_tac[it,ia,1]) + ' ' + str(pos_tac[it,ia,2]) + '\n') + + +class Qe2Extxyz: + """ + Convert QE/CP output files into ASE extended xyz format. + + Example: + + from abipy.dynamics.cpx import Qe2Extxyz + converter = Qe2Extxyz.from_input("cp.in") + coverter.write_xyz("extended.xyz", take_every=1) + """ + + @classmethod + def from_input(cls, qe_input, prefix="cp"): + """ + Build object from QE/CP input assuming all the other output files + are located in the same directory with the given prefix. + """ + qe_input = Path(str(qe_input)).absolute() + directory = qe_input.cwd() + #from ase.io.espresso import read_fortran_namelist + #with open(, "rt") as fh: + # atoms = read_espresso_in(fh) + # fh.seek(0) + # sections, card_lines = read_fortran_namelist(fh) + #control = sections["control"] + #prefix = control.get("prefix", default_prefix) + #from ase.io.espresso import read_espresso_in + #with open(qe_input, "rt") as fh: + # self.initial_atoms = read_espresso_in(fh) + + pos_filepath = directory / f"{prefix}.pos" + cell_filepath = directory / f"{prefix}.cel" + evp_filepath = directory / f"{prefix}.evp" + str_filepath = directory / f"{prefix}.str" + if not str_filepath.exists(): + # File with stresses is optional. + str_filepath = None + + return cls(qe_input, pos_filepath, cell_filepath, evp_filepath, str_filepath=str_filepath) + + def __init__(self, qe_input, pos_filepath, cell_filepath, evp_filepath, str_filepath=None): + """ + Args: + qe_input: QE/CP input file + pos_filepath: File with cartesian positions. + cell_filepath: File with lattice vectors. + evp_filepath: File with energies + """ + print(f""" +Reading symbols from {qe_input=} +Reading positions from: {pos_filepath=} +Reading cell from: {cell_filepath=} +Reading energies from: {evp_filepath=} +Reading stresses from: {str_filepath=} + """) + + # Parse input file to get initial_atoms and symbols + from ase.io.espresso import read_espresso_in + with open(qe_input, "rt") as fh: + self.initial_atoms = read_espresso_in(fh) + natom = len(self.initial_atoms) + print("initial_atoms:", self.initial_atoms) + + # Parse cell. + cell_nsteps, cells, cell_headers = parse_file_with_blocks(cell_filepath, 3) + # FIXME: row vectors or column vectors? + self.cells = np.reshape(cells, (cell_nsteps, 3, 3)) + + # Parse stress. + self.stresses = None + if str_filepath is not None: + str_nsteps, stresses, str_headers = parse_file_with_blocks(str_filepath, 3) + if str_nsteps != cell_nsteps: + raise RuntimeError(f"{str_nsteps=} != {cell_nsteps=}") + + # FIXME: Clarify units for positions, forces and stresses. + + # Parse Cartesian positions. + pos_nsteps, pos_cart, pos_headers = parse_file_with_blocks(pos_filepath, natom) + self.pos_cart = np.reshape(pos_cart, (pos_nsteps, natom, 3)) + if pos_nsteps != cell_nsteps: + raise RuntimeError(f"{pos_nsteps=} != {cell_nsteps=}") + + # Parse Cartesian forces. + for_nsteps, forces_list, for_headers = parse_file_with_blocks(for_filepath, natom) + self.forces_list = np.reshape(forces_list, (for_nsteps, natom, 3)) + if for_nsteps != cell_nsteps: + raise RuntimeError(f"{for_nsteps=} != {cell_nsteps=}") + + # Get energies from evl file and convert from Ha to eV. + with EvpFile(evp_filepath) as evp: + self.energies = evp.df["etot"].values * abu.Ha_to_eV + if len(self.energies) != cell_nsteps: + raise RuntimeError(f"{len(energies)=} != {cell_nsteps=}") + + self.nsteps = cell_nsteps + + def write_xyz(self, xyz_filepath, take_every=1) -> None: + """ + Write results in ASE extended xyz format. + + Args: + xyz_filepath: Name of the XYZ file. + take_every: Used to downsample the trajectory. + """ + from ase.io import write + with open(xyz_filepath, "wt") as fh: + for istep, atoms in enumerate(self.yield_atoms()): + if istep % take_every != 0: continue + write(fh, atoms, format='extxyz', append=True) + + def yield_atoms(self): + """Yields ASE atoms along the trajectory.""" + from ase.atoms import Atoms + from ase.calculators.singlepoint import SinglePointCalculator + for istep in range(self.nsteps): + atoms = Atoms(symbols=self.initial_atoms.symbols, + positions=self.pos_cart[istep], + cell=self.cells[istep], + pbc=True, + ) + + # Attach calculator with results. + atoms.calc = SinglePointCalculator(atoms, + energy=self.energies[istep], + free_energy=self.energies[istep], + #forces=self.forces[istep], + stress=self.stresses[istep] if self.stresses is not None else None, + ) + yield atoms diff --git a/abipy/examples/plot/plot_lruj.py b/abipy/examples/plot/plot_lruj.py index b86b3f5a1..c93b92b64 100755 --- a/abipy/examples/plot/plot_lruj.py +++ b/abipy/examples/plot/plot_lruj.py @@ -4,6 +4,8 @@ ===================== This example shows how to parse the output file produced by lruj and plot the results + + See also https://docs.abinit.org/tutorial/lruj/#43-execution-of-the-lruj-post-processinng-utility """ from abipy.electrons.lruj import LrujAnalyzer, LrujResults @@ -17,8 +19,7 @@ #%% # Plot the fits. -lr.plot(degrees="all", insetdegree=4, ptcolor0='blue', - ptitle="Hello World", fontsize=9) +lr.plot(degrees="all", insetdegree=4, ptcolor0='blue', ptitle="Hello World", fontsize=9) #filepaths = [ # "tlruj_2.o_DS1_LRUJ.nc", diff --git a/abipy/ml/aseml.py b/abipy/ml/aseml.py index 77522ef35..822ca56a7 100644 --- a/abipy/ml/aseml.py +++ b/abipy/ml/aseml.py @@ -202,13 +202,14 @@ class AseTrajectoryPlotter: """ Plot an ASE trajectory with matplotlib. """ - def __init__(self, traj: Trajectory): self.traj = traj + self.natom = len(traj[0]) + self.traj_size = len(traj) @classmethod def from_file(cls, filepath: PathLike) -> AseTrajectoryPlotter: - """Initialize object from file.""" + """Initialize an instance from file filepath""" return cls(read(filepath, index=":")) def __str__(self) -> str: @@ -223,72 +224,128 @@ def to_string(self, verbose=0) -> str: app(first.to_string(verbose=verbose)) else: first, last = AseResults.from_atoms(self.traj[0]), AseResults.from_atoms(self.traj[-1]) - raise NotImplementedError() + app("First configuration:") + app(first.to_string(verbose=verbose)) + app("Last configuration:") + app(last.to_string(verbose=verbose)) return "\n".join(lines) - #@add_fig_kwargs - #def plot_lattice(self, what_list=("abc", "angles", "volume"), ax_list=None, - # fontsize=8, xlims=None, **kwargs) -> Figure: - # """ - # Plot lattice lengths/angles/volume as a function the trajectory index. - # """ - # energies = [ene=float(atoms.get_potential_energy()) for atoms in self.traj] - # - # stress_voigt = atoms.get_stress() - # forces=atoms.get_forces(), - # try: - # magmoms = atoms.get_magnetic_moments() - # except PropertyNotImplementedError: - # magmoms = None - - #@add_fig_kwargs - #def plot_lattice(self, what_list=("abc", "angles", "volume"), ax_list=None, - # fontsize=8, xlims=None, **kwargs) -> Figure: - # """ - # Plot lattice lengths/angles/volume as a function the trajectory index. - - # Args: - # what_list: List of strings specifying the quantities to plot. Default all - # ax_list: List of axis or None if a new figure should be created. - # fontsize: fontsize for legends and titles - # xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)`` - # or scalar e.g. ``left``. If left (right) is None, default values are used. - # """ - # what_list = list_strings(what_list) - # ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=1, ncols=len(what_list), - # sharex=True, sharey=False, squeeze=False) - # markers = ["o", "^", "v"] - - # if "abc" in what_list: - # # plot lattice parameters. - # for i, label in enumerate(["a", "b", "c"]): - # ax.plot(self.times, [lattice.abc[i] for lattice in self.lattices], - # label=label, marker=markers[i]) - # ax.set_ylabel("abc (A)") - - # if "angles" in what_list: - # # plot lattice angles. - # for i, label in enumerate(["alpha", "beta", "gamma"]): - # ax.plot(self.times, [lattice.angles[i] for lattice in self.lattices], - # label=label, marker=markers[i]) - # ax.set_ylabel(r"$\alpha\beta\gamma$ (degree)") - - # if "volume" in what_list: - # # plot lattice volume. - # marker = "o" - # ax.plot(self.times, [lattice.volume for lattice in self.lattices], - # label="Volume", marker=marker) - # ax.set_ylabel(r'$V\, (A^3)$') - - # for ix, ax in enumerate(ax_list): - # set_axlims(ax, xlims, "x") - # if ix == len(ax_list) - 1: - # ax.set_xlabel('t (ps)', fontsize=fontsize) - # ax.legend(loc="best", shadow=True, fontsize=fontsize) - - # return fig + @add_fig_kwargs + def plot(self, fontsize=8, xlims=None, **kwargs) -> Figure: + """ + Plot energies, force stats, and pressure as a function of the trajectory index. + """ + ax_list, fig, plt = get_axarray_fig_plt(None, nrows=3, ncols=1, + sharex=True, sharey=False, squeeze=True) + + # Plot total energy in eV. + energies = [float(atoms.get_potential_energy()) for atoms in self.traj] + ax = ax_list[0] + marker = "o" + ax.plot(energies, marker=marker) + ax.set_ylabel('Energy (eV)') + + # Plot Force stats. + forces_traj = np.reshape([atoms.get_forces() for atoms in self.traj], (self.traj_size, self.natom, 3)) + fmin_steps, fmax_steps, fmean_steps, fstd_steps = [], [], [], [] + for forces in forces_traj: + fmods = np.sqrt([np.dot(force, force) for force in forces]) + fmean_steps.append(fmods.mean()) + fstd_steps.append(fmods.std()) + fmin_steps.append(fmods.min()) + fmax_steps.append(fmods.max()) + + markers = ["o", "^", "v", "X"] + ax = ax_list[1] + ax.plot(fmin_steps, label="min |F|", marker=markers[0]) + ax.plot(fmax_steps, label="max |F|", marker=markers[1]) + ax.plot(fmean_steps, label="mean |F|", marker=markers[2]) + #ax.plot(fstd_steps, label="std |F|", marker=markers[3]) + ax.set_ylabel('F stats (eV/A)') + ax.legend(loc="best", shadow=True, fontsize=fontsize) + + # Plot pressure. + voigt_stresses_traj = np.reshape([atoms.get_stress() for atoms in self.traj], (self.traj_size, 6)) + pressures = [-sum(vs[0:3])/3 for vs in voigt_stresses_traj] + ax = ax_list[2] + ax.plot(pressures, marker=marker) + ax.set_ylabel('Pressure (GPa)') + + for ix, ax in enumerate(ax_list): + set_axlims(ax, xlims, "x") + ax.grid(True) + if ix == len(ax_list) - 1: + ax.set_xlabel('Trajectory index', fontsize=fontsize) + return fig + + @add_fig_kwargs + def plot_lattice(self, ax_list=None, + fontsize=8, xlims=None, **kwargs) -> Figure: + """ + Plot lattice lengths/angles/volume as a function the of the trajectory index. + + Args: + ax_list: List of axis or None if a new figure should be created. + fontsize: fontsize for legends and titles + xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)`` + or scalar e.g. ``left``. If left (right) is None, default values are used. + """ + ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=3, ncols=1, + sharex=True, sharey=False, squeeze=False) + ax_list = ax_list.ravel() + + def cell_dict(atoms): + return dict(zip(_CELLPAR_KEYS, atoms.cell.cellpar())) + + cellpar_list = [cell_dict(atoms) for atoms in self.traj] + df = pd.DataFrame(cellpar_list) + #print(df) + + # plot lattice parameters. + ax = ax_list[0] + markers = ["o", "^", "v"] + for i, label in enumerate(["a", "b", "c"]): + ax.plot(df[label].values, label=label, marker=markers[i]) + ax.set_ylabel("abc (A)") + + # plot lattice angles. + ax = ax_list[1] + for i, label in enumerate(["angle(b,c)", "angle(a,c)", "angle(a,b)"]): + ax.plot(df[label].values, label=label, marker=markers[i]) + ax.set_ylabel(r"$\alpha\beta\gamma$ (degree)") + + # plot lattice volume. + ax = ax_list[2] + volumes = [atoms.get_volume() for atoms in self.traj] + marker = "o" + ax.plot(volumes, label="Volume", marker=marker) + ax.set_ylabel(r'$V\, (A^3)$') + + for ix, ax in enumerate(ax_list): + set_axlims(ax, xlims, "x") + if ix == len(ax_list) - 1: + ax.set_xlabel('Trajectory index', fontsize=fontsize) + ax.legend(loc="best", shadow=True, fontsize=fontsize) + + return fig + + +def get_fstats(cart_forces: np.ndarray) -> dict: + """ + Return dictionary with statistics on cart_forces. + """ + fmods = np.array([np.linalg.norm(f) for f in cart_forces]) + #fmods = np.sqrt(np.einsum('ij, ij->i', cart_forces, cart_forces)) + #return AttrDict( + return dict( + fmin=fmods.min(), + fmax=fmods.max(), + fmean=fmods.mean(), + fstd=fmods.std(), + drift=np.linalg.norm(cart_forces.sum(axis=0)), + ) @dataclasses.dataclass @@ -361,26 +418,27 @@ def to_string(self, verbose: int = 0) -> str: for k, v in fstats.items(): app(f"{k} = {v} (eV/Ang)") - #if verbose: if True: + #if verbose: app('Forces (eV/Ang):') positions = self.atoms.get_positions() - df = pd.DataFrame(dict( + data = dict( x=positions[:,0], y=positions[:,1], z=positions[:,2], fx=self.forces[:,0], fy=self.forces[:,1], fz=self.forces[:,2], - )) - app(df.to_string()) + ) + # Add magmoms if available. + if self.magmoms is not None: + data["magmoms"] = self.magmoms - #if self.magmoms is not None: - # for ia, (atom, magmoms) in enumerate(zip(self.atoms, self.magmoms)): - # print(atom, magmoms) + df = pd.DataFrame(data) + app(df.to_string()) app('Stress tensor:') - for row in self.strees: + for row in self.stress: app(str(row)) return "\n".join(lines) @@ -389,16 +447,7 @@ def get_fstats(self) -> dict: """ Return dictionary with statistics on forces. """ - fmods = np.array([np.linalg.norm(force) for force in self.forces]) - #fmods = np.sqrt(np.einsum('ij, ij->i', forces, forces)) - #return AttrDict( - return dict( - fmin=fmods.min(), - fmax=fmods.max(), - fmean=fmods.mean(), - #fstd=fmods.std(), - drift=np.linalg.norm(self.forces.sum(axis=0)), - ) + return get_fstats(self.forces) def get_dict4pandas(self, with_geo=True, with_fstats=True) -> dict: """ @@ -1310,7 +1359,7 @@ def pip_install(nn_name) -> int: installed, versions = get_installed_nn_names(verbose=verbose, printout=False) for name in nn_names: - #print(f"About to install nn_name={name}") + print(f"About to install nn_name={name} ...") if name in black_list: print("Cannot install {name} with pip!") continue @@ -1346,9 +1395,11 @@ class CalcBuilder: "pyace", "nequip", "metatensor", + "deepmd", ] - def __init__(self, name: str, **kwargs): + + def __init__(self, name: str, dftd3_args=None, **kwargs): self.name = name # Extract nn_type and model_name from name @@ -1362,6 +1413,15 @@ def __init__(self, name: str, **kwargs): if self.nn_type not in self.ALL_NN_TYPES: raise ValueError(f"Invalid {name=}, it should be in {self.ALL_NN_TYPES=}") + # Handle DFTD3. + self.dftd3_args = dftd3_args + if self.dftd3_args and not isinstance(dftd3_args, dict): + # Load parameters from Yaml file. + self.dftd3_args = yaml_safe_load_path(self.dftd3_args) + + if self.dftd3_args: + print("Activating dftd3 with arguments:", self.dftd3_args) + self._model = None def __str__(self): @@ -1408,9 +1468,9 @@ class MyM3GNetCalculator(_MyCalculator, M3GNetCalculator): """Add abi_forces and abi_stress""" cls = MyM3GNetCalculator if with_delta else M3GNetCalculator - return cls(potential=self._model) + calc = cls(potential=self._model) - if self.nn_type == "matgl": + elif self.nn_type == "matgl": # See https://github.com/materialsvirtuallab/matgl try: import matgl @@ -1429,9 +1489,9 @@ class MyM3GNetCalculator(_MyCalculator, M3GNetCalculator): """Add abi_forces and abi_stress""" cls = MyM3GNetCalculator if with_delta else M3GNetCalculator - return cls(potential=self._model) + calc = cls(potential=self._model) - if self.nn_type == "chgnet": + elif self.nn_type == "chgnet": try: from chgnet.model.dynamics import CHGNetCalculator from chgnet.model.model import CHGNet @@ -1452,9 +1512,9 @@ class MyCHGNetCalculator(_MyCalculator, CHGNetCalculator): """Add abi_forces and abi_stress""" cls = MyCHGNetCalculator if with_delta else CHGNetCalculator - return cls(model=self._model) + calc = cls(model=self._model) - if self.nn_type == "alignn": + elif self.nn_type == "alignn": try: from alignn.ff.ff import AlignnAtomwiseCalculator, default_path, get_figshare_model_ff except ImportError as exc: @@ -1468,9 +1528,9 @@ class MyAlignnCalculator(_MyCalculator, AlignnAtomwiseCalculator): model_name = default_path() if self.model_name is None else self.model_name cls = MyAlignnCalculator if with_delta else AlignnAtomwiseCalculator - return cls(path=model_name) + calc = cls(path=model_name) - if self.nn_type == "pyace": + elif self.nn_type == "pyace": try: from pyace import PyACECalculator except ImportError as exc: @@ -1483,9 +1543,9 @@ class MyPyACECalculator(_MyCalculator, PyACECalculator): raise RuntimeError("PyACECalculator requires model_path e.g. nn_name='pyace@FILEPATH'") cls = MyPyACECalculator if with_delta else PyACECalculator - return cls(basis_set=self.model_path) + calc = cls(basis_set=self.model_path) - if self.nn_type == "mace": + elif self.nn_type == "mace": try: from mace.calculators import MACECalculator except ImportError as exc: @@ -1501,9 +1561,9 @@ class MyMACECalculator(_MyCalculator, MACECalculator): raise RuntimeError("MACECalculator requires model_path e.g. nn_name='mace@FILEPATH'") cls = MyMACECalculator if with_delta else MACECalculator - return cls(model_paths=self.model_path, device="cpu") #, default_dtype='float32') + calc = cls(model_paths=self.model_path, device="cpu") #, default_dtype='float32') - if self.nn_type == "nequip": + elif self.nn_type == "nequip": try: from nequip.ase.nequip_calculator import NequIPCalculator except ImportError as exc: @@ -1516,9 +1576,9 @@ class MyNequIPCalculator(_MyCalculator, NequIPCalculator): raise RuntimeError("NequIPCalculator requires model_path e.g. nn_name='nequip:FILEPATH'") cls = MyNequIPCalculator if with_delta else NequIPCalculator - return cls.from_deployed_model(modle_path=self.model_path, species_to_type_name=None) + calc = cls.from_deployed_model(modle_path=self.model_path, species_to_type_name=None) - if self.nn_type == "metatensor": + elif self.nn_type == "metatensor": try: from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator except ImportError as exc: @@ -1531,9 +1591,32 @@ class MyMetatensorCalculator(_MyCalculator, MetatensorCalculator): raise RuntimeError("MetaTensorCalculator requires model_path e.g. nn_name='metatensor:FILEPATH'") cls = MyMetaTensorCalculator if with_delta else MetatensorCalculator - return cls(self.model_path) + calc = cls(self.model_path) + + elif self.nn_type == "deepmd": + try: + from deepmd.calculator import DP + except ImportError as exc: + raise ImportError("deepmd not installed. See https://tutorials.deepmodeling.com/") from exc + + class MyDpCalculator(_MyCalculator, DP): + """Add abi_forces and abi_stress""" + + if self.model_path is None: + raise RuntimeError("DeepMD calculator requires model_path e.g. nn_name='deepmd:FILEPATH'") + + cls = MyDp if with_delta else Dp + calc = cls(self.model_path) + + else: + raise ValueError(f"Invalid {self.nn_type=}") + + # Include DFTD3 vDW corrections on top of ML potential. + if self.dftd3_args is not None: + from ase.calculators.dftd3 import DFTD3 + calc = DFTD3(dft=calc, **self.dftd3_args) - raise ValueError(f"Invalid {self.nn_type=}") + return calc class MlBase(HasPickleIO): @@ -2890,7 +2973,7 @@ def run(self): res = AseResults.from_atoms(self.atoms) print(res.to_string(verbose=self.verbose)) - # Write ASE traj file with results. + # Write ASE trajectory file with results. with open(self.workdir / "gs.traj", "wb") as fd: write_traj(fd, [self.atoms]) diff --git a/abipy/scripts/abiml.py b/abipy/scripts/abiml.py index 53a088579..2fa3e43f3 100755 --- a/abipy/scripts/abiml.py +++ b/abipy/scripts/abiml.py @@ -147,6 +147,8 @@ def add_nn_name_opt(f): """Add CLI options to select the NN potential.""" f = click.option("--nn-name", "-nn", default=DEFAULT_NN, show_default=True, help=f"ML potential to be used. Supported values are: {aseml.CalcBuilder.ALL_NN_TYPES}")(f) + #f = click.option("--dftd3", , default="no", show_default=True, + # help=f"Activate DFD3.")(f) return f diff --git a/abipy/scripts/abiopen.py b/abipy/scripts/abiopen.py index ca82362d5..9075f6af0 100755 --- a/abipy/scripts/abiopen.py +++ b/abipy/scripts/abiopen.py @@ -12,6 +12,7 @@ import argparse import subprocess import abipy.tools.cli_parsers as cli +from abipy.tools.plotting import Exposer from pprint import pprint from shutil import which @@ -143,8 +144,6 @@ def get_parser(with_epilog=False): "Default: FastList" ) parser.add_argument("--port", default=0, type=int, help="Allows specifying a specific port when serving panel app.") - - #add_expose_options_to_parser(parser) # Expose option. @@ -335,11 +334,15 @@ def handle_ase_traj(options): """Handle ASE trajectory file.""" from abipy.ml.aseml import AseTrajectoryPlotter plotter = AseTrajectoryPlotter.from_file(options.filepath) - print(plotter.to_string(verbose=options.verbose)) - #if len(plotter.traj) > 1): - plotter.plot_cell() - #plotter.plot_cell() + print(plotter.to_string(verbose=options.verbose)) + if options.expose: + print(plotter.to_string(verbose=options.verbose)) + if len(plotter.traj) > 1: + plot_kws = dict(show=False) + with Exposer.as_exposer("mpl") as e: + e(plotter.plot(**plot_kws)) + e(plotter.plot_lattice(**plot_kws)) return 0 diff --git a/abipy/tools/plotting.py b/abipy/tools/plotting.py index b52f46375..2b6f44118 100644 --- a/abipy/tools/plotting.py +++ b/abipy/tools/plotting.py @@ -1026,7 +1026,7 @@ class Exposer: Example: - kws = dict(show=False) + plot_kws = dict(show=False) with Exposer.as_exposer("panel") as e: e(obj.plot1(**plot_kws)) e(obj.plot2(**plot_kws))