diff --git a/src/py4vasp/data/structure.py b/src/py4vasp/data/structure.py index e3f2bf8..b475e12 100644 --- a/src/py4vasp/data/structure.py +++ b/src/py4vasp/data/structure.py @@ -1,5 +1,4 @@ -from py4vasp.data import _util -from py4vasp.data import Viewer3d +from py4vasp.data import _util, Viewer3d, Topology import ase import numpy as np @@ -8,26 +7,29 @@ class Structure: def __init__(self, raw_structure): self._raw = raw_structure + @classmethod + def from_file(cls, file=None): + return _util.from_file(cls, file, "structure") + def read(self): return self.to_dict() def to_dict(self): return { - "cell": self._raw.cell.lattice_vectors[:], - "cartesian_positions": self._raw.cartesian_positions[:], - "species": self._raw.species, + "cell": self._raw.cell.scale * self._raw.cell.lattice_vectors[:], + "positions": self._raw.positions[:], + "elements": Topology(self._raw.topology).elements(), } def __len__(self): - return len(self._raw.cartesian_positions) + return len(self._raw.positions) def to_ase(self, supercell=None): data = self.to_dict() - species = [_util.decode_if_possible(sp) for sp in data["species"]] structure = ase.Atoms( - symbols=species, + symbols=data["elements"], cell=data["cell"], - positions=data["cartesian_positions"], + scaled_positions=data["positions"], pbc=True, ) if supercell is not None: diff --git a/src/py4vasp/raw/file.py b/src/py4vasp/raw/file.py index 87a6a75..952a6d9 100644 --- a/src/py4vasp/raw/file.py +++ b/src/py4vasp/raw/file.py @@ -152,9 +152,9 @@ def structure(self): """ self._assert_not_closed() return raw.Structure( + topology=self.topology(), cell=self.cell(), - cartesian_positions=self._h5f["results/positions/cartesian_positions"], - species=self._h5f["results/positions/species"], + positions=self._h5f["results/positions/position_ions"], ) def energy(self): diff --git a/src/py4vasp/raw/rawdata.py b/src/py4vasp/raw/rawdata.py index fd0c3b6..a8b169b 100644 --- a/src/py4vasp/raw/rawdata.py +++ b/src/py4vasp/raw/rawdata.py @@ -60,9 +60,9 @@ class Cell: @dataclass class Structure: + topology: Topology cell: Cell - cartesian_positions: np.ndarray - species: np.ndarray = None + positions: np.ndarray __eq__ = _dataclass_equal diff --git a/tests/data/test_structure.py b/tests/data/test_structure.py index e6f5630..d9f32be 100644 --- a/tests/data/test_structure.py +++ b/tests/data/test_structure.py @@ -1,5 +1,6 @@ from unittest.mock import patch from py4vasp.data import Structure +from .test_topology import raw_topology import py4vasp.data as data import py4vasp.raw as raw import pytest @@ -7,37 +8,56 @@ @pytest.fixture -def raw_structure(): - number_atoms = 6 +def raw_structure(raw_topology): + number_atoms = len(raw_topology.elements) shape = (number_atoms, 3) - return raw.Structure( - cell=raw.Cell(scale=1.0, lattice_vectors=np.eye(3)), - cartesian_positions=np.arange(np.prod(shape)).reshape(shape), - species=np.array(["C"] * number_atoms, dtype="S2"), + structure = raw.Structure( + topology=raw_topology, + cell=raw.Cell(scale=2.0, lattice_vectors=np.eye(3)), + positions=np.arange(np.prod(shape)).reshape(shape) / np.prod(shape), ) + structure.actual_cell = structure.cell.scale * structure.cell.lattice_vectors + return structure -def get_messages_after_structure_information(view): - message_archive = view.get_state()["_ngl_msg_archive"] - all_messages = [(msg["methodName"], msg["args"]) for msg in message_archive] - return all_messages[1:] # first message is structure data +def test_from_file(raw_structure, mock_file, check_read): + with mock_file("structure", raw_structure) as mocks: + check_read(Structure, mocks, raw_structure) def test_read(raw_structure, Assert): actual = Structure(raw_structure).read() - Assert.allclose(actual["cell"], raw_structure.cell.lattice_vectors) - Assert.allclose(actual["cartesian_positions"], raw_structure.cartesian_positions) - assert (actual["species"] == raw_structure.species).all() + Assert.allclose(actual["cell"], raw_structure.actual_cell) + Assert.allclose(actual["positions"], raw_structure.positions) + assert actual["elements"] == raw_structure.topology.elements def test_to_ase(raw_structure, Assert): structure = Structure(raw_structure).to_ase() - Assert.allclose(structure.cell.array, raw_structure.cell.lattice_vectors) - Assert.allclose(structure.positions, raw_structure.cartesian_positions) - assert all(structure.symbols == "C6") + Assert.allclose(structure.cell.array, raw_structure.actual_cell) + Assert.allclose(structure.get_scaled_positions(), raw_structure.positions) + assert all(structure.symbols == "Sr2TiO4") assert all(structure.pbc) +def test_tilted_unitcell(raw_structure, Assert): + cell = np.array([[4, 0, 0], [0, 4, 0], [2, 2, 6]]) + inv_cell = np.linalg.inv(cell) + cartesian_positions = ( + (0, 0, 0), + (4, 4, 4), + (2, 2, 2), + (2, 2, 0), + (2, 4, 2), + (4, 2, 2), + (2, 2, 4), + ) + raw_structure.cell = raw.Cell(scale=1, lattice_vectors=cell) + raw_structure.positions = cartesian_positions @ inv_cell + structure = Structure(raw_structure).to_ase() + Assert.allclose(structure.positions, cartesian_positions) + + def test_plot(raw_structure): cm_init = patch.object(data.Viewer3d, "__init__", autospec=True, return_value=None) cm_cell = patch.object(data.Viewer3d, "show_cell") @@ -51,14 +71,13 @@ def test_plot(raw_structure): def test_supercell(raw_structure, Assert): structure = Structure(raw_structure) number_atoms = len(structure) - cell = raw_structure.cell.lattice_vectors # scale all dimensions by constant factor scale = 2 supercell = structure.to_ase(supercell=scale) assert len(supercell) == number_atoms * scale ** 3 - Assert.allclose(supercell.cell.array, cell * scale) + Assert.allclose(supercell.cell.array, raw_structure.actual_cell * scale) # scale differently for each dimension scale = (2, 1, 3) supercell = structure.to_ase(supercell=scale) assert len(supercell) == number_atoms * np.prod(scale) - Assert.allclose(supercell.cell.array, cell * scale) + Assert.allclose(supercell.cell.array, raw_structure.actual_cell * scale) diff --git a/tests/data/test_topology.py b/tests/data/test_topology.py index d22a145..e809464 100644 --- a/tests/data/test_topology.py +++ b/tests/data/test_topology.py @@ -9,10 +9,13 @@ @pytest.fixture def raw_topology(): - return raw.Topology( + topology = raw.Topology( number_ion_types=np.array((2, 1, 4)), ion_types=np.array(("Sr", "Ti", "O "), dtype="S"), ) + topology.names = ["Sr_1", "Sr_2", "Ti_1", "O_1", "O_2", "O_3", "O_4"] + topology.elements = ["Sr", "Sr", "Ti", "O", "O", "O", "O"] + return topology def test_raw_topology(raw_topology): @@ -31,6 +34,12 @@ def test_raw_topology(raw_topology): assert topology["*"] == Selection(indices=range(index[-1])) +def test_atom_labels(raw_topology): + topology = Topology(raw_topology) + assert topology.names() == raw_topology.names + assert topology.elements() == raw_topology.elements + + def test_from_file(raw_topology, mock_file, check_read): with mock_file("topology", raw_topology) as mocks: check_read(Topology, mocks, raw_topology) diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index 5839407..17f4f43 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -22,9 +22,8 @@ def raw_trajectory(raw_topology): def test_read_trajectory(raw_trajectory, Assert): trajectory = Trajectory(raw_trajectory).read() - topology = Topology(raw_trajectory.topology) - assert trajectory["names"] == topology.names() - assert trajectory["elements"] == topology.elements() + assert trajectory["names"] == raw_trajectory.topology.names + assert trajectory["elements"] == raw_trajectory.topology.elements Assert.allclose(trajectory["positions"], raw_trajectory.positions) Assert.allclose(trajectory["lattice_vectors"], raw_trajectory.lattice_vectors) diff --git a/tests/data/test_viewer3d.py b/tests/data/test_viewer3d.py index 93430ee..7dcc856 100644 --- a/tests/data/test_viewer3d.py +++ b/tests/data/test_viewer3d.py @@ -2,7 +2,7 @@ from py4vasp.data import Structure, Viewer3d from py4vasp.data.viewer3d import _Arrow3d, _x_axis, _y_axis, _z_axis from py4vasp.exceptions import RefinementException -from .test_structure import raw_structure +from .test_structure import raw_structure, raw_topology import numpy as np import pytest import nglview @@ -81,8 +81,8 @@ def _assert_arrow_message(message, arrow): def test_arrows(viewer3d, assert_arrow_message): - number_atoms = len(viewer3d.raw_structure.cartesian_positions) - positions = viewer3d.raw_structure.cartesian_positions + positions = viewer3d.raw_structure.positions @ viewer3d.raw_structure.actual_cell + number_atoms = len(positions) color = [0.1, 0.1, 0.8] arrows = create_arrows(viewer3d, number_atoms) messages = last_messages(viewer3d, number_atoms) @@ -101,7 +101,7 @@ def test_arrows(viewer3d, assert_arrow_message): def test_supercell(raw_structure, assert_arrow_message): supercell = (1, 2, 3) viewer = make_viewer(raw_structure, supercell) - number_atoms = len(raw_structure.cartesian_positions) + number_atoms = len(raw_structure.positions) create_arrows(viewer, number_atoms) assert count_messages(viewer) == np.prod(supercell) * number_atoms diff --git a/tests/raw/test_file.py b/tests/raw/test_file.py index 0564d2d..741f69d 100644 --- a/tests/raw/test_file.py +++ b/tests/raw/test_file.py @@ -358,17 +358,17 @@ def test_structure(tmpdir): def reference_structure(): structure = raw.Structure( + topology=reference_topology(), cell=reference_cell(), - cartesian_positions=np.linspace(np.zeros(3), np.ones(3), num_atoms), - species=np.array(["C"] * num_atoms, dtype="S2"), + positions=np.linspace(np.zeros(3), np.ones(3), num_atoms), ) return structure def write_structure(h5f, structure): + write_topology(h5f, structure.topology) write_cell(h5f, structure.cell) - h5f["results/positions/cartesian_positions"] = structure.cartesian_positions - h5f["results/positions/species"] = structure.species + h5f["results/positions/position_ions"] = structure.positions def check_structure(file, reference):