Skip to content

Commit

Permalink
Cleanup NGL Viewer class (orest-d#11)
Browse files Browse the repository at this point in the history
A lot of the functionality previously contained in the Structure class
is now moved to the Viewer3d class. The main reason for this change is
that the structure class does not need to now about NGL at all.
Also the interface is much leaner and more options can be choosen by the
user afterwards if desired and don't expand the default interface.
  • Loading branch information
martin-schlipf committed Aug 5, 2020
1 parent 2fba540 commit db64e22
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 109 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
name = "py4vasp"
version = "0.1.0"
description = ""
authors = ["Martin Schlipf <martin.schlipf@gmail.com>"]
authors = [
"Martin Schlipf <martin.schlipf@gmail.com>",
"Orest Dubay <orest-d@users.noreply.github.com>"
]

[tool.poetry.dependencies]
python = "^3.7"
Expand Down Expand Up @@ -31,7 +34,9 @@ numpy = {channel = "anaconda"}
pandas = {channel = "anaconda"}
cufflinks = {channel = "conda-forge", name = "cufflinks-py"}
mdtraj = {channel = "conda-forge"}
nglview = {channel = "conda-forge"}
ase = {channel = "conda-forge"}
pymatgen = {channel = "conda-forge"}
sphinx = {channel = "anaconda"}

[build-system]
Expand Down
1 change: 1 addition & 0 deletions src/py4vasp/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .energy import Energy
from .kpoints import Kpoints
from .projectors import Projectors
from .viewer3d import Viewer3d
from .structure import Structure

import plotly.io as pio
Expand Down
67 changes: 8 additions & 59 deletions src/py4vasp/data/structure.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,10 @@
from py4vasp.data import _util
from py4vasp.exceptions import RefinementException
import functools
import numpy as np


class StructureViewer:
"""Collection of data and elements to be displayed in a structure viewer"""

def __init__(self, structure, show_cell=True, supercell=None, show_axes=False, axes_length=3, arrows=None):
self.structure = structure
self.show_cell = show_cell
self.supercell = supercell
self.show_axes = show_axes
self.axes_length = axes_length
self.arrows = arrows

def with_arrows(self, arrows):
self.arrows = arrows
return self

def show(self):
import nglview
from nglview.shape import Shape

structure = self.structure.to_pymatgen()
if self.supercell is not None:
structure.make_supercell(self.supercell)

view = nglview.show_pymatgen(structure)
if self.show_cell:
view.add_representation(repr_type="unitcell")
if self.show_axes or self.arrows is not None:
shape = Shape(view=view)
if self.show_axes:
shape.add_arrow(
[0, 0, 0], [self.axes_length, 0, 0], [1, 0, 0], 0.2)
shape.add_arrow(
[0, 0, 0], [0, self.axes_length, 0], [0, 1, 0], 0.2)
shape.add_arrow(
[0, 0, 0], [0, 0, self.axes_length], [0, 0, 1], 0.2)
if self.arrows is not None:
for (x, y, z), (vx, vy, vz) in zip(structure.cart_coords, self.arrows):
shape.add_arrow(
[x, y, z], [x+vx, y+vy, z+vz], [0.1, 0.1, 0.8], 0.2)

return view
from py4vasp.data import Viewer3d


class Structure:
def __init__(self, raw_structure):
self._raw = raw_structure
self.structure_viewer = None

def read(self):
return self.to_dict()
Expand All @@ -64,23 +18,18 @@ def to_dict(self):

def __len__(self):
return len(self._raw.cartesian_positions)

def to_pymatgen(self):
import pymatgen as mg

return mg.Structure(
lattice=mg.Lattice(self._raw.cell.lattice_vectors),
species=[specie.decode("ascii") for specie in self._raw.species],
coords=self._raw.cartesian_positions,
coords_are_cartesian=True
coords_are_cartesian=True,
)

def plot(self, show_cell=True, supercell=None, show_axes=False, axes_length=3):
self.structure_viewer = StructureViewer(
self, show_cell=show_cell, supercell=supercell, show_axes=show_axes, axes_length=axes_length)
return self.structure_viewer.show()

def plot_arrows(self, arrows):
if self.structure_viewer is None:
self.plot()
self.structure_viewer = self.structure_viewer.with_arrows(arrows)
return self.structure_viewer.show()
def plot(self, supercell=None):
viewer = Viewer3d(self, supercell=supercell)
viewer.show_cell()
return viewer
73 changes: 73 additions & 0 deletions src/py4vasp/data/viewer3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import NamedTuple
import nglview
import numpy as np


class _Arrow3d(NamedTuple):
tail: np.ndarray
tip: np.ndarray
color: np.ndarray
radius: float = 0.2


_x_axis = _Arrow3d(tail=np.zeros(3), tip=np.array((3, 0, 0)), color=[1, 0, 0])
_y_axis = _Arrow3d(tail=np.zeros(3), tip=np.array((0, 3, 0)), color=[0, 1, 0])
_z_axis = _Arrow3d(tail=np.zeros(3), tip=np.array((0, 0, 3)), color=[0, 0, 1])


class Viewer3d:
"""Collection of data and elements to be displayed in a structure viewer"""

def __init__(self, structure, supercell=None):
self._structure = structure
self._axes = None
self._arrows = []
self.supercell = supercell
self._ngl = self.show()

def _ipython_display_(self):
self._ngl._ipython_display_()

def show_cell(self):
self._ngl.add_unitcell()

def hide_cell(self):
self._ngl.remove_unitcell()

def show_axes(self):
if self._axes is not None:
return
self._axes = (
self._make_arrow(_x_axis),
self._make_arrow(_y_axis),
self._make_arrow(_z_axis),
)

def hide_axes(self):
if self._axes is None:
return
for axis in self._axes:
self._ngl.remove_component(axis)
self._axes = None

def show_arrows_at_atoms(self, arrows, color=[0.1, 0.1, 0.8]):
structure = self._structure.to_pymatgen()
for tail, arrow in zip(structure.cart_coords, arrows):
tip = tail + arrow
arrow = _Arrow3d(tail, tip, color)
self._arrows.append(self._make_arrow(arrow))

def hide_arrows_at_atoms(self):
for arrow in self._arrows:
self._ngl.remove_component(arrow)
self._arrows = []

def _make_arrow(self, arrow):
return self._ngl.shape.add_arrow(*arrow)

def show(self):
structure = self._structure.to_pymatgen()
if self.supercell is not None:
structure.make_supercell(self.supercell)
view = nglview.show_pymatgen(structure)
return view
6 changes: 3 additions & 3 deletions src/py4vasp/raw/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def structure(self):
"""
self._assert_not_closed()
return raw.Structure(
cell = self.cell(),
cartesian_positions = self._h5f["results/positions/cartesian_positions"],
species = self._h5f["results/positions/species"],
cell=self.cell(),
cartesian_positions=self._h5f["results/positions/cartesian_positions"],
species=self._h5f["results/positions/species"],
)

def energy(self):
Expand Down
2 changes: 2 additions & 0 deletions src/py4vasp/raw/rawdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ class Cell:
"Lattice vectors defining the unit cell."
__eq__ = _dataclass_equal


@dataclass
class Structure:
cell: Cell
cartesian_positions: np.ndarray
species: np.ndarray = None
__eq__ = _dataclass_equal


@dataclass
class Kpoints:
"**k** points at which wave functions are calculated."
Expand Down
64 changes: 19 additions & 45 deletions tests/data/test_structure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import patch
from py4vasp.data import Structure
from py4vasp.exceptions import RefinementException
import py4vasp.data as data
import py4vasp.raw as raw
import pytest
import numpy as np
Expand All @@ -12,19 +13,24 @@ def raw_structure():
return raw.Structure(
cell=raw.Cell(scale=1.0, lattice_vectors=np.eye(3)),
cartesian_positions=np.arange(np.prod(shape)).reshape(shape),
species=np.array(["C"]*number_atoms, dtype="S2"),
species=np.array(["C"] * number_atoms, dtype="S2"),
)


def get_messages_after_structure_information(view):
message_archive = view.get_state()["_ngl_msg_archive"]
all_messages = [(msg["methodName"], msg["args"]) for msg in message_archive]
return all_messages[1:] # first message is structure data


def test_read(raw_structure, Assert):
actual = Structure(raw_structure).read()
assert (actual["cell"] == raw_structure.cell.lattice_vectors).all()
Assert.allclose(actual["cartesian_positions"],
raw_structure.cartesian_positions)
Assert.allclose(actual["cartesian_positions"], raw_structure.cartesian_positions)
assert (actual["species"] == raw_structure.species).all()


def test_to_pymatgen(raw_structure, Assert):
def test_to_pymatgen(raw_structure):
structure = Structure(raw_structure)
mg_structure = structure.to_pymatgen()
a, b, c = mg_structure.lattice.as_dict()["matrix"]
Expand All @@ -33,43 +39,11 @@ def test_to_pymatgen(raw_structure, Assert):
assert c == [0, 0, 1]


def test_plot(raw_structure, Assert):
structure = Structure(raw_structure)
assert structure.structure_viewer is None
view = structure.plot()
assert structure.structure_viewer is not None

view = structure.plot(show_cell=False)
assert [(msg["methodName"], msg["args"])
for msg in view.get_state()["_ngl_msg_archive"]][1:] == []

view = structure.plot(show_cell=True)
assert [(msg["methodName"], msg["args"]) for msg in view.get_state()[
"_ngl_msg_archive"]][1:] == [('addRepresentation', ['unitcell'])]

view = structure.plot(show_cell=False, show_axes=False)
assert [(msg["methodName"], msg["args"])
for msg in view.get_state()["_ngl_msg_archive"]][1:] == []

view = structure.plot(show_cell=False, show_axes=True, axes_length=5)
assert [(msg["methodName"], msg["args"]) for msg in view.get_state()["_ngl_msg_archive"]][1:] == [
('addShape', [
'shape', [('arrow', [0, 0, 0], [5, 0, 0], [1, 0, 0], 0.2)]]),
('addShape', [
'shape', [('arrow', [0, 0, 0], [0, 5, 0], [0, 1, 0], 0.2)]]),
('addShape', [
'shape', [('arrow', [0, 0, 0], [0, 0, 5], [0, 0, 1], 0.2)]]),
]


def test_plot_arrows(raw_structure, Assert):
structure = Structure(raw_structure)
assert structure.structure_viewer is None
view = structure.plot_arrows([(0, 0, 1) for i in range(len(structure))])
assert structure.structure_viewer is not None
assert [(msg["methodName"], msg["args"]) for msg in view.get_state()[
"_ngl_msg_archive"]][1] == ('addRepresentation', ['unitcell'])
assert [(msg["methodName"], msg["args"]) for msg in view.get_state()[
"_ngl_msg_archive"]][2] == ('addShape', ['shape', [('arrow', [0.0, 1.0, 2.0], [0.0, 1.0, 3.0], [0.1, 0.1, 0.8], 0.2)]])
assert sum(msg["methodName"] == 'addShape' for msg in view.get_state()[
"_ngl_msg_archive"]) == len(structure)
def test_plot(raw_structure):
cm_init = patch.object(data.Viewer3d, "__init__", autospec=True, return_value=None)
cm_cell = patch.object(data.Viewer3d, "show_cell")
with cm_init as init, cm_cell as cell:
structure = Structure(raw_structure)
structure.plot()
init.assert_called_once()
cell.assert_called_once()
Loading

0 comments on commit db64e22

Please sign in to comment.