Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector and error bar stretches #86

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions glue_ar/common/scatter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from functools import partial
from numpy import clip, isfinite, isnan, ndarray, ones, sqrt
from numpy import array, clip, isfinite, isnan, ndarray, ones, sqrt
from typing import Callable, Dict, List, Optional, Tuple, Union

from glue.utils import ensure_numerical
from glue_vispy_viewers.scatter.layer_state import ScatterLayerState

from glue_ar.common.shapes import rectangular_prism_points, rectangular_prism_triangulation, \
sphere_points, sphere_triangles
from glue_ar.utils import Bounds, NoneType, Viewer3DState, mask_for_bounds
from glue_ar.utils import Bounds, NoneType, Viewer3DState, get_stretches, mask_for_bounds

try:
from glue_jupyter.ipyvolume.scatter import Scatter3DLayerState
Expand Down Expand Up @@ -93,6 +93,28 @@
return sizes


def clip_vector_data(viewer_state: Viewer3DState,
layer_state: ScatterLayerState3D,
bounds: Bounds,
mask: Optional[ndarray] = None) -> ndarray:
if isinstance(layer_state, ScatterLayerState):
atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute]

Check warning on line 101 in glue_ar/common/scatter.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter.py#L100-L101

Added lines #L100 - L101 were not covered by tests
else:
atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att]
vector_data = [layer_state.layer[att].ravel()[mask] for att in atts]

Check warning on line 104 in glue_ar/common/scatter.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter.py#L103-L104

Added lines #L103 - L104 were not covered by tests

stretches = get_stretches(viewer_state)
if viewer_state.native_aspect:
factor = max((abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches)))
vector_data = [[0.5 * t / factor for t in v] for v in vector_data]

Check warning on line 109 in glue_ar/common/scatter.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter.py#L106-L109

Added lines #L106 - L109 were not covered by tests
else:
bound_factors = [abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches)]
vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)]
vector_data = array(list(zip(*vector_data)))

Check warning on line 113 in glue_ar/common/scatter.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter.py#L111-L113

Added lines #L111 - L113 were not covered by tests

return vector_data

Check warning on line 115 in glue_ar/common/scatter.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter.py#L115

Added line #L115 was not covered by tests


def sphere_points_getter(theta_resolution: int,
phi_resolution: int) -> PointsGetter:

Expand Down
33 changes: 11 additions & 22 deletions glue_ar/common/scatter_gltf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from gltflib import AccessorType, BufferTarget, ComponentType, PrimitiveMode
from glue_vispy_viewers.common.viewer_state import Vispy3DViewerState
from glue_vispy_viewers.scatter.layer_state import ScatterLayerState
from numpy import array, isfinite, ndarray
from numpy import isfinite, ndarray
from numpy.linalg import norm

from typing import List, Literal, Optional, Tuple
Expand All @@ -12,13 +12,14 @@
from glue_ar.common.shapes import cone_triangles, cone_points, cylinder_points, cylinder_triangles, \
normalize, rectangular_prism_triangulation, sphere_triangles
from glue_ar.gltf_utils import add_points_to_bytearray, add_triangles_to_bytearray, index_mins, index_maxes
from glue_ar.utils import Viewer3DState, iterable_has_nan, hex_to_components, layer_color, \
from glue_ar.utils import Viewer3DState, get_stretches, iterable_has_nan, hex_to_components, layer_color, \
unique_id, xyz_bounds, xyz_for_layer, Bounds
from glue_ar.common.gltf_builder import GLTFBuilder
from glue_ar.common.scatter import Scatter3DLayerState, ScatterLayerState3D, \
PointsGetter, box_points_getter, IPYVOLUME_POINTS_GETTERS, \
IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, radius_for_scatter_layer, \
scatter_layer_mask, sizes_for_scatter_layer, sphere_points_getter, NoneType
IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, clip_vector_data, \
radius_for_scatter_layer, scatter_layer_mask, sizes_for_scatter_layer, \
sphere_points_getter, NoneType

try:
from glue_jupyter.common.state3d import ViewerState3D
Expand All @@ -39,20 +40,7 @@
materials: Optional[List[int]] = None,
mask: Optional[ndarray] = None):

if isinstance(layer_state, ScatterLayerState):
atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute]
else:
atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att]
vector_data = [layer_state.layer[att].ravel()[mask] for att in atts]

if viewer_state.native_aspect:
factor = max((abs(b[1] - b[0]) for b in bounds))
vector_data = [[0.5 * t / factor for t in v] for v in vector_data]
else:
bound_factors = [abs(b[1] - b[0]) for b in bounds]
vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)]
vector_data = array(list(zip(*vector_data)))

vector_data = clip_vector_data(viewer_state, layer_state, bounds, mask)

Check warning on line 43 in glue_ar/common/scatter_gltf.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter_gltf.py#L43

Added line #L43 was not covered by tests
offset = VECTOR_OFFSETS[layer_state.vector_origin]
if layer_state.vector_origin == "tip":
offset += tip_height
Expand Down Expand Up @@ -166,12 +154,13 @@
# NB: This ordering is intentional to account for glTF coordinate system
gltf_index = ['z', 'y', 'x'].index(axis)

axis_range = abs(bounds[index][1] - bounds[index][0])
stretches = get_stretches(viewer_state)

Check warning on line 157 in glue_ar/common/scatter_gltf.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter_gltf.py#L157

Added line #L157 was not covered by tests
if viewer_state.native_aspect:
max_range = max((abs(b[1] - b[0]) for b in bounds))
factor = 1 / max_range
max_side = max(abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches))
factor = 1 / max_side

Check warning on line 160 in glue_ar/common/scatter_gltf.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter_gltf.py#L159-L160

Added lines #L159 - L160 were not covered by tests
else:
factor = 1 / axis_range
axis_factor = abs(bounds[index][1] - bounds[index][0]) * stretches[index]
factor = 1 / axis_factor

Check warning on line 163 in glue_ar/common/scatter_gltf.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter_gltf.py#L162-L163

Added lines #L162 - L163 were not covered by tests
err_values *= factor

barr = bytearray()
Expand Down
19 changes: 3 additions & 16 deletions glue_ar/common/scatter_usd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from glue_vispy_viewers.scatter.layer_state import ScatterLayerState
from glue_vispy_viewers.scatter.viewer_state import Vispy3DViewerState
from numpy import array, ndarray
from numpy import ndarray
from numpy.linalg import norm

from glue_ar.common.export_options import ar_layer_export
from glue_ar.common.scatter import IPYVOLUME_POINTS_GETTERS, IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, PointsGetter, \
ScatterLayerState3D, box_points_getter, radius_for_scatter_layer, \
ScatterLayerState3D, box_points_getter, clip_vector_data, radius_for_scatter_layer, \
scatter_layer_mask, sizes_for_scatter_layer, sphere_points_getter
from glue_ar.common.scatter_export_options import ARIpyvolumeScatterExportOptions, ARVispyScatterExportOptions
from glue_ar.common.usd_builder import USDBuilder
Expand Down Expand Up @@ -38,20 +38,7 @@
colors: Optional[List[Tuple[int, int, int]]] = None,
mask: Optional[ndarray] = None):

if isinstance(layer_state, ScatterLayerState):
atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute]
else:
atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att]
vector_data = [layer_state.layer[att].ravel()[mask] for att in atts]

if viewer_state.native_aspect:
factor = max((abs(b[1] - b[0]) for b in bounds))
vector_data = [[0.5 * t / factor for t in v] for v in vector_data]
else:
bound_factors = [abs(b[1] - b[0]) for b in bounds]
vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)]
vector_data = array(list(zip(*vector_data)))

vector_data = clip_vector_data(viewer_state, layer_state, bounds, mask)

Check warning on line 41 in glue_ar/common/scatter_usd.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/common/scatter_usd.py#L41

Added line #L41 was not covered by tests
offset = VECTOR_OFFSETS[layer_state.vector_origin]
if layer_state.vector_origin == "tip":
offset += tip_height
Expand Down
18 changes: 9 additions & 9 deletions glue_ar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,7 @@
def clip_sides(viewer_state: Viewer3DState,
clip_size: float = 1.0) -> Tuple[float, float, float]:

stretches = tuple(
getattr(viewer_state, f"{axis}_stretch", 1.0)
for axis in ("x", "y", "z")
)

stretches = get_stretches(viewer_state)

Check warning on line 174 in glue_ar/utils.py

View check run for this annotation

Codecov / codecov/patch

glue_ar/utils.py#L174

Added line #L174 was not covered by tests
bounds = xyz_bounds(viewer_state, with_resolution=False)
resolution = get_resolution(viewer_state)
x_range = viewer_state.x_max - viewer_state.x_min
Expand Down Expand Up @@ -224,6 +220,13 @@
(data[viewer_state.z_att] <= bounds[2][1])


def get_stretches(viewer_state: Viewer3DState) -> Tuple[float, float, float]:
return tuple(
getattr(viewer_state, f"{axis}_stretch", 1.0)
for axis in ("x", "y", "z")
)


# TODO: Worry about efficiency later
# and just generally make this better
def xyz_for_layer(viewer_state: Viewer3DState,
Expand All @@ -237,10 +240,7 @@
vals = [xs, ys, zs]

if scaled:
stretches = tuple(
getattr(viewer_state, f"{axis}_stretch", 1.0)
for axis in ("x", "y", "z")
)
stretches = get_stretches(viewer_state)
bounds = xyz_bounds(viewer_state, with_resolution=False)
vals = bring_into_clip(vals, bounds, preserve_aspect=preserve_aspect, stretches=stretches)

Expand Down
Loading