Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vector field support #42

Merged
merged 8 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions znvis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@
from znvis.mesh.custom import CustomMesh
from znvis.mesh.cylinder import Cylinder
from znvis.mesh.sphere import Sphere
from znvis.mesh.arrow import Arrow
from znvis.particle.particle import Particle
from znvis.particle.vector_field import VectorField
from znvis.visualizer.visualizer import Visualizer

__all__ = [
Particle.__name__,
Sphere.__name__,
Arrow.__name__,
VectorField.__name__,
Visualizer.__name__,
Cylinder.__name__,
CustomMesh.__name__,
Expand Down
89 changes: 89 additions & 0 deletions znvis/mesh/arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
ZnVis: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:

Summary
-------
Create a sphere mesh.
"""

from dataclasses import dataclass

import numpy as np
import open3d as o3d

from znvis.transformations.rotation_matrices import rotation_matrix

from znvis.mesh import Mesh


@dataclass
class Arrow(Mesh):
"""
A class to produce arrow meshes.

Attributes
----------
scale : float
Scale of the arrow
resolution : int
Resolution of the mesh.
"""
scale: float = 1.0
resolution: int = 10

def create_mesh(
self, starting_position: np.ndarray, direction: np.ndarray = None
) -> o3d.geometry.TriangleMesh:
"""
Create a mesh object defined by the dataclass.

Parameters
----------
starting_position : np.ndarray shape=(3,)
Starting position of the mesh.
direction : np.ndarray shape=(3,) (default = None)
Direction of the mesh.

Returns
-------
mesh : o3d.geometry.TriangleMesh
"""

direction_length = np.linalg.norm(direction)

cylinder_radius = 0.06 * direction_length * self.scale
cylinder_height = 0.85 * direction_length * self.scale
cone_radius = 0.15 * direction_length * self.scale
cone_height = 0.15 * direction_length * self.scale

arrow = o3d.geometry.TriangleMesh.create_arrow(
cylinder_radius=cylinder_radius,
cylinder_height=cylinder_height,
cone_radius=cone_radius,
cone_height=cone_height,
resolution=self.resolution
)

arrow.compute_vertex_normals()
matrix = rotation_matrix(np.array([0, 0, 1]), direction)
arrow.rotate(matrix, center=(0, 0, 0))

# Translate the arrow to the starting position and center the origin
arrow.translate(starting_position.astype(float))

return arrow
111 changes: 111 additions & 0 deletions znvis/particle/vector_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
ZnVis: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the particle parent class
"""

import typing
from dataclasses import dataclass

import numpy as np
from rich.progress import track

from znvis.mesh.arrow import Arrow


@dataclass
class VectorField:
"""
A class to represent a vector field.

Attributes
----------
name : str
Name of the vector field
mesh : Mesh
Mesh to use
position : np.ndarray
Position tensor of the shape (n_steps, n_vectors, n_dims)
direction : np.ndarray
Direction tensor of the shape (n_steps, n_vectors, n_dims)
mesh_list : list
A list of mesh objects, one for each time step.
smoothing : bool (default=False)
If true, apply smoothing to each mesh object as it is rendered.
This will slow down the initial construction of the mesh objects
but not the deployment.
"""

name: str
mesh: Arrow = None # Should be an instance of the Arrow class
position: np.ndarray = None
direction: np.ndarray = None
mesh_list: typing.List[Arrow] = None

smoothing: bool = False

def _create_mesh(self, position, direction):
"""
Create a mesh object for the vector field.

Parameters
----------
position : np.ndarray
Position of the arrow
direction : np.ndarray
Direction of the arrow

Returns
-------
mesh : o3d.geometry.TriangleMesh
A mesh object
"""

mesh = self.mesh.create_mesh(position, direction)
if self.smoothing:
return mesh.filter_smooth_taubin(100)
else:
return mesh

def construct_mesh_list(self):
"""
Constructor the mesh list for the class.

The mesh list is a list of mesh objects for each
time step in the parsed trajectory.

Returns
-------
Updates the class attributes mesh_list
"""
self.mesh_list = []
try:
n_particles = int(self.position.shape[1])
n_time_steps = int(self.position.shape[0])
except ValueError:
raise ValueError("There is no data for this vector field.")

for i in track(range(n_time_steps), description=f"Building {self.name} Mesh"):
for j in range(n_particles):
if j == 0:
mesh = self._create_mesh(self.position[i][j], self.direction[i][j])
else:
mesh += self._create_mesh(self.position[i][j], self.direction[i][j])
self.mesh_list.append(mesh)
44 changes: 44 additions & 0 deletions znvis/visualizer/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Visualizer:
def __init__(
self,
particles: typing.List[znvis.Particle],
vector_field: typing.List[znvis.VectorField] = None,
output_folder: typing.Union[str, pathlib.Path] = "./",
frame_rate: int = 24,
number_of_steps: int = None,
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
The format of the video to be generated.
"""
self.particles = particles
self.vector_field = vector_field
self.frame_rate = frame_rate
self.bounding_box = bounding_box() if bounding_box else None

Expand Down Expand Up @@ -305,6 +307,11 @@ def _initialize_particles(self):

self._draw_particles(initial=True)

def _initialize_vector_field(self):
for item in self.vector_field:
item.construct_mesh_list()
self._draw_vector_field(initial=True)

def _draw_particles(self, visualizer=None, initial: bool = False):
"""
Draw the particles on the visualizer.
Expand Down Expand Up @@ -344,6 +351,36 @@ def _draw_particles(self, visualizer=None, initial: bool = False):
item.name, item.mesh_list[self.counter], item.mesh.o3d_material
)


def _draw_vector_field(self, visualizer=None, initial: bool = False):
"""
Draw the vector field on the visualizer.

Parameters
----------
initial : bool (default = True)
If true, no particles are removed.

Returns
-------
updates the information in the visualizer.
-----
"""
if visualizer is None:
visualizer = self.vis

if initial:
for i, item in enumerate(self.vector_field):
visualizer.add_geometry(
item.name, item.mesh_list[self.counter], item.mesh.o3d_material
)
else:
for i, item in enumerate(self.vector_field):
visualizer.remove_geometry(item.name)
visualizer.add_geometry(
item.name, item.mesh_list[self.counter], item.mesh.o3d_material
)

def _continuous_trajectory(self, vis):
"""
Button command for running the simulation in the visualizer.
Expand Down Expand Up @@ -509,6 +546,11 @@ def _update_particles(self, visualizer=None, step: int = None):
step = self.counter

self._draw_particles(visualizer=visualizer) # draw the particles.

# draw the vector field if it exists.
if self.vector_field is not None:
self._draw_vector_field(visualizer=visualizer)

visualizer.post_redraw() # re-draw the window.

def run_visualization(self):
Expand All @@ -521,6 +563,8 @@ def run_visualization(self):
"""
self._initialize_app()
self._initialize_particles()
if self.vector_field is not None:
self._initialize_vector_field()

self.vis.reset_camera_to_default()
self.app.run()