Skip to content

Commit

Permalink
Read structure from file (orest-d#14)
Browse files Browse the repository at this point in the history
We can use the Topology class like the Trajectory information instead of 
redefining another way to extract the species information from the file.

The position are actually in crystal units not in cartesian ones

Add from_file method
  • Loading branch information
martin-schlipf authored Aug 12, 2020
1 parent a08f4f1 commit 25a4f40
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 44 deletions.
20 changes: 11 additions & 9 deletions src/py4vasp/data/structure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from py4vasp.data import _util
from py4vasp.data import Viewer3d
from py4vasp.data import _util, Viewer3d, Topology
import ase
import numpy as np

Expand All @@ -8,26 +7,29 @@ class Structure:
def __init__(self, raw_structure):
self._raw = raw_structure

@classmethod
def from_file(cls, file=None):
return _util.from_file(cls, file, "structure")

def read(self):
return self.to_dict()

def to_dict(self):
return {
"cell": self._raw.cell.lattice_vectors[:],
"cartesian_positions": self._raw.cartesian_positions[:],
"species": self._raw.species,
"cell": self._raw.cell.scale * self._raw.cell.lattice_vectors[:],
"positions": self._raw.positions[:],
"elements": Topology(self._raw.topology).elements(),
}

def __len__(self):
return len(self._raw.cartesian_positions)
return len(self._raw.positions)

def to_ase(self, supercell=None):
data = self.to_dict()
species = [_util.decode_if_possible(sp) for sp in data["species"]]
structure = ase.Atoms(
symbols=species,
symbols=data["elements"],
cell=data["cell"],
positions=data["cartesian_positions"],
scaled_positions=data["positions"],
pbc=True,
)
if supercell is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/py4vasp/raw/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def structure(self):
"""
self._assert_not_closed()
return raw.Structure(
topology=self.topology(),
cell=self.cell(),
cartesian_positions=self._h5f["results/positions/cartesian_positions"],
species=self._h5f["results/positions/species"],
positions=self._h5f["results/positions/position_ions"],
)

def energy(self):
Expand Down
4 changes: 2 additions & 2 deletions src/py4vasp/raw/rawdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class Cell:

@dataclass
class Structure:
topology: Topology
cell: Cell
cartesian_positions: np.ndarray
species: np.ndarray = None
positions: np.ndarray
__eq__ = _dataclass_equal


Expand Down
57 changes: 38 additions & 19 deletions tests/data/test_structure.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,63 @@
from unittest.mock import patch
from py4vasp.data import Structure
from .test_topology import raw_topology
import py4vasp.data as data
import py4vasp.raw as raw
import pytest
import numpy as np


@pytest.fixture
def raw_structure():
number_atoms = 6
def raw_structure(raw_topology):
number_atoms = len(raw_topology.elements)
shape = (number_atoms, 3)
return raw.Structure(
cell=raw.Cell(scale=1.0, lattice_vectors=np.eye(3)),
cartesian_positions=np.arange(np.prod(shape)).reshape(shape),
species=np.array(["C"] * number_atoms, dtype="S2"),
structure = raw.Structure(
topology=raw_topology,
cell=raw.Cell(scale=2.0, lattice_vectors=np.eye(3)),
positions=np.arange(np.prod(shape)).reshape(shape) / np.prod(shape),
)
structure.actual_cell = structure.cell.scale * structure.cell.lattice_vectors
return structure


def get_messages_after_structure_information(view):
message_archive = view.get_state()["_ngl_msg_archive"]
all_messages = [(msg["methodName"], msg["args"]) for msg in message_archive]
return all_messages[1:] # first message is structure data
def test_from_file(raw_structure, mock_file, check_read):
with mock_file("structure", raw_structure) as mocks:
check_read(Structure, mocks, raw_structure)


def test_read(raw_structure, Assert):
actual = Structure(raw_structure).read()
Assert.allclose(actual["cell"], raw_structure.cell.lattice_vectors)
Assert.allclose(actual["cartesian_positions"], raw_structure.cartesian_positions)
assert (actual["species"] == raw_structure.species).all()
Assert.allclose(actual["cell"], raw_structure.actual_cell)
Assert.allclose(actual["positions"], raw_structure.positions)
assert actual["elements"] == raw_structure.topology.elements


def test_to_ase(raw_structure, Assert):
structure = Structure(raw_structure).to_ase()
Assert.allclose(structure.cell.array, raw_structure.cell.lattice_vectors)
Assert.allclose(structure.positions, raw_structure.cartesian_positions)
assert all(structure.symbols == "C6")
Assert.allclose(structure.cell.array, raw_structure.actual_cell)
Assert.allclose(structure.get_scaled_positions(), raw_structure.positions)
assert all(structure.symbols == "Sr2TiO4")
assert all(structure.pbc)


def test_tilted_unitcell(raw_structure, Assert):
cell = np.array([[4, 0, 0], [0, 4, 0], [2, 2, 6]])
inv_cell = np.linalg.inv(cell)
cartesian_positions = (
(0, 0, 0),
(4, 4, 4),
(2, 2, 2),
(2, 2, 0),
(2, 4, 2),
(4, 2, 2),
(2, 2, 4),
)
raw_structure.cell = raw.Cell(scale=1, lattice_vectors=cell)
raw_structure.positions = cartesian_positions @ inv_cell
structure = Structure(raw_structure).to_ase()
Assert.allclose(structure.positions, cartesian_positions)


def test_plot(raw_structure):
cm_init = patch.object(data.Viewer3d, "__init__", autospec=True, return_value=None)
cm_cell = patch.object(data.Viewer3d, "show_cell")
Expand All @@ -51,14 +71,13 @@ def test_plot(raw_structure):
def test_supercell(raw_structure, Assert):
structure = Structure(raw_structure)
number_atoms = len(structure)
cell = raw_structure.cell.lattice_vectors
# scale all dimensions by constant factor
scale = 2
supercell = structure.to_ase(supercell=scale)
assert len(supercell) == number_atoms * scale ** 3
Assert.allclose(supercell.cell.array, cell * scale)
Assert.allclose(supercell.cell.array, raw_structure.actual_cell * scale)
# scale differently for each dimension
scale = (2, 1, 3)
supercell = structure.to_ase(supercell=scale)
assert len(supercell) == number_atoms * np.prod(scale)
Assert.allclose(supercell.cell.array, cell * scale)
Assert.allclose(supercell.cell.array, raw_structure.actual_cell * scale)
11 changes: 10 additions & 1 deletion tests/data/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

@pytest.fixture
def raw_topology():
return raw.Topology(
topology = raw.Topology(
number_ion_types=np.array((2, 1, 4)),
ion_types=np.array(("Sr", "Ti", "O "), dtype="S"),
)
topology.names = ["Sr_1", "Sr_2", "Ti_1", "O_1", "O_2", "O_3", "O_4"]
topology.elements = ["Sr", "Sr", "Ti", "O", "O", "O", "O"]
return topology


def test_raw_topology(raw_topology):
Expand All @@ -31,6 +34,12 @@ def test_raw_topology(raw_topology):
assert topology["*"] == Selection(indices=range(index[-1]))


def test_atom_labels(raw_topology):
topology = Topology(raw_topology)
assert topology.names() == raw_topology.names
assert topology.elements() == raw_topology.elements


def test_from_file(raw_topology, mock_file, check_read):
with mock_file("topology", raw_topology) as mocks:
check_read(Topology, mocks, raw_topology)
Expand Down
5 changes: 2 additions & 3 deletions tests/data/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def raw_trajectory(raw_topology):

def test_read_trajectory(raw_trajectory, Assert):
trajectory = Trajectory(raw_trajectory).read()
topology = Topology(raw_trajectory.topology)
assert trajectory["names"] == topology.names()
assert trajectory["elements"] == topology.elements()
assert trajectory["names"] == raw_trajectory.topology.names
assert trajectory["elements"] == raw_trajectory.topology.elements
Assert.allclose(trajectory["positions"], raw_trajectory.positions)
Assert.allclose(trajectory["lattice_vectors"], raw_trajectory.lattice_vectors)

Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_viewer3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from py4vasp.data import Structure, Viewer3d
from py4vasp.data.viewer3d import _Arrow3d, _x_axis, _y_axis, _z_axis
from py4vasp.exceptions import RefinementException
from .test_structure import raw_structure
from .test_structure import raw_structure, raw_topology
import numpy as np
import pytest
import nglview
Expand Down Expand Up @@ -81,8 +81,8 @@ def _assert_arrow_message(message, arrow):


def test_arrows(viewer3d, assert_arrow_message):
number_atoms = len(viewer3d.raw_structure.cartesian_positions)
positions = viewer3d.raw_structure.cartesian_positions
positions = viewer3d.raw_structure.positions @ viewer3d.raw_structure.actual_cell
number_atoms = len(positions)
color = [0.1, 0.1, 0.8]
arrows = create_arrows(viewer3d, number_atoms)
messages = last_messages(viewer3d, number_atoms)
Expand All @@ -101,7 +101,7 @@ def test_arrows(viewer3d, assert_arrow_message):
def test_supercell(raw_structure, assert_arrow_message):
supercell = (1, 2, 3)
viewer = make_viewer(raw_structure, supercell)
number_atoms = len(raw_structure.cartesian_positions)
number_atoms = len(raw_structure.positions)
create_arrows(viewer, number_atoms)
assert count_messages(viewer) == np.prod(supercell) * number_atoms

Expand Down
8 changes: 4 additions & 4 deletions tests/raw/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,17 @@ def test_structure(tmpdir):

def reference_structure():
structure = raw.Structure(
topology=reference_topology(),
cell=reference_cell(),
cartesian_positions=np.linspace(np.zeros(3), np.ones(3), num_atoms),
species=np.array(["C"] * num_atoms, dtype="S2"),
positions=np.linspace(np.zeros(3), np.ones(3), num_atoms),
)
return structure


def write_structure(h5f, structure):
write_topology(h5f, structure.topology)
write_cell(h5f, structure.cell)
h5f["results/positions/cartesian_positions"] = structure.cartesian_positions
h5f["results/positions/species"] = structure.species
h5f["results/positions/position_ions"] = structure.positions


def check_structure(file, reference):
Expand Down

0 comments on commit 25a4f40

Please sign in to comment.