From 456d46788cad64c14ba9e2a341a34658ec8192e4 Mon Sep 17 00:00:00 2001 From: Martin Schlipf Date: Mon, 10 Aug 2020 11:32:38 +0200 Subject: [PATCH] Create supercells using ASE instead of pymatgen (#12) * Replace pymatgen by ase * Use plot as wrapper routine * Implement supercell for structure class * Implement supercell for Viewer3d class * Use a from_structure method to initialize viewer * Raise exception if atom positions are needed but not set --- pyproject.toml | 11 ++++------ src/py4vasp/data/structure.py | 30 ++++++++++++++++---------- src/py4vasp/data/viewer3d.py | 36 ++++++++++++++++++------------- tests/data/test_structure.py | 33 +++++++++++++++++++++-------- tests/data/test_viewer3d.py | 40 ++++++++++++++++++++++++++++------- 5 files changed, 100 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 27e21ac..cb30773 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,13 +13,12 @@ h5py = "^2.10.0" numpy = "^1.17.4" pandas = "^0.25.3" cufflinks = "^0.17.3" -pymatgen = "^2020.4.2" nglview = "^2.7.5" ase = "^3.19.0" [tool.poetry.dev-dependencies] pytest = "^3.0" -black = {version = "^18.3-alpha.0", allows-prereleases = true} +black = {version = "^18.3-alpha.0", allow-prereleases = true} pytest-cov = "^2.8.1" pylint = "^2.5.0" sphinx = "^3.0.4" @@ -36,13 +35,11 @@ cufflinks = {channel = "conda-forge", name = "cufflinks-py"} mdtraj = {channel = "conda-forge"} nglview = {channel = "conda-forge"} ase = {channel = "conda-forge"} -pymatgen = {channel = "conda-forge"} sphinx = {channel = "anaconda"} -[build-system] -requires = ["poetry>=1.0"] -build-backend = "poetry.masonry.api" - [tool.dephell.main] from = {format = "poetry", path = "pyproject.toml"} to = {format = "setuppy", path = "setup.py"} +[build-system] +requires = ["poetry>=1.0"] +build-backend = "poetry.masonry.api" diff --git a/src/py4vasp/data/structure.py b/src/py4vasp/data/structure.py index 7495bb6..e3f2bf8 100644 --- a/src/py4vasp/data/structure.py +++ b/src/py4vasp/data/structure.py @@ -1,5 +1,7 @@ from py4vasp.data import _util from py4vasp.data import Viewer3d +import ase +import numpy as np class Structure: @@ -13,23 +15,29 @@ def to_dict(self): return { "cell": self._raw.cell.lattice_vectors[:], "cartesian_positions": self._raw.cartesian_positions[:], - "species": list(self._raw.species), + "species": self._raw.species, } def __len__(self): return len(self._raw.cartesian_positions) - def to_pymatgen(self): - import pymatgen as mg - - return mg.Structure( - lattice=mg.Lattice(self._raw.cell.lattice_vectors), - species=[specie.decode("ascii") for specie in self._raw.species], - coords=self._raw.cartesian_positions, - coords_are_cartesian=True, + def to_ase(self, supercell=None): + data = self.to_dict() + species = [_util.decode_if_possible(sp) for sp in data["species"]] + structure = ase.Atoms( + symbols=species, + cell=data["cell"], + positions=data["cartesian_positions"], + pbc=True, ) + if supercell is not None: + structure *= supercell + return structure + + def plot(self, *args): + return self.to_ngl(*args) - def plot(self, supercell=None): - viewer = Viewer3d(self, supercell=supercell) + def to_ngl(self, supercell=None): + viewer = Viewer3d.from_structure(self, supercell=supercell) viewer.show_cell() return viewer diff --git a/src/py4vasp/data/viewer3d.py b/src/py4vasp/data/viewer3d.py index 53a253c..7600781 100644 --- a/src/py4vasp/data/viewer3d.py +++ b/src/py4vasp/data/viewer3d.py @@ -1,3 +1,4 @@ +from py4vasp.exceptions import RefinementException from typing import NamedTuple import nglview import numpy as np @@ -18,12 +19,22 @@ class _Arrow3d(NamedTuple): class Viewer3d: """Collection of data and elements to be displayed in a structure viewer""" - def __init__(self, structure, supercell=None): - self._structure = structure - self._axes = None - self._arrows = [] - self.supercell = supercell - self._ngl = self.show() + _positions = None + _multiple_cells = 1 + _axes = None + _arrows = [] + + def __init__(self, ngl): + self._ngl = ngl + + @classmethod + def from_structure(cls, structure, supercell=None): + ase = structure.to_ase(supercell) + res = cls(nglview.show_ase(ase)) + res._positions = ase.positions + if supercell is not None: + res._multiple_cells = np.prod(supercell) + return res def _ipython_display_(self): self._ngl._ipython_display_() @@ -51,8 +62,10 @@ def hide_axes(self): self._axes = None def show_arrows_at_atoms(self, arrows, color=[0.1, 0.1, 0.8]): - structure = self._structure.to_pymatgen() - for tail, arrow in zip(structure.cart_coords, arrows): + if self._positions is None: + raise RefinementException("Positions of atoms are not known.") + arrows = np.repeat(arrows, self._multiple_cells, axis=0) + for tail, arrow in zip(self._positions, arrows): tip = tail + arrow arrow = _Arrow3d(tail, tip, color) self._arrows.append(self._make_arrow(arrow)) @@ -64,10 +77,3 @@ def hide_arrows_at_atoms(self): def _make_arrow(self, arrow): return self._ngl.shape.add_arrow(*arrow) - - def show(self): - structure = self._structure.to_pymatgen() - if self.supercell is not None: - structure.make_supercell(self.supercell) - view = nglview.show_pymatgen(structure) - return view diff --git a/tests/data/test_structure.py b/tests/data/test_structure.py index 1c5eb6e..e6f5630 100644 --- a/tests/data/test_structure.py +++ b/tests/data/test_structure.py @@ -8,7 +8,7 @@ @pytest.fixture def raw_structure(): - number_atoms = 20 + number_atoms = 6 shape = (number_atoms, 3) return raw.Structure( cell=raw.Cell(scale=1.0, lattice_vectors=np.eye(3)), @@ -25,18 +25,17 @@ def get_messages_after_structure_information(view): def test_read(raw_structure, Assert): actual = Structure(raw_structure).read() - assert (actual["cell"] == raw_structure.cell.lattice_vectors).all() + Assert.allclose(actual["cell"], raw_structure.cell.lattice_vectors) Assert.allclose(actual["cartesian_positions"], raw_structure.cartesian_positions) assert (actual["species"] == raw_structure.species).all() -def test_to_pymatgen(raw_structure): - structure = Structure(raw_structure) - mg_structure = structure.to_pymatgen() - a, b, c = mg_structure.lattice.as_dict()["matrix"] - assert a == [1, 0, 0] - assert b == [0, 1, 0] - assert c == [0, 0, 1] +def test_to_ase(raw_structure, Assert): + structure = Structure(raw_structure).to_ase() + Assert.allclose(structure.cell.array, raw_structure.cell.lattice_vectors) + Assert.allclose(structure.positions, raw_structure.cartesian_positions) + assert all(structure.symbols == "C6") + assert all(structure.pbc) def test_plot(raw_structure): @@ -47,3 +46,19 @@ def test_plot(raw_structure): structure.plot() init.assert_called_once() cell.assert_called_once() + + +def test_supercell(raw_structure, Assert): + structure = Structure(raw_structure) + number_atoms = len(structure) + cell = raw_structure.cell.lattice_vectors + # scale all dimensions by constant factor + scale = 2 + supercell = structure.to_ase(supercell=scale) + assert len(supercell) == number_atoms * scale ** 3 + Assert.allclose(supercell.cell.array, cell * scale) + # scale differently for each dimension + scale = (2, 1, 3) + supercell = structure.to_ase(supercell=scale) + assert len(supercell) == number_atoms * np.prod(scale) + Assert.allclose(supercell.cell.array, cell * scale) diff --git a/tests/data/test_viewer3d.py b/tests/data/test_viewer3d.py index e033ef6..93430ee 100644 --- a/tests/data/test_viewer3d.py +++ b/tests/data/test_viewer3d.py @@ -1,6 +1,7 @@ from unittest.mock import patch from py4vasp.data import Structure, Viewer3d from py4vasp.data.viewer3d import _Arrow3d, _x_axis, _y_axis, _z_axis +from py4vasp.exceptions import RefinementException from .test_structure import raw_structure import numpy as np import pytest @@ -9,10 +10,14 @@ @pytest.fixture def viewer3d(raw_structure): + return make_viewer(raw_structure) + + +def make_viewer(raw_structure, supercell=None): structure = Structure(raw_structure) - viewer = structure.plot() + viewer = structure.plot(supercell) + viewer.raw_structure = raw_structure viewer.default_messages = count_messages(viewer, setup=True) - viewer.positions = raw_structure.cartesian_positions return viewer @@ -76,14 +81,13 @@ def _assert_arrow_message(message, arrow): def test_arrows(viewer3d, assert_arrow_message): - direction = np.array((0, 0, 1)) - number_atoms = len(viewer3d._structure) - arrows = np.repeat([direction], number_atoms, axis=0) - viewer3d.show_arrows_at_atoms(arrows) - messages = last_messages(viewer3d, number_atoms) + number_atoms = len(viewer3d.raw_structure.cartesian_positions) + positions = viewer3d.raw_structure.cartesian_positions color = [0.1, 0.1, 0.8] + arrows = create_arrows(viewer3d, number_atoms) + messages = last_messages(viewer3d, number_atoms) assert len(messages) == number_atoms - for message, tail, arrow in zip(messages, viewer3d.positions, arrows): + for message, tail, arrow in zip(messages, positions, arrows): tip = tail + arrow assert_arrow_message(message, _Arrow3d(tail, tip, color)) viewer3d.show_arrows_at_atoms(arrows) @@ -92,3 +96,23 @@ def test_arrows(viewer3d, assert_arrow_message): # ngl deletes the sent messages to indicate removal of the shapes assert count_messages(viewer3d) == 0 viewer3d.hide_arrows_at_atoms() + + +def test_supercell(raw_structure, assert_arrow_message): + supercell = (1, 2, 3) + viewer = make_viewer(raw_structure, supercell) + number_atoms = len(raw_structure.cartesian_positions) + create_arrows(viewer, number_atoms) + assert count_messages(viewer) == np.prod(supercell) * number_atoms + + +def test_bare_ngl_cannot_add_arrows_at_atoms(viewer3d): + viewer = Viewer3d(viewer3d._ngl) + with pytest.raises(RefinementException): + create_arrows(viewer, 1) + + +def create_arrows(viewer, number_atoms): + arrows = np.repeat([(0, 0, 1)], number_atoms, axis=0) + viewer.show_arrows_at_atoms(arrows) + return arrows