Skip to content

Commit

Permalink
Create supercells using ASE instead of pymatgen (orest-d#12)
Browse files Browse the repository at this point in the history
* Replace pymatgen by ase
* Use plot as wrapper routine
* Implement supercell for structure class
* Implement supercell for Viewer3d class
* Use a from_structure method to initialize viewer
* Raise exception if atom positions are needed but not set
  • Loading branch information
martin-schlipf committed Aug 10, 2020
1 parent db64e22 commit 456d467
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 50 deletions.
11 changes: 4 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ h5py = "^2.10.0"
numpy = "^1.17.4"
pandas = "^0.25.3"
cufflinks = "^0.17.3"
pymatgen = "^2020.4.2"
nglview = "^2.7.5"
ase = "^3.19.0"

[tool.poetry.dev-dependencies]
pytest = "^3.0"
black = {version = "^18.3-alpha.0", allows-prereleases = true}
black = {version = "^18.3-alpha.0", allow-prereleases = true}
pytest-cov = "^2.8.1"
pylint = "^2.5.0"
sphinx = "^3.0.4"
Expand All @@ -36,13 +35,11 @@ cufflinks = {channel = "conda-forge", name = "cufflinks-py"}
mdtraj = {channel = "conda-forge"}
nglview = {channel = "conda-forge"}
ase = {channel = "conda-forge"}
pymatgen = {channel = "conda-forge"}
sphinx = {channel = "anaconda"}

[build-system]
requires = ["poetry>=1.0"]
build-backend = "poetry.masonry.api"

[tool.dephell.main]
from = {format = "poetry", path = "pyproject.toml"}
to = {format = "setuppy", path = "setup.py"}
[build-system]
requires = ["poetry>=1.0"]
build-backend = "poetry.masonry.api"
30 changes: 19 additions & 11 deletions src/py4vasp/data/structure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from py4vasp.data import _util
from py4vasp.data import Viewer3d
import ase
import numpy as np


class Structure:
Expand All @@ -13,23 +15,29 @@ def to_dict(self):
return {
"cell": self._raw.cell.lattice_vectors[:],
"cartesian_positions": self._raw.cartesian_positions[:],
"species": list(self._raw.species),
"species": self._raw.species,
}

def __len__(self):
return len(self._raw.cartesian_positions)

def to_pymatgen(self):
import pymatgen as mg

return mg.Structure(
lattice=mg.Lattice(self._raw.cell.lattice_vectors),
species=[specie.decode("ascii") for specie in self._raw.species],
coords=self._raw.cartesian_positions,
coords_are_cartesian=True,
def to_ase(self, supercell=None):
data = self.to_dict()
species = [_util.decode_if_possible(sp) for sp in data["species"]]
structure = ase.Atoms(
symbols=species,
cell=data["cell"],
positions=data["cartesian_positions"],
pbc=True,
)
if supercell is not None:
structure *= supercell
return structure

def plot(self, *args):
return self.to_ngl(*args)

def plot(self, supercell=None):
viewer = Viewer3d(self, supercell=supercell)
def to_ngl(self, supercell=None):
viewer = Viewer3d.from_structure(self, supercell=supercell)
viewer.show_cell()
return viewer
36 changes: 21 additions & 15 deletions src/py4vasp/data/viewer3d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from py4vasp.exceptions import RefinementException
from typing import NamedTuple
import nglview
import numpy as np
Expand All @@ -18,12 +19,22 @@ class _Arrow3d(NamedTuple):
class Viewer3d:
"""Collection of data and elements to be displayed in a structure viewer"""

def __init__(self, structure, supercell=None):
self._structure = structure
self._axes = None
self._arrows = []
self.supercell = supercell
self._ngl = self.show()
_positions = None
_multiple_cells = 1
_axes = None
_arrows = []

def __init__(self, ngl):
self._ngl = ngl

@classmethod
def from_structure(cls, structure, supercell=None):
ase = structure.to_ase(supercell)
res = cls(nglview.show_ase(ase))
res._positions = ase.positions
if supercell is not None:
res._multiple_cells = np.prod(supercell)
return res

def _ipython_display_(self):
self._ngl._ipython_display_()
Expand Down Expand Up @@ -51,8 +62,10 @@ def hide_axes(self):
self._axes = None

def show_arrows_at_atoms(self, arrows, color=[0.1, 0.1, 0.8]):
structure = self._structure.to_pymatgen()
for tail, arrow in zip(structure.cart_coords, arrows):
if self._positions is None:
raise RefinementException("Positions of atoms are not known.")
arrows = np.repeat(arrows, self._multiple_cells, axis=0)
for tail, arrow in zip(self._positions, arrows):
tip = tail + arrow
arrow = _Arrow3d(tail, tip, color)
self._arrows.append(self._make_arrow(arrow))
Expand All @@ -64,10 +77,3 @@ def hide_arrows_at_atoms(self):

def _make_arrow(self, arrow):
return self._ngl.shape.add_arrow(*arrow)

def show(self):
structure = self._structure.to_pymatgen()
if self.supercell is not None:
structure.make_supercell(self.supercell)
view = nglview.show_pymatgen(structure)
return view
33 changes: 24 additions & 9 deletions tests/data/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@pytest.fixture
def raw_structure():
number_atoms = 20
number_atoms = 6
shape = (number_atoms, 3)
return raw.Structure(
cell=raw.Cell(scale=1.0, lattice_vectors=np.eye(3)),
Expand All @@ -25,18 +25,17 @@ def get_messages_after_structure_information(view):

def test_read(raw_structure, Assert):
actual = Structure(raw_structure).read()
assert (actual["cell"] == raw_structure.cell.lattice_vectors).all()
Assert.allclose(actual["cell"], raw_structure.cell.lattice_vectors)
Assert.allclose(actual["cartesian_positions"], raw_structure.cartesian_positions)
assert (actual["species"] == raw_structure.species).all()


def test_to_pymatgen(raw_structure):
structure = Structure(raw_structure)
mg_structure = structure.to_pymatgen()
a, b, c = mg_structure.lattice.as_dict()["matrix"]
assert a == [1, 0, 0]
assert b == [0, 1, 0]
assert c == [0, 0, 1]
def test_to_ase(raw_structure, Assert):
structure = Structure(raw_structure).to_ase()
Assert.allclose(structure.cell.array, raw_structure.cell.lattice_vectors)
Assert.allclose(structure.positions, raw_structure.cartesian_positions)
assert all(structure.symbols == "C6")
assert all(structure.pbc)


def test_plot(raw_structure):
Expand All @@ -47,3 +46,19 @@ def test_plot(raw_structure):
structure.plot()
init.assert_called_once()
cell.assert_called_once()


def test_supercell(raw_structure, Assert):
structure = Structure(raw_structure)
number_atoms = len(structure)
cell = raw_structure.cell.lattice_vectors
# scale all dimensions by constant factor
scale = 2
supercell = structure.to_ase(supercell=scale)
assert len(supercell) == number_atoms * scale ** 3
Assert.allclose(supercell.cell.array, cell * scale)
# scale differently for each dimension
scale = (2, 1, 3)
supercell = structure.to_ase(supercell=scale)
assert len(supercell) == number_atoms * np.prod(scale)
Assert.allclose(supercell.cell.array, cell * scale)
40 changes: 32 additions & 8 deletions tests/data/test_viewer3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import patch
from py4vasp.data import Structure, Viewer3d
from py4vasp.data.viewer3d import _Arrow3d, _x_axis, _y_axis, _z_axis
from py4vasp.exceptions import RefinementException
from .test_structure import raw_structure
import numpy as np
import pytest
Expand All @@ -9,10 +10,14 @@

@pytest.fixture
def viewer3d(raw_structure):
return make_viewer(raw_structure)


def make_viewer(raw_structure, supercell=None):
structure = Structure(raw_structure)
viewer = structure.plot()
viewer = structure.plot(supercell)
viewer.raw_structure = raw_structure
viewer.default_messages = count_messages(viewer, setup=True)
viewer.positions = raw_structure.cartesian_positions
return viewer


Expand Down Expand Up @@ -76,14 +81,13 @@ def _assert_arrow_message(message, arrow):


def test_arrows(viewer3d, assert_arrow_message):
direction = np.array((0, 0, 1))
number_atoms = len(viewer3d._structure)
arrows = np.repeat([direction], number_atoms, axis=0)
viewer3d.show_arrows_at_atoms(arrows)
messages = last_messages(viewer3d, number_atoms)
number_atoms = len(viewer3d.raw_structure.cartesian_positions)
positions = viewer3d.raw_structure.cartesian_positions
color = [0.1, 0.1, 0.8]
arrows = create_arrows(viewer3d, number_atoms)
messages = last_messages(viewer3d, number_atoms)
assert len(messages) == number_atoms
for message, tail, arrow in zip(messages, viewer3d.positions, arrows):
for message, tail, arrow in zip(messages, positions, arrows):
tip = tail + arrow
assert_arrow_message(message, _Arrow3d(tail, tip, color))
viewer3d.show_arrows_at_atoms(arrows)
Expand All @@ -92,3 +96,23 @@ def test_arrows(viewer3d, assert_arrow_message):
# ngl deletes the sent messages to indicate removal of the shapes
assert count_messages(viewer3d) == 0
viewer3d.hide_arrows_at_atoms()


def test_supercell(raw_structure, assert_arrow_message):
supercell = (1, 2, 3)
viewer = make_viewer(raw_structure, supercell)
number_atoms = len(raw_structure.cartesian_positions)
create_arrows(viewer, number_atoms)
assert count_messages(viewer) == np.prod(supercell) * number_atoms


def test_bare_ngl_cannot_add_arrows_at_atoms(viewer3d):
viewer = Viewer3d(viewer3d._ngl)
with pytest.raises(RefinementException):
create_arrows(viewer, 1)


def create_arrows(viewer, number_atoms):
arrows = np.repeat([(0, 0, 1)], number_atoms, axis=0)
viewer.show_arrows_at_atoms(arrows)
return arrows

0 comments on commit 456d467

Please sign in to comment.