Skip to content

Commit

Permalink
Add nglview support (orest-d#8)
Browse files Browse the repository at this point in the history
* Refactored structure viewer into its own object
* Add unit tests for structure viewer - structure and arrows
  • Loading branch information
orest-d committed Aug 3, 2020
1 parent bcfad57 commit 2fba540
Show file tree
Hide file tree
Showing 9 changed files with 1,106 additions and 272 deletions.
1,156 changes: 887 additions & 269 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ python = "^3.7"
h5py = "^2.10.0"
numpy = "^1.17.4"
pandas = "^0.25.3"
cufflinks = "^0.17.0"
cufflinks = "^0.17.3"
pymatgen = "^2020.4.2"
nglview = "^2.7.5"
ase = "^3.19.0"

[tool.poetry.dev-dependencies]
pytest = "^3.0"
black = {version = "^18.3-alpha.0", allows-prereleases = true}
pytest-cov = "^2.8.1"
pylint = "^2.5.0"
sphinx = "^3.0.4"

[tool.poetry2conda]
Expand All @@ -27,6 +31,7 @@ numpy = {channel = "anaconda"}
pandas = {channel = "anaconda"}
cufflinks = {channel = "conda-forge", name = "cufflinks-py"}
mdtraj = {channel = "conda-forge"}
ase = {channel = "conda-forge"}
sphinx = {channel = "anaconda"}

[build-system]
Expand Down
1 change: 1 addition & 0 deletions src/py4vasp/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .energy import Energy
from .kpoints import Kpoints
from .projectors import Projectors
from .structure import Structure

import plotly.io as pio
import cufflinks as cf
Expand Down
4 changes: 2 additions & 2 deletions src/py4vasp/data/projectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from py4vasp.data import _util
from py4vasp.exceptions import UsageException

_selection_doc = """
_selection_doc = r"""
selection : str
A string specifying the projection of the orbitals. There are three distinct
possibilities:
Expand All @@ -33,7 +33,7 @@
_end_spec = ")"
_seperators = (" ", ",")
_range_separator = "-"
_range = re.compile(r"^(\d+)" + re.escape(_range_separator) + "(\d+)$")
_range = re.compile(r"^(\d+)" + re.escape(_range_separator) + r"(\d+)$")
_whitespace_begin_spec = re.compile(r"\s*" + re.escape(_begin_spec) + r"\s*")
_whitespace_end_spec = re.compile(r"\s*" + re.escape(_end_spec) + r"\s*")
_whitespace_range = re.compile(r"\s*" + re.escape(_range_separator) + r"\s*")
Expand Down
86 changes: 86 additions & 0 deletions src/py4vasp/data/structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from py4vasp.data import _util
from py4vasp.exceptions import RefinementException
import functools
import numpy as np


class StructureViewer:
"""Collection of data and elements to be displayed in a structure viewer"""

def __init__(self, structure, show_cell=True, supercell=None, show_axes=False, axes_length=3, arrows=None):
self.structure = structure
self.show_cell = show_cell
self.supercell = supercell
self.show_axes = show_axes
self.axes_length = axes_length
self.arrows = arrows

def with_arrows(self, arrows):
self.arrows = arrows
return self

def show(self):
import nglview
from nglview.shape import Shape

structure = self.structure.to_pymatgen()
if self.supercell is not None:
structure.make_supercell(self.supercell)

view = nglview.show_pymatgen(structure)
if self.show_cell:
view.add_representation(repr_type="unitcell")
if self.show_axes or self.arrows is not None:
shape = Shape(view=view)
if self.show_axes:
shape.add_arrow(
[0, 0, 0], [self.axes_length, 0, 0], [1, 0, 0], 0.2)
shape.add_arrow(
[0, 0, 0], [0, self.axes_length, 0], [0, 1, 0], 0.2)
shape.add_arrow(
[0, 0, 0], [0, 0, self.axes_length], [0, 0, 1], 0.2)
if self.arrows is not None:
for (x, y, z), (vx, vy, vz) in zip(structure.cart_coords, self.arrows):
shape.add_arrow(
[x, y, z], [x+vx, y+vy, z+vz], [0.1, 0.1, 0.8], 0.2)

return view


class Structure:
def __init__(self, raw_structure):
self._raw = raw_structure
self.structure_viewer = None

def read(self):
return self.to_dict()

def to_dict(self):
return {
"cell": self._raw.cell.lattice_vectors[:],
"cartesian_positions": self._raw.cartesian_positions[:],
"species": list(self._raw.species),
}

def __len__(self):
return len(self._raw.cartesian_positions)

def to_pymatgen(self):
import pymatgen as mg
return mg.Structure(
lattice=mg.Lattice(self._raw.cell.lattice_vectors),
species=[specie.decode("ascii") for specie in self._raw.species],
coords=self._raw.cartesian_positions,
coords_are_cartesian=True
)

def plot(self, show_cell=True, supercell=None, show_axes=False, axes_length=3):
self.structure_viewer = StructureViewer(
self, show_cell=show_cell, supercell=supercell, show_axes=show_axes, axes_length=axes_length)
return self.structure_viewer.show()

def plot_arrows(self, arrows):
if self.structure_viewer is None:
self.plot()
self.structure_viewer = self.structure_viewer.with_arrows(arrows)
return self.structure_viewer.show()
14 changes: 14 additions & 0 deletions src/py4vasp/raw/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ def cell(self):
lattice_vectors=self._h5f["results/positions/lattice_vectors"],
)

def structure(self):
""" Read the structure information.
Returns
-------
raw.Structure
"""
self._assert_not_closed()
return raw.Structure(
cell = self.cell(),
cartesian_positions = self._h5f["results/positions/cartesian_positions"],
species = self._h5f["results/positions/species"],
)

def energy(self):
""" Read the energies during the ionic convergence.
Expand Down
6 changes: 6 additions & 0 deletions src/py4vasp/raw/rawdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class Cell:
"Lattice vectors defining the unit cell."
__eq__ = _dataclass_equal

@dataclass
class Structure:
cell: Cell
cartesian_positions: np.ndarray
species: np.ndarray = None
__eq__ = _dataclass_equal

@dataclass
class Kpoints:
Expand Down
75 changes: 75 additions & 0 deletions tests/data/test_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from py4vasp.data import Structure
from py4vasp.exceptions import RefinementException
import py4vasp.raw as raw
import pytest
import numpy as np


@pytest.fixture
def raw_structure():
number_atoms = 20
shape = (number_atoms, 3)
return raw.Structure(
cell=raw.Cell(scale=1.0, lattice_vectors=np.eye(3)),
cartesian_positions=np.arange(np.prod(shape)).reshape(shape),
species=np.array(["C"]*number_atoms, dtype="S2"),
)


def test_read(raw_structure, Assert):
actual = Structure(raw_structure).read()
assert (actual["cell"] == raw_structure.cell.lattice_vectors).all()
Assert.allclose(actual["cartesian_positions"],
raw_structure.cartesian_positions)
assert (actual["species"] == raw_structure.species).all()


def test_to_pymatgen(raw_structure, Assert):
structure = Structure(raw_structure)
mg_structure = structure.to_pymatgen()
a, b, c = mg_structure.lattice.as_dict()["matrix"]
assert a == [1, 0, 0]
assert b == [0, 1, 0]
assert c == [0, 0, 1]


def test_plot(raw_structure, Assert):
structure = Structure(raw_structure)
assert structure.structure_viewer is None
view = structure.plot()
assert structure.structure_viewer is not None

view = structure.plot(show_cell=False)
assert [(msg["methodName"], msg["args"])
for msg in view.get_state()["_ngl_msg_archive"]][1:] == []

view = structure.plot(show_cell=True)
assert [(msg["methodName"], msg["args"]) for msg in view.get_state()[
"_ngl_msg_archive"]][1:] == [('addRepresentation', ['unitcell'])]

view = structure.plot(show_cell=False, show_axes=False)
assert [(msg["methodName"], msg["args"])
for msg in view.get_state()["_ngl_msg_archive"]][1:] == []

view = structure.plot(show_cell=False, show_axes=True, axes_length=5)
assert [(msg["methodName"], msg["args"]) for msg in view.get_state()["_ngl_msg_archive"]][1:] == [
('addShape', [
'shape', [('arrow', [0, 0, 0], [5, 0, 0], [1, 0, 0], 0.2)]]),
('addShape', [
'shape', [('arrow', [0, 0, 0], [0, 5, 0], [0, 1, 0], 0.2)]]),
('addShape', [
'shape', [('arrow', [0, 0, 0], [0, 0, 5], [0, 0, 1], 0.2)]]),
]


def test_plot_arrows(raw_structure, Assert):
structure = Structure(raw_structure)
assert structure.structure_viewer is None
view = structure.plot_arrows([(0, 0, 1) for i in range(len(structure))])
assert structure.structure_viewer is not None
assert [(msg["methodName"], msg["args"]) for msg in view.get_state()[
"_ngl_msg_archive"]][1] == ('addRepresentation', ['unitcell'])
assert [(msg["methodName"], msg["args"]) for msg in view.get_state()[
"_ngl_msg_archive"]][2] == ('addShape', ['shape', [('arrow', [0.0, 1.0, 2.0], [0.0, 1.0, 3.0], [0.1, 0.1, 0.8], 0.2)]])
assert sum(msg["methodName"] == 'addShape' for msg in view.get_state()[
"_ngl_msg_archive"]) == len(structure)
29 changes: 29 additions & 0 deletions tests/raw/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,32 @@ def check_kpoints(file, reference):
assert actual == reference
assert isinstance(actual.number, Integral)
assert isinstance(actual.mode, str)


def test_structure(tmpdir):
setup = SetupTest(
directory=tmpdir,
options=itertools.product((True, False)),
create_reference=reference_structure,
write_reference=write_structure,
check_actual=check_structure,
)
generic_test(setup)


def reference_structure():
structure = raw.Structure(
cell=reference_cell(),
cartesian_positions=np.linspace(np.zeros(3), np.ones(3), num_atoms),
species = np.array(["C"]*num_atoms, dtype="S2")
)
return structure

def write_structure(h5f, structure):
write_cell(h5f, structure.cell)
h5f["results/positions/cartesian_positions"] = structure.cartesian_positions
h5f["results/positions/species"] = structure.species

def check_structure(file, reference):
actual = file.structure()
assert actual == reference

0 comments on commit 2fba540

Please sign in to comment.