diff --git a/glue_ar/common/scatter.py b/glue_ar/common/scatter.py index 5c330a7..e111955 100644 --- a/glue_ar/common/scatter.py +++ b/glue_ar/common/scatter.py @@ -1,5 +1,5 @@ from functools import partial -from numpy import clip, isfinite, isnan, ndarray, ones, sqrt +from numpy import array, clip, isfinite, isnan, ndarray, ones, sqrt from typing import Callable, Dict, List, Optional, Tuple, Union from glue.utils import ensure_numerical @@ -7,7 +7,7 @@ from glue_ar.common.shapes import rectangular_prism_points, rectangular_prism_triangulation, \ sphere_points, sphere_triangles -from glue_ar.utils import Bounds, NoneType, Viewer3DState, mask_for_bounds +from glue_ar.utils import Bounds, NoneType, Viewer3DState, get_stretches, mask_for_bounds try: from glue_jupyter.ipyvolume.scatter import Scatter3DLayerState @@ -93,6 +93,28 @@ def sizes_for_scatter_layer(layer_state: ScatterLayerState3D, return sizes +def clip_vector_data(viewer_state: Viewer3DState, + layer_state: ScatterLayerState3D, + bounds: Bounds, + mask: Optional[ndarray] = None) -> ndarray: + if isinstance(layer_state, ScatterLayerState): + atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute] + else: + atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att] + vector_data = [layer_state.layer[att].ravel()[mask] for att in atts] + + stretches = get_stretches(viewer_state) + if viewer_state.native_aspect: + factor = max((abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches))) + vector_data = [[0.5 * t / factor for t in v] for v in vector_data] + else: + bound_factors = [abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches)] + vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)] + vector_data = array(list(zip(*vector_data))) + + return vector_data + + def sphere_points_getter(theta_resolution: int, phi_resolution: int) -> PointsGetter: diff --git a/glue_ar/common/scatter_gltf.py b/glue_ar/common/scatter_gltf.py index a317884..3dc75e3 100644 --- a/glue_ar/common/scatter_gltf.py +++ b/glue_ar/common/scatter_gltf.py @@ -1,7 +1,7 @@ from gltflib import AccessorType, BufferTarget, ComponentType, PrimitiveMode from glue_vispy_viewers.common.viewer_state import Vispy3DViewerState from glue_vispy_viewers.scatter.layer_state import ScatterLayerState -from numpy import array, isfinite, ndarray +from numpy import isfinite, ndarray from numpy.linalg import norm from typing import List, Literal, Optional, Tuple @@ -12,13 +12,14 @@ from glue_ar.common.shapes import cone_triangles, cone_points, cylinder_points, cylinder_triangles, \ normalize, rectangular_prism_triangulation, sphere_triangles from glue_ar.gltf_utils import add_points_to_bytearray, add_triangles_to_bytearray, index_mins, index_maxes -from glue_ar.utils import Viewer3DState, iterable_has_nan, hex_to_components, layer_color, \ +from glue_ar.utils import Viewer3DState, get_stretches, iterable_has_nan, hex_to_components, layer_color, \ unique_id, xyz_bounds, xyz_for_layer, Bounds from glue_ar.common.gltf_builder import GLTFBuilder from glue_ar.common.scatter import Scatter3DLayerState, ScatterLayerState3D, \ PointsGetter, box_points_getter, IPYVOLUME_POINTS_GETTERS, \ - IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, radius_for_scatter_layer, \ - scatter_layer_mask, sizes_for_scatter_layer, sphere_points_getter, NoneType + IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, clip_vector_data, \ + radius_for_scatter_layer, scatter_layer_mask, sizes_for_scatter_layer, \ + sphere_points_getter, NoneType try: from glue_jupyter.common.state3d import ViewerState3D @@ -39,20 +40,7 @@ def add_vectors_gltf(builder: GLTFBuilder, materials: Optional[List[int]] = None, mask: Optional[ndarray] = None): - if isinstance(layer_state, ScatterLayerState): - atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute] - else: - atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att] - vector_data = [layer_state.layer[att].ravel()[mask] for att in atts] - - if viewer_state.native_aspect: - factor = max((abs(b[1] - b[0]) for b in bounds)) - vector_data = [[0.5 * t / factor for t in v] for v in vector_data] - else: - bound_factors = [abs(b[1] - b[0]) for b in bounds] - vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)] - vector_data = array(list(zip(*vector_data))) - + vector_data = clip_vector_data(viewer_state, layer_state, bounds, mask) offset = VECTOR_OFFSETS[layer_state.vector_origin] if layer_state.vector_origin == "tip": offset += tip_height @@ -166,12 +154,13 @@ def add_error_bars_gltf(builder: GLTFBuilder, # NB: This ordering is intentional to account for glTF coordinate system gltf_index = ['z', 'y', 'x'].index(axis) - axis_range = abs(bounds[index][1] - bounds[index][0]) + stretches = get_stretches(viewer_state) if viewer_state.native_aspect: - max_range = max((abs(b[1] - b[0]) for b in bounds)) - factor = 1 / max_range + max_side = max(abs(b[1] - b[0]) * s for b, s in zip(bounds, stretches)) + factor = 1 / max_side else: - factor = 1 / axis_range + axis_factor = abs(bounds[index][1] - bounds[index][0]) * stretches[index] + factor = 1 / axis_factor err_values *= factor barr = bytearray() diff --git a/glue_ar/common/scatter_usd.py b/glue_ar/common/scatter_usd.py index ec6a6be..0526783 100644 --- a/glue_ar/common/scatter_usd.py +++ b/glue_ar/common/scatter_usd.py @@ -3,12 +3,12 @@ from glue_vispy_viewers.scatter.layer_state import ScatterLayerState from glue_vispy_viewers.scatter.viewer_state import Vispy3DViewerState -from numpy import array, ndarray +from numpy import ndarray from numpy.linalg import norm from glue_ar.common.export_options import ar_layer_export from glue_ar.common.scatter import IPYVOLUME_POINTS_GETTERS, IPYVOLUME_TRIANGLE_GETTERS, VECTOR_OFFSETS, PointsGetter, \ - ScatterLayerState3D, box_points_getter, radius_for_scatter_layer, \ + ScatterLayerState3D, box_points_getter, clip_vector_data, radius_for_scatter_layer, \ scatter_layer_mask, sizes_for_scatter_layer, sphere_points_getter from glue_ar.common.scatter_export_options import ARIpyvolumeScatterExportOptions, ARVispyScatterExportOptions from glue_ar.common.usd_builder import USDBuilder @@ -38,20 +38,7 @@ def add_vectors_usd(builder: USDBuilder, colors: Optional[List[Tuple[int, int, int]]] = None, mask: Optional[ndarray] = None): - if isinstance(layer_state, ScatterLayerState): - atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute] - else: - atts = [layer_state.vx_att, layer_state.vy_att, layer_state.vz_att] - vector_data = [layer_state.layer[att].ravel()[mask] for att in atts] - - if viewer_state.native_aspect: - factor = max((abs(b[1] - b[0]) for b in bounds)) - vector_data = [[0.5 * t / factor for t in v] for v in vector_data] - else: - bound_factors = [abs(b[1] - b[0]) for b in bounds] - vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)] - vector_data = array(list(zip(*vector_data))) - + vector_data = clip_vector_data(viewer_state, layer_state, bounds, mask) offset = VECTOR_OFFSETS[layer_state.vector_origin] if layer_state.vector_origin == "tip": offset += tip_height diff --git a/glue_ar/utils.py b/glue_ar/utils.py index b943f2e..b01aac9 100644 --- a/glue_ar/utils.py +++ b/glue_ar/utils.py @@ -171,11 +171,7 @@ def layer_color(layer_state: LayerState) -> str: def clip_sides(viewer_state: Viewer3DState, clip_size: float = 1.0) -> Tuple[float, float, float]: - stretches = tuple( - getattr(viewer_state, f"{axis}_stretch", 1.0) - for axis in ("x", "y", "z") - ) - + stretches = get_stretches(viewer_state) bounds = xyz_bounds(viewer_state, with_resolution=False) resolution = get_resolution(viewer_state) x_range = viewer_state.x_max - viewer_state.x_min @@ -224,6 +220,13 @@ def mask_for_bounds(viewer_state: Viewer3DState, (data[viewer_state.z_att] <= bounds[2][1]) +def get_stretches(viewer_state: Viewer3DState) -> Tuple[float, float, float]: + return tuple( + getattr(viewer_state, f"{axis}_stretch", 1.0) + for axis in ("x", "y", "z") + ) + + # TODO: Worry about efficiency later # and just generally make this better def xyz_for_layer(viewer_state: Viewer3DState, @@ -237,10 +240,7 @@ def xyz_for_layer(viewer_state: Viewer3DState, vals = [xs, ys, zs] if scaled: - stretches = tuple( - getattr(viewer_state, f"{axis}_stretch", 1.0) - for axis in ("x", "y", "z") - ) + stretches = get_stretches(viewer_state) bounds = xyz_bounds(viewer_state, with_resolution=False) vals = bring_into_clip(vals, bounds, preserve_aspect=preserve_aspect, stretches=stretches)