Skip to content

Commit

Permalink
Refactor vector data and stretch functionality into common function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Carifio24 committed Nov 14, 2024
1 parent 0a11430 commit 3c0f3ba
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 48 deletions.
26 changes: 24 additions & 2 deletions glue_ar/common/scatter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from functools import partial
from numpy import clip, isfinite, isnan, ndarray, ones, sqrt
from numpy import array, clip, isfinite, isnan, ndarray, ones, sqrt
from typing import Callable, Dict, List, Optional, Tuple, Union

from glue.utils import ensure_numerical
from glue_vispy_viewers.scatter.layer_state import ScatterLayerState

from glue_ar.common.shapes import rectangular_prism_points, rectangular_prism_triangulation, \
sphere_points, sphere_triangles
from glue_ar.utils import Bounds, NoneType, Viewer3DState, mask_for_bounds
from glue_ar.utils import Bounds, NoneType, Viewer3DState, get_stretches, mask_for_bounds

try:
from glue_jupyter.ipyvolume.scatter import Scatter3DLayerState
Expand Down Expand Up @@ -93,6 +93,28 @@ def sizes_for_scatter_layer(layer_state: ScatterLayerState3D,
return sizes


def clip_vector_data(viewer_state: Viewer3DState,
layer_state: ScatterLayerState3D,
bounds: Bounds,
mask: Optional[ndarray] = None) -> ndarray:
if isinstance(layer_state, ScatterLayerState):
atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute]
else:
atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att]
vector_data = [layer_state.layer[att].ravel()[mask] for att in atts]

stretches = get_stretches(viewer_state)
if viewer_state.native_aspect:
factor = max((abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches)))
vector_data = [[0.5 * t / factor for t in v] for v in vector_data]
else:
bound_factors = [abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches)]
vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)]
vector_data = array(list(zip(*vector_data)))

return vector_data


def sphere_points_getter(theta_resolution: int,
phi_resolution: int) -> PointsGetter:

Expand Down
30 changes: 9 additions & 21 deletions glue_ar/common/scatter_gltf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from glue_ar.common.shapes import cone_triangles, cone_points, cylinder_points, cylinder_triangles, \
normalize, rectangular_prism_triangulation, sphere_triangles
from glue_ar.gltf_utils import add_points_to_bytearray, add_triangles_to_bytearray, index_mins, index_maxes
from glue_ar.utils import Viewer3DState, iterable_has_nan, hex_to_components, layer_color, \
from glue_ar.utils import Viewer3DState, get_stretches, iterable_has_nan, hex_to_components, layer_color, \
unique_id, xyz_bounds, xyz_for_layer, Bounds
from glue_ar.common.gltf_builder import GLTFBuilder
from glue_ar.common.scatter import Scatter3DLayerState, ScatterLayerState3D, \
PointsGetter, box_points_getter, IPYVOLUME_POINTS_GETTERS, \
IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, radius_for_scatter_layer, \
scatter_layer_mask, sizes_for_scatter_layer, sphere_points_getter, NoneType
IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, clip_vector_data, radius_for_scatter_layer, \
scatter_layer_mask, sizes_for_scatter_layer, sphere_points_getter, NoneType

try:
from glue_jupyter.common.state3d import ViewerState3D
Expand All @@ -39,20 +39,7 @@ def add_vectors_gltf(builder: GLTFBuilder,
materials: Optional[List[int]] = None,
mask: Optional[ndarray] = None):

if isinstance(layer_state, ScatterLayerState):
atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute]
else:
atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att]
vector_data = [layer_state.layer[att].ravel()[mask] for att in atts]

if viewer_state.native_aspect:
factor = max((abs(b[1] - b[0]) for b in bounds))
vector_data = [[0.5 * t / factor for t in v] for v in vector_data]
else:
bound_factors = [abs(b[1] - b[0]) for b in bounds]
vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)]
vector_data = array(list(zip(*vector_data)))

vector_data = clip_vector_data(viewer_state, layer_state, bounds, mask)
offset = VECTOR_OFFSETS[layer_state.vector_origin]
if layer_state.vector_origin == "tip":
offset += tip_height
Expand Down Expand Up @@ -166,12 +153,13 @@ def add_error_bars_gltf(builder: GLTFBuilder,
# NB: This ordering is intentional to account for glTF coordinate system
gltf_index = ['z', 'y', 'x'].index(axis)

axis_range = abs(bounds[index][1] - bounds[index][0])
stretches = get_stretches(viewer_state)
if viewer_state.native_aspect:
max_range = max((abs(b[1] - b[0]) for b in bounds))
factor = 1 / max_range
max_side = max(abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches))
factor = 1 / max_side
else:
factor = 1 / axis_range
axis_factor = abs(bounds[index][1] - bounds[index][0]) * stretches[index]
factor = 1 / axis_factor
err_values *= factor

barr = bytearray()
Expand Down
19 changes: 3 additions & 16 deletions glue_ar/common/scatter_usd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from glue_vispy_viewers.scatter.layer_state import ScatterLayerState
from glue_vispy_viewers.scatter.viewer_state import Vispy3DViewerState
from numpy import array, ndarray
from numpy import ndarray
from numpy.linalg import norm

from glue_ar.common.export_options import ar_layer_export
from glue_ar.common.scatter import IPYVOLUME_POINTS_GETTERS, IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, PointsGetter, \
ScatterLayerState3D, box_points_getter, radius_for_scatter_layer, \
ScatterLayerState3D, box_points_getter, clip_vector_data, radius_for_scatter_layer, \
scatter_layer_mask, sizes_for_scatter_layer, sphere_points_getter
from glue_ar.common.scatter_export_options import ARIpyvolumeScatterExportOptions, ARVispyScatterExportOptions
from glue_ar.common.usd_builder import USDBuilder
Expand Down Expand Up @@ -38,20 +38,7 @@ def add_vectors_usd(builder: USDBuilder,
colors: Optional[List[Tuple[int, int, int]]] = None,
mask: Optional[ndarray] = None):

if isinstance(layer_state, ScatterLayerState):
atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute]
else:
atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att]
vector_data = [layer_state.layer[att].ravel()[mask] for att in atts]

if viewer_state.native_aspect:
factor = max((abs(b[1] - b[0]) for b in bounds))
vector_data = [[0.5 * t / factor for t in v] for v in vector_data]
else:
bound_factors = [abs(b[1] - b[0]) for b in bounds]
vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)]
vector_data = array(list(zip(*vector_data)))

vector_data = clip_vector_data(viewer_state, layer_state, bounds, mask)
offset = VECTOR_OFFSETS[layer_state.vector_origin]
if layer_state.vector_origin == "tip":
offset += tip_height
Expand Down
18 changes: 9 additions & 9 deletions glue_ar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,7 @@ def layer_color(layer_state: LayerState) -> str:
def clip_sides(viewer_state: Viewer3DState,
clip_size: float = 1.0) -> Tuple[float, float, float]:

stretches = tuple(
getattr(viewer_state, f"{axis}_stretch", 1.0)
for axis in ("x", "y", "z")
)

stretches = get_stretches(viewer_state)
bounds = xyz_bounds(viewer_state, with_resolution=False)
resolution = get_resolution(viewer_state)
x_range = viewer_state.x_max - viewer_state.x_min
Expand Down Expand Up @@ -224,6 +220,13 @@ def mask_for_bounds(viewer_state: Viewer3DState,
(data[viewer_state.z_att] <= bounds[2][1])


def get_stretches(viewer_state: Viewer3DState) -> Tuple[float, float, float]:
return tuple(
getattr(viewer_state, f"{axis}_stretch", 1.0)
for axis in ("x", "y", "z")
)


# TODO: Worry about efficiency later
# and just generally make this better
def xyz_for_layer(viewer_state: Viewer3DState,
Expand All @@ -237,10 +240,7 @@ def xyz_for_layer(viewer_state: Viewer3DState,
vals = [xs, ys, zs]

if scaled:
stretches = tuple(
getattr(viewer_state, f"{axis}_stretch", 1.0)
for axis in ("x", "y", "z")
)
stretches = get_stretches(viewer_state)
bounds = xyz_bounds(viewer_state, with_resolution=False)
vals = bring_into_clip(vals, bounds, preserve_aspect=preserve_aspect, stretches=stretches)

Expand Down

0 comments on commit 3c0f3ba

Please sign in to comment.