diff --git a/src/coordinax/_src/d3/__init__.py b/src/coordinax/_src/d3/__init__.py index 4144b8c..01069a7 100644 --- a/src/coordinax/_src/d3/__init__.py +++ b/src/coordinax/_src/d3/__init__.py @@ -3,20 +3,26 @@ from . import ( base, + base_spherical, cartesian, compat, constructor, cylindrical, generic, + lonlatspherical, + mathspherical, spherical, transform, ) from .base import * +from .base_spherical import * from .cartesian import * from .compat import * from .constructor import * from .cylindrical import * from .generic import * +from .lonlatspherical import * +from .mathspherical import * from .spherical import * from .transform import * @@ -24,7 +30,10 @@ __all__ += base.__all__ __all__ += cartesian.__all__ __all__ += cylindrical.__all__ +__all__ += base_spherical.__all__ __all__ += spherical.__all__ +__all__ += mathspherical.__all__ +__all__ += lonlatspherical.__all__ __all__ += generic.__all__ __all__ += transform.__all__ __all__ += compat.__all__ diff --git a/src/coordinax/_src/d3/base_spherical.py b/src/coordinax/_src/d3/base_spherical.py new file mode 100644 index 0000000..65f6540 --- /dev/null +++ b/src/coordinax/_src/d3/base_spherical.py @@ -0,0 +1,55 @@ +"""Built-in vector classes.""" + +__all__ = [ + "AbstractSphericalPosition", + "AbstractSphericalVelocity", + "AbstractSphericalAcceleration", +] + +from abc import abstractmethod +from typing_extensions import override + +from unxt import Quantity + +from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D +from coordinax._src.utils import classproperty + +_90d = Quantity(90, "deg") +_180d = Quantity(180, "deg") +_360d = Quantity(360, "deg") + + +class AbstractSphericalPosition(AbstractPosition3D): + """Abstract spherical vector representation.""" + + @override + @classproperty + @classmethod + @abstractmethod + def differential_cls(cls) -> "type[AbstractSphericalVelocity]": ... + + +class AbstractSphericalVelocity(AbstractVelocity3D): + """Spherical differential representation.""" + + @override + @classproperty + @classmethod + @abstractmethod + def integral_cls(cls) -> type[AbstractSphericalPosition]: ... + + @override + @classproperty + @classmethod + @abstractmethod + def differential_cls(cls) -> "type[AbstractSphericalAcceleration]": ... + + +class AbstractSphericalAcceleration(AbstractAcceleration3D): + """Spherical acceleration representation.""" + + @override + @classproperty + @classmethod + @abstractmethod + def integral_cls(cls) -> type[AbstractSphericalVelocity]: ... diff --git a/src/coordinax/_src/d3/lonlatspherical.py b/src/coordinax/_src/d3/lonlatspherical.py new file mode 100644 index 0000000..6181090 --- /dev/null +++ b/src/coordinax/_src/d3/lonlatspherical.py @@ -0,0 +1,335 @@ +"""Built-in vector classes.""" + +__all__ = [ + "LonLatSphericalPosition", + "LonLatSphericalVelocity", + "LonLatSphericalAcceleration", + "LonCosLatSphericalVelocity", +] + +from functools import partial +from typing import final +from typing_extensions import override + +import equinox as eqx + +import quaxed.lax as qlax +import quaxed.numpy as jnp +from dataclassish.converters import Unless +from unxt import AbstractDistance, Distance, Quantity + +import coordinax._src.typing as ct +from .base_spherical import ( + AbstractSphericalAcceleration, + AbstractSphericalPosition, + AbstractSphericalVelocity, + _90d, + _180d, +) +from coordinax._src.checks import ( + check_azimuth_range, + check_polar_range, + check_r_non_negative, +) +from coordinax._src.converters import converter_azimuth_to_range +from coordinax._src.utils import classproperty + + +@final +class LonLatSphericalPosition(AbstractSphericalPosition): + """Spherical vector representation. + + .. note:: + + This class follows the Geographic / Astronomical convention. + + Parameters + ---------- + lon : Quantity['angle'] + The longitude (azimuthal) angle [0, 360) [deg] where 0 is the x-axis. + lat : Quantity['angle'] + The latitude (polar angle) [-90, 90] [deg] where 90 is the z-axis. + distance : Distance + Radial distance r (slant distance to origin), + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> cx.LonLatSphericalPosition(lon=Quantity(0, "deg"), lat=Quantity(0, "deg"), + ... distance=Quantity(3, "kpc")) + LonLatSphericalPosition( + lon=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")), + lat=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")), + distance=Distance(value=f32[], unit=Unit("kpc")) + ) + + The longitude and latitude angles are in the range [0, 360) and [-90, 90] degrees, + and the radial distance is non-negative. + When initializing, the longitude is wrapped to the [0, 360) degrees range. + + >>> vec = cx.LonLatSphericalPosition(lon=Quantity(365, "deg"), + ... lat=Quantity(90, "deg"), + ... distance=Quantity(3, "kpc")) + >>> vec.lon + Quantity['angle'](Array(5., dtype=float32), unit='deg') + + The latitude is not wrapped, but it is checked to be in the [-90, 90] degrees range. + + .. skip: next + + >>> try: + ... cx.LonLatSphericalPosition(lon=Quantity(0, "deg"), lat=Quantity(100, "deg"), + ... distance=Quantity(3, "kpc")) + ... except Exception as e: + ... print(e) + The inclination angle must be in the range [0, pi]... + + Likewise, the radial distance is checked to be non-negative. + + .. skip: next + + >>> try: + ... cx.LonLatSphericalPosition(lon=Quantity(0, "deg"), lat=Quantity(0, "deg"), + ... distance=Quantity(-3, "kpc")) + ... except Exception as e: + ... print(e) + The radial distance must be non-negative... + + """ + + lon: ct.BatchableAngle = eqx.field( + converter=lambda x: converter_azimuth_to_range( + Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 + ) + ) + r"""Longitude (azimuthal) angle :math:`\in [0,360)`.""" + + lat: ct.BatchableAngle = eqx.field( + converter=partial(Quantity["angle"].constructor, dtype=float) + ) + r"""Latitude (polar) angle :math:`\in [-90,90]`.""" + + distance: ct.BatchableDistance = eqx.field( + converter=Unless(AbstractDistance, partial(Distance.constructor, dtype=float)) + ) + r"""Radial distance :math:`r \in [0,+\infty)`.""" + + def __check_init__(self) -> None: + """Check the validity of the initialization.""" + check_azimuth_range(self.lon) + check_polar_range(self.lat, -Quantity(90, "deg"), Quantity(90, "deg")) + check_r_non_negative(self.distance) + + @override + @classproperty + @classmethod + def differential_cls(cls) -> type["LonLatSphericalVelocity"]: + return LonLatSphericalVelocity + + @override + @partial(eqx.filter_jit, inline=True) + def norm(self) -> ct.BatchableDistance: + """Return the norm of the vector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> s = cx.LonLatSphericalPosition(lon=Quantity(0, "deg"), + ... lat=Quantity(90, "deg"), + ... distance=Quantity(3, "kpc")) + >>> s.norm() + Distance(Array(3., dtype=float32), unit='kpc') + + """ + return self.distance + + +@LonLatSphericalPosition.constructor._f.register # type: ignore[attr-defined, misc] # noqa: SLF001 +def constructor( + cls: type[LonLatSphericalPosition], + *, + lon: Quantity["angle"], + lat: Quantity["angle"], + distance: Distance, +) -> LonLatSphericalPosition: + """Construct LonLatSphericalPosition, allowing for out-of-range values. + + Examples + -------- + >>> import coordinax as cx + + Let's start with a valid input: + + >>> cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"), + ... lat=Quantity(0, "deg"), + ... distance=Quantity(3, "kpc")) + LonLatSphericalPosition( + lon=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")), + lat=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")), + distance=Distance(value=f32[], unit=Unit("kpc")) + ) + + The distance can be negative, which wraps the longitude by 180 degrees and + flips the latitude: + + >>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"), + ... lat=Quantity(45, "deg"), + ... distance=Quantity(-3, "kpc")) + >>> vec.lon + Quantity['angle'](Array(180., dtype=float32), unit='deg') + >>> vec.lat + Quantity['angle'](Array(-45., dtype=float32), unit='deg') + >>> vec.distance + Distance(Array(3., dtype=float32), unit='kpc') + + The latitude can be outside the [-90, 90] deg range, causing the longitude + to be shifted by 180 degrees: + + >>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"), + ... lat=Quantity(-100, "deg"), + ... distance=Quantity(3, "kpc")) + >>> vec.lon + Quantity['angle'](Array(180., dtype=float32), unit='deg') + >>> vec.lat + Quantity['angle'](Array(-80., dtype=float32), unit='deg') + >>> vec.distance + Distance(Array(3., dtype=float32), unit='kpc') + + >>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"), + ... lat=Quantity(100, "deg"), + ... distance=Quantity(3, "kpc")) + >>> vec.lon + Quantity['angle'](Array(180., dtype=float32), unit='deg') + >>> vec.lat + Quantity['angle'](Array(80., dtype=float32), unit='deg') + >>> vec.distance + Distance(Array(3., dtype=float32), unit='kpc') + + The longitude can be outside the [0, 360) deg range. This is wrapped to the + [0, 360) deg range (actually the base constructor does this): + + >>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(365, "deg"), + ... lat=Quantity(0, "deg"), + ... distance=Quantity(3, "kpc")) + >>> vec.lon + Quantity['angle'](Array(5., dtype=float32), unit='deg') + + """ + # 1) Convert the inputs + fields = LonLatSphericalPosition.__dataclass_fields__ + lon = fields["lon"].metadata["converter"](lon) + lat = fields["lat"].metadata["converter"](lat) + distance = fields["distance"].metadata["converter"](distance) + + # 2) handle negative distances + distance_pred = distance < jnp.zeros_like(distance) + distance = qlax.select(distance_pred, -distance, distance) + lon = qlax.select(distance_pred, lon + _180d, lon) + lat = qlax.select(distance_pred, -lat, lat) + + # 3) Handle latitude outside of [-90, 90] degrees + # TODO: fix when lat < -180, lat > 180 + lat_pred = lat < -_90d + lat = qlax.select(lat_pred, -_180d - lat, lat) + lon = qlax.select(lat_pred, lon + _180d, lon) + + lat_pred = lat > _90d + lat = qlax.select(lat_pred, _180d - lat, lat) + lon = qlax.select(lat_pred, lon + _180d, lon) + + # 4) Construct. This also handles the longitude wrapping + return cls(lon=lon, lat=lat, distance=distance) + + +############################################################################## + + +@final +class LonLatSphericalVelocity(AbstractSphericalVelocity): + """Spherical differential representation.""" + + d_lon: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Longitude speed :math:`dlon/dt \in [-\infty, \infty].""" + + d_lat: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Latitude speed :math:`dlat/dt \in [-\infty, \infty].""" + + d_distance: ct.BatchableSpeed = eqx.field( + converter=partial(Quantity["speed"].constructor, dtype=float) + ) + r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" + + @classproperty + @classmethod + def integral_cls(cls) -> type[LonLatSphericalPosition]: + return LonLatSphericalPosition + + @classproperty + @classmethod + def differential_cls(cls) -> type["LonLatSphericalAcceleration"]: + return LonLatSphericalAcceleration + + +@final +class LonCosLatSphericalVelocity(AbstractSphericalVelocity): + """Spherical differential representation.""" + + d_lon_coslat: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Longitude * cos(Latitude) speed :math:`dlon/dt \in [-\infty, \infty].""" + + d_lat: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Latitude speed :math:`dlat/dt \in [-\infty, \infty].""" + + d_distance: ct.BatchableSpeed = eqx.field( + converter=partial(Quantity["speed"].constructor, dtype=float) + ) + r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" + + @classproperty + @classmethod + def integral_cls(cls) -> type[LonLatSphericalPosition]: + return LonLatSphericalPosition + + @classproperty + @classmethod + def differential_cls(cls) -> type["LonLatSphericalAcceleration"]: + return LonLatSphericalAcceleration + + +############################################################################## + + +@final +class LonLatSphericalAcceleration(AbstractSphericalAcceleration): + """Spherical acceleration representation.""" + + d2_lon: ct.BatchableAngularAcc = eqx.field( + converter=partial(Quantity["angular acceleration"].constructor, dtype=float) + ) + r"""Longitude acceleration :math:`d^2lon/dt^2 \in [-\infty, \infty].""" + + d2_lat: ct.BatchableAngularAcc = eqx.field( + converter=partial(Quantity["angular acceleration"].constructor, dtype=float) + ) + r"""Latitude acceleration :math:`d^2lat/dt^2 \in [-\infty, \infty].""" + + d2_distance: ct.BatchableAcc = eqx.field( + converter=partial(Quantity["acceleration"].constructor, dtype=float) + ) + r"""Radial acceleration :math:`d^2r/dt^2 \in [-\infty, \infty].""" + + @classproperty + @classmethod + def integral_cls(cls) -> type[LonLatSphericalVelocity]: + return LonLatSphericalVelocity diff --git a/src/coordinax/_src/d3/mathspherical.py b/src/coordinax/_src/d3/mathspherical.py new file mode 100644 index 0000000..68d2436 --- /dev/null +++ b/src/coordinax/_src/d3/mathspherical.py @@ -0,0 +1,290 @@ +"""Built-in vector classes.""" + +__all__ = [ + "MathSphericalPosition", + "MathSphericalVelocity", + "MathSphericalAcceleration", +] + +from functools import partial +from typing import final +from typing_extensions import override + +import equinox as eqx +import jax +from jaxtyping import ArrayLike +from quax import register + +import quaxed.lax as qlax +import quaxed.numpy as jnp +from dataclassish import replace +from dataclassish.converters import Unless +from unxt import AbstractDistance, AbstractQuantity, Distance, Quantity + +import coordinax._src.typing as ct +from .base_spherical import ( + AbstractSphericalAcceleration, + AbstractSphericalPosition, + AbstractSphericalVelocity, + _180d, + _360d, +) +from coordinax._src.checks import ( + check_azimuth_range, + check_polar_range, + check_r_non_negative, +) +from coordinax._src.converters import converter_azimuth_to_range +from coordinax._src.utils import classproperty + + +@final +class MathSphericalPosition(AbstractSphericalPosition): + """Spherical vector representation. + + .. note:: + + This class follows the Mathematics conventions. + + Parameters + ---------- + r : Distance + Radial distance r (slant distance to origin), + theta : Quantity['angle'] + Azimuthal angle [0, 360) [deg] where 0 is the x-axis. + phi : Quantity['angle'] + Polar angle [0, 180] [deg] where 0 is the z-axis. + + """ + + r: ct.BatchableDistance = eqx.field( + converter=Unless(AbstractDistance, partial(Distance.constructor, dtype=float)) + ) + r"""Radial distance :math:`r \in [0,+\infty)`.""" + + theta: ct.BatchableAngle = eqx.field( + converter=lambda x: converter_azimuth_to_range( + Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 + ) + ) + r"""Azimuthal angle :math:`\theta \in [0,360)`.""" + + phi: ct.BatchableAngle = eqx.field( + converter=partial(Quantity["angle"].constructor, dtype=float) + ) + r"""Inclination angle :math:`\phi \in [0,180]`.""" + + def __check_init__(self) -> None: + """Check the validity of the initialization.""" + check_r_non_negative(self.r) + check_azimuth_range(self.theta) + check_polar_range(self.phi) + + @override + @classproperty + @classmethod + def differential_cls(cls) -> type["MathSphericalVelocity"]: + return MathSphericalVelocity + + @partial(eqx.filter_jit, inline=True) + def norm(self) -> ct.BatchableDistance: + """Return the norm of the vector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> s = cx.MathSphericalPosition(r=Quantity(3, "kpc"), + ... theta=Quantity(90, "deg"), + ... phi=Quantity(0, "deg")) + >>> s.norm() + Distance(Array(3., dtype=float32), unit='kpc') + + """ + return self.r + + +@MathSphericalPosition.constructor._f.register # type: ignore[attr-defined, misc] # noqa: SLF001 +def constructor( + cls: type[MathSphericalPosition], + *, + r: AbstractQuantity, + theta: AbstractQuantity, + phi: AbstractQuantity, +) -> MathSphericalPosition: + """Construct MathSphericalPosition, allowing for out-of-range values. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + Let's start with a valid input: + + >>> cx.MathSphericalPosition.constructor(r=Quantity(3, "kpc"), + ... theta=Quantity(90, "deg"), + ... phi=Quantity(0, "deg")) + MathSphericalPosition( + r=Distance(value=f32[], unit=Unit("kpc")), + theta=Quantity[...](value=f32[], unit=Unit("deg")), + phi=Quantity[...](value=f32[], unit=Unit("deg")) + ) + + The radial distance can be negative, which wraps the azimuthal angle by 180 + degrees and flips the polar angle: + + >>> vec = cx.MathSphericalPosition.constructor(r=Quantity(-3, "kpc"), + ... theta=Quantity(100, "deg"), + ... phi=Quantity(45, "deg")) + >>> vec.r + Distance(Array(3., dtype=float32), unit='kpc') + >>> vec.theta + Quantity['angle'](Array(280., dtype=float32), unit='deg') + >>> vec.phi + Quantity[...](Array(135., dtype=float32), unit='deg') + + The polar angle can be outside the [0, 180] deg range, causing the azimuthal + angle to be shifted by 180 degrees: + + >>> vec = cx.MathSphericalPosition.constructor(r=Quantity(3, "kpc"), + ... theta=Quantity(0, "deg"), + ... phi=Quantity(190, "deg")) + >>> vec.r + Distance(Array(3., dtype=float32), unit='kpc') + >>> vec.theta + Quantity['angle'](Array(180., dtype=float32), unit='deg') + >>> vec.phi + Quantity['angle'](Array(170., dtype=float32), unit='deg') + + The azimuth can be outside the [0, 360) deg range. This is wrapped to the + [0, 360) deg range (actually the base constructor does this): + + >>> vec = cx.MathSphericalPosition.constructor(r=Quantity(3, "kpc"), + ... theta=Quantity(365, "deg"), + ... phi=Quantity(90, "deg")) + >>> vec.theta + Quantity['angle'](Array(5., dtype=float32), unit='deg') + + """ + # 1) Convert the inputs + fields = MathSphericalPosition.__dataclass_fields__ + r = fields["r"].metadata["converter"](r) + theta = fields["theta"].metadata["converter"](theta) + phi = fields["phi"].metadata["converter"](phi) + + # 2) handle negative distances + r_pred = r < jnp.zeros_like(r) + r = qlax.select(r_pred, -r, r) + theta = qlax.select(r_pred, theta + _180d, theta) + phi = qlax.select(r_pred, _180d - phi, phi) + + # 3) Handle polar angle outside of [0, 180] degrees + phi = jnp.mod(phi, _360d) # wrap to [0, 360) deg + phi_pred = phi < _180d + phi = qlax.select(phi_pred, phi, _360d - phi) + theta = qlax.select(phi_pred, theta, theta + _180d) + + # 4) Construct. This also handles the azimuthal angle wrapping + return cls(r=r, theta=theta, phi=phi) + + +@register(jax.lax.mul_p) # type: ignore[misc] +def _mul_p_vmsph( + lhs: ArrayLike, rhs: MathSphericalPosition, / +) -> MathSphericalPosition: + """Scale the polar position by a scalar. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> import quaxed.numpy as jnp + + >>> v = cx.MathSphericalPosition(r=Quantity(3, "kpc"), + ... theta=Quantity(90, "deg"), + ... phi=Quantity(0, "deg")) + + >>> jnp.linalg.vector_norm(v, axis=-1) + Quantity['length'](Array(3., dtype=float32), unit='kpc') + + >>> nv = jnp.multiply(2, v) + >>> nv + MathSphericalPosition( + r=Distance(value=f32[], unit=Unit("kpc")), + theta=Quantity[...](value=f32[], unit=Unit("deg")), + phi=Quantity[...](value=f32[], unit=Unit("deg")) + ) + >>> nv.r + Distance(Array(6., dtype=float32), unit='kpc') + + """ + # Validation + lhs = eqx.error_if( + lhs, any(jax.numpy.shape(lhs)), f"must be a scalar, not {type(lhs)}" + ) + # Scale the radial distance + return replace(rhs, r=lhs * rhs.r) + + +############################################################################## + + +@final +class MathSphericalVelocity(AbstractSphericalVelocity): + """Spherical differential representation.""" + + d_r: ct.BatchableSpeed = eqx.field( + converter=partial(Quantity["speed"].constructor, dtype=float) + ) + r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" + + d_theta: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Azimuthal speed :math:`d\theta/dt \in [-\infty, \infty].""" + + d_phi: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Inclination speed :math:`d\phi/dt \in [-\infty, \infty].""" + + @override + @classproperty + @classmethod + def integral_cls(cls) -> type[MathSphericalPosition]: + return MathSphericalPosition + + @override + @classproperty + @classmethod + def differential_cls(cls) -> type["MathSphericalAcceleration"]: + return MathSphericalAcceleration + + +############################################################################## + + +@final +class MathSphericalAcceleration(AbstractSphericalAcceleration): + """Spherical acceleration representation.""" + + d2_r: ct.BatchableAcc = eqx.field( + converter=partial(Quantity["acceleration"].constructor, dtype=float) + ) + r"""Radial acceleration :math:`d^2r/dt^2 \in [-\infty, \infty].""" + + d2_theta: ct.BatchableAngularAcc = eqx.field( + converter=partial(Quantity["angular acceleration"].constructor, dtype=float) + ) + r"""Azimuthal acceleration :math:`d^2\theta/dt^2 \in [-\infty, \infty].""" + + d2_phi: ct.BatchableAngularAcc = eqx.field( + converter=partial(Quantity["angular acceleration"].constructor, dtype=float) + ) + r"""Inclination acceleration :math:`d^2\phi/dt^2 \in [-\infty, \infty].""" + + @override + @classproperty + @classmethod + def integral_cls(cls) -> type[MathSphericalVelocity]: + return MathSphericalVelocity diff --git a/src/coordinax/_src/d3/spherical.py b/src/coordinax/_src/d3/spherical.py index 647b757..e96cded 100644 --- a/src/coordinax/_src/d3/spherical.py +++ b/src/coordinax/_src/d3/spherical.py @@ -1,43 +1,30 @@ """Built-in vector classes.""" __all__ = [ - "AbstractSphericalPosition", - "AbstractSphericalVelocity", - "AbstractSphericalAcceleration", # Physics conventions "SphericalPosition", "SphericalVelocity", "SphericalAcceleration", - # Mathematics conventions - "MathSphericalPosition", - "MathSphericalVelocity", - "MathSphericalAcceleration", - # Geographic / Astronomical conventions - "LonLatSphericalPosition", - "LonLatSphericalVelocity", - "LonLatSphericalAcceleration", - "LonCosLatSphericalVelocity", ] -from abc import abstractmethod from functools import partial from typing import final -from typing_extensions import override import equinox as eqx -import jax -from jaxtyping import ArrayLike -from quax import register import quaxed.lax as qlax import quaxed.numpy as jnp -from dataclassish import replace from dataclassish.converters import Unless from unxt import AbstractDistance, AbstractQuantity, Distance, Quantity import coordinax._src.typing as ct -from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D -from coordinax._src.base import AbstractAcceleration +from .base_spherical import ( + AbstractSphericalAcceleration, + AbstractSphericalPosition, + AbstractSphericalVelocity, + _180d, + _360d, +) from coordinax._src.checks import ( check_azimuth_range, check_polar_range, @@ -46,29 +33,14 @@ from coordinax._src.converters import converter_azimuth_to_range from coordinax._src.utils import classproperty -_90d = Quantity(90, "deg") -_180d = Quantity(180, "deg") -_360d = Quantity(360, "deg") - ############################################################################## # Position -class AbstractSphericalPosition(AbstractPosition3D): - """Abstract spherical vector representation.""" - - @classproperty - @classmethod - @abstractmethod - def differential_cls(cls) -> type["AbstractSphericalVelocity"]: ... - - -# ============================================================================ - - +# TODO: make this an alias for SphericalPolarPosition, the more correct description? @final class SphericalPosition(AbstractSphericalPosition): - """Spherical vector representation. + """Spherical-Polar coordinates. .. note:: @@ -198,390 +170,11 @@ def constructor( return cls(r=r, theta=theta, phi=phi) -# ============================================================================ - - -@final -class MathSphericalPosition(AbstractSphericalPosition): - """Spherical vector representation. - - .. note:: - - This class follows the Mathematics conventions. - - Parameters - ---------- - r : Distance - Radial distance r (slant distance to origin), - theta : Quantity['angle'] - Azimuthal angle [0, 360) [deg] where 0 is the x-axis. - phi : Quantity['angle'] - Polar angle [0, 180] [deg] where 0 is the z-axis. - - """ - - r: ct.BatchableDistance = eqx.field( - converter=Unless(AbstractDistance, partial(Distance.constructor, dtype=float)) - ) - r"""Radial distance :math:`r \in [0,+\infty)`.""" - - theta: ct.BatchableAngle = eqx.field( - converter=lambda x: converter_azimuth_to_range( - Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 - ) - ) - r"""Azimuthal angle :math:`\theta \in [0,360)`.""" - - phi: ct.BatchableAngle = eqx.field( - converter=partial(Quantity["angle"].constructor, dtype=float) - ) - r"""Inclination angle :math:`\phi \in [0,180]`.""" - - def __check_init__(self) -> None: - """Check the validity of the initialization.""" - check_r_non_negative(self.r) - check_azimuth_range(self.theta) - check_polar_range(self.phi) - - @override - @classproperty - @classmethod - def differential_cls(cls) -> type["MathSphericalVelocity"]: - return MathSphericalVelocity - - @partial(eqx.filter_jit, inline=True) - def norm(self) -> ct.BatchableDistance: - """Return the norm of the vector. - - Examples - -------- - >>> from unxt import Quantity - >>> import coordinax as cx - >>> s = cx.MathSphericalPosition(r=Quantity(3, "kpc"), - ... theta=Quantity(90, "deg"), - ... phi=Quantity(0, "deg")) - >>> s.norm() - Distance(Array(3., dtype=float32), unit='kpc') - - """ - return self.r - - -@MathSphericalPosition.constructor._f.register # type: ignore[attr-defined, misc] # noqa: SLF001 -def constructor( - cls: type[MathSphericalPosition], - *, - r: AbstractQuantity, - theta: AbstractQuantity, - phi: AbstractQuantity, -) -> MathSphericalPosition: - """Construct MathSphericalPosition, allowing for out-of-range values. - - Examples - -------- - >>> from unxt import Quantity - >>> import coordinax as cx - - Let's start with a valid input: - - >>> cx.MathSphericalPosition.constructor(r=Quantity(3, "kpc"), - ... theta=Quantity(90, "deg"), - ... phi=Quantity(0, "deg")) - MathSphericalPosition( - r=Distance(value=f32[], unit=Unit("kpc")), - theta=Quantity[...](value=f32[], unit=Unit("deg")), - phi=Quantity[...](value=f32[], unit=Unit("deg")) - ) - - The radial distance can be negative, which wraps the azimuthal angle by 180 - degrees and flips the polar angle: - - >>> vec = cx.MathSphericalPosition.constructor(r=Quantity(-3, "kpc"), - ... theta=Quantity(100, "deg"), - ... phi=Quantity(45, "deg")) - >>> vec.r - Distance(Array(3., dtype=float32), unit='kpc') - >>> vec.theta - Quantity['angle'](Array(280., dtype=float32), unit='deg') - >>> vec.phi - Quantity[...](Array(135., dtype=float32), unit='deg') - - The polar angle can be outside the [0, 180] deg range, causing the azimuthal - angle to be shifted by 180 degrees: - - >>> vec = cx.MathSphericalPosition.constructor(r=Quantity(3, "kpc"), - ... theta=Quantity(0, "deg"), - ... phi=Quantity(190, "deg")) - >>> vec.r - Distance(Array(3., dtype=float32), unit='kpc') - >>> vec.theta - Quantity['angle'](Array(180., dtype=float32), unit='deg') - >>> vec.phi - Quantity['angle'](Array(170., dtype=float32), unit='deg') - - The azimuth can be outside the [0, 360) deg range. This is wrapped to the - [0, 360) deg range (actually the base constructor does this): - - >>> vec = cx.MathSphericalPosition.constructor(r=Quantity(3, "kpc"), - ... theta=Quantity(365, "deg"), - ... phi=Quantity(90, "deg")) - >>> vec.theta - Quantity['angle'](Array(5., dtype=float32), unit='deg') - - """ - # 1) Convert the inputs - fields = SphericalPosition.__dataclass_fields__ - r = fields["r"].metadata["converter"](r) - theta = fields["theta"].metadata["converter"](theta) - phi = fields["phi"].metadata["converter"](phi) - - # 2) handle negative distances - r_pred = r < jnp.zeros_like(r) - r = qlax.select(r_pred, -r, r) - theta = qlax.select(r_pred, theta + _180d, theta) - phi = qlax.select(r_pred, _180d - phi, phi) - - # 3) Handle polar angle outside of [0, 180] degrees - phi = jnp.mod(phi, _360d) # wrap to [0, 360) deg - phi_pred = phi < _180d - phi = qlax.select(phi_pred, phi, _360d - phi) - theta = qlax.select(phi_pred, theta, theta + _180d) - - # 4) Construct. This also handles the azimuthal angle wrapping - return cls(r=r, theta=theta, phi=phi) - - -# ============================================================================ - - -@final -class LonLatSphericalPosition(AbstractSphericalPosition): - """Spherical vector representation. - - .. note:: - - This class follows the Geographic / Astronomical convention. - - Parameters - ---------- - lon : Quantity['angle'] - The longitude (azimuthal) angle [0, 360) [deg] where 0 is the x-axis. - lat : Quantity['angle'] - The latitude (polar angle) [-90, 90] [deg] where 90 is the z-axis. - distance : Distance - Radial distance r (slant distance to origin), - - Examples - -------- - >>> from unxt import Quantity - >>> import coordinax as cx - - >>> cx.LonLatSphericalPosition(lon=Quantity(0, "deg"), lat=Quantity(0, "deg"), - ... distance=Quantity(3, "kpc")) - LonLatSphericalPosition( - lon=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")), - lat=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")), - distance=Distance(value=f32[], unit=Unit("kpc")) - ) - - The longitude and latitude angles are in the range [0, 360) and [-90, 90] degrees, - and the radial distance is non-negative. - When initializing, the longitude is wrapped to the [0, 360) degrees range. - - >>> vec = cx.LonLatSphericalPosition(lon=Quantity(365, "deg"), - ... lat=Quantity(90, "deg"), - ... distance=Quantity(3, "kpc")) - >>> vec.lon - Quantity['angle'](Array(5., dtype=float32), unit='deg') - - The latitude is not wrapped, but it is checked to be in the [-90, 90] degrees range. - - .. skip: next - - >>> try: - ... cx.LonLatSphericalPosition(lon=Quantity(0, "deg"), lat=Quantity(100, "deg"), - ... distance=Quantity(3, "kpc")) - ... except Exception as e: - ... print(e) - The inclination angle must be in the range [0, pi]... - - Likewise, the radial distance is checked to be non-negative. - - .. skip: next - - >>> try: - ... cx.LonLatSphericalPosition(lon=Quantity(0, "deg"), lat=Quantity(0, "deg"), - ... distance=Quantity(-3, "kpc")) - ... except Exception as e: - ... print(e) - The radial distance must be non-negative... - - """ - - lon: ct.BatchableAngle = eqx.field( - converter=lambda x: converter_azimuth_to_range( - Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 - ) - ) - r"""Longitude (azimuthal) angle :math:`\in [0,360)`.""" - - lat: ct.BatchableAngle = eqx.field( - converter=partial(Quantity["angle"].constructor, dtype=float) - ) - r"""Latitude (polar) angle :math:`\in [-90,90]`.""" - - distance: ct.BatchableDistance = eqx.field( - converter=Unless(AbstractDistance, partial(Distance.constructor, dtype=float)) - ) - r"""Radial distance :math:`r \in [0,+\infty)`.""" - - def __check_init__(self) -> None: - """Check the validity of the initialization.""" - check_azimuth_range(self.lon) - check_polar_range(self.lat, -Quantity(90, "deg"), Quantity(90, "deg")) - check_r_non_negative(self.distance) - - @override - @classproperty - @classmethod - def differential_cls(cls) -> type["LonLatSphericalVelocity"]: - return LonLatSphericalVelocity - - @override - @partial(eqx.filter_jit, inline=True) - def norm(self) -> ct.BatchableDistance: - """Return the norm of the vector. - - Examples - -------- - >>> from unxt import Quantity - >>> import coordinax as cx - >>> s = cx.LonLatSphericalPosition(lon=Quantity(0, "deg"), - ... lat=Quantity(90, "deg"), - ... distance=Quantity(3, "kpc")) - >>> s.norm() - Distance(Array(3., dtype=float32), unit='kpc') - - """ - return self.distance - - -@LonLatSphericalPosition.constructor._f.register # type: ignore[attr-defined, misc] # noqa: SLF001 -def constructor( - cls: type[LonLatSphericalPosition], - *, - lon: Quantity["angle"], - lat: Quantity["angle"], - distance: Distance, -) -> LonLatSphericalPosition: - """Construct LonLatSphericalPosition, allowing for out-of-range values. - - Examples - -------- - >>> import coordinax as cx - - Let's start with a valid input: - - >>> cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"), - ... lat=Quantity(0, "deg"), - ... distance=Quantity(3, "kpc")) - LonLatSphericalPosition( - lon=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")), - lat=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")), - distance=Distance(value=f32[], unit=Unit("kpc")) - ) - - The distance can be negative, which wraps the longitude by 180 degrees and - flips the latitude: - - >>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"), - ... lat=Quantity(45, "deg"), - ... distance=Quantity(-3, "kpc")) - >>> vec.lon - Quantity['angle'](Array(180., dtype=float32), unit='deg') - >>> vec.lat - Quantity['angle'](Array(-45., dtype=float32), unit='deg') - >>> vec.distance - Distance(Array(3., dtype=float32), unit='kpc') - - The latitude can be outside the [-90, 90] deg range, causing the longitude - to be shifted by 180 degrees: - - >>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"), - ... lat=Quantity(-100, "deg"), - ... distance=Quantity(3, "kpc")) - >>> vec.lon - Quantity['angle'](Array(180., dtype=float32), unit='deg') - >>> vec.lat - Quantity['angle'](Array(-80., dtype=float32), unit='deg') - >>> vec.distance - Distance(Array(3., dtype=float32), unit='kpc') - - >>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"), - ... lat=Quantity(100, "deg"), - ... distance=Quantity(3, "kpc")) - >>> vec.lon - Quantity['angle'](Array(180., dtype=float32), unit='deg') - >>> vec.lat - Quantity['angle'](Array(80., dtype=float32), unit='deg') - >>> vec.distance - Distance(Array(3., dtype=float32), unit='kpc') - - The longitude can be outside the [0, 360) deg range. This is wrapped to the - [0, 360) deg range (actually the base constructor does this): - - >>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(365, "deg"), - ... lat=Quantity(0, "deg"), - ... distance=Quantity(3, "kpc")) - >>> vec.lon - Quantity['angle'](Array(5., dtype=float32), unit='deg') - - """ - # 1) Convert the inputs - fields = LonLatSphericalPosition.__dataclass_fields__ - lon = fields["lon"].metadata["converter"](lon) - lat = fields["lat"].metadata["converter"](lat) - distance = fields["distance"].metadata["converter"](distance) - - # 2) handle negative distances - distance_pred = distance < jnp.zeros_like(distance) - distance = qlax.select(distance_pred, -distance, distance) - lon = qlax.select(distance_pred, lon + _180d, lon) - lat = qlax.select(distance_pred, -lat, lat) - - # 3) Handle latitude outside of [-90, 90] degrees - # TODO: fix when lat < -180, lat > 180 - lat_pred = lat < -_90d - lat = qlax.select(lat_pred, -_180d - lat, lat) - lon = qlax.select(lat_pred, lon + _180d, lon) - - lat_pred = lat > _90d - lat = qlax.select(lat_pred, _180d - lat, lat) - lon = qlax.select(lat_pred, lon + _180d, lon) - - # 4) Construct. This also handles the longitude wrapping - return cls(lon=lon, lat=lat, distance=distance) - - ############################################################################## -class AbstractSphericalVelocity(AbstractVelocity3D): - """Spherical differential representation.""" - - @classproperty - @classmethod - @abstractmethod - def integral_cls(cls) -> type[SphericalPosition]: ... - - @classproperty - @classmethod - @abstractmethod - def differential_cls(cls) -> type[AbstractAcceleration]: ... - - @final -class SphericalVelocity(AbstractVelocity3D): +class SphericalVelocity(AbstractSphericalVelocity): """Spherical differential representation.""" d_r: ct.BatchableSpeed = eqx.field( @@ -610,110 +203,11 @@ def differential_cls(cls) -> type["SphericalAcceleration"]: return SphericalAcceleration -@final -class MathSphericalVelocity(AbstractVelocity3D): - """Spherical differential representation.""" - - d_r: ct.BatchableSpeed = eqx.field( - converter=partial(Quantity["speed"].constructor, dtype=float) - ) - r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" - - d_theta: ct.BatchableAngularSpeed = eqx.field( - converter=partial(Quantity["angular speed"].constructor, dtype=float) - ) - r"""Azimuthal speed :math:`d\theta/dt \in [-\infty, \infty].""" - - d_phi: ct.BatchableAngularSpeed = eqx.field( - converter=partial(Quantity["angular speed"].constructor, dtype=float) - ) - r"""Inclination speed :math:`d\phi/dt \in [-\infty, \infty].""" - - @classproperty - @classmethod - def integral_cls(cls) -> type[MathSphericalPosition]: - return MathSphericalPosition - - @classproperty - @classmethod - def differential_cls(cls) -> type["MathSphericalAcceleration"]: - return MathSphericalAcceleration - - -@final -class LonLatSphericalVelocity(AbstractVelocity3D): - """Spherical differential representation.""" - - d_lon: ct.BatchableAngularSpeed = eqx.field( - converter=partial(Quantity["angular speed"].constructor, dtype=float) - ) - r"""Longitude speed :math:`dlon/dt \in [-\infty, \infty].""" - - d_lat: ct.BatchableAngularSpeed = eqx.field( - converter=partial(Quantity["angular speed"].constructor, dtype=float) - ) - r"""Latitude speed :math:`dlat/dt \in [-\infty, \infty].""" - - d_distance: ct.BatchableSpeed = eqx.field( - converter=partial(Quantity["speed"].constructor, dtype=float) - ) - r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" - - @classproperty - @classmethod - def integral_cls(cls) -> type[LonLatSphericalPosition]: - return LonLatSphericalPosition - - @classproperty - @classmethod - def differential_cls(cls) -> type["LonLatSphericalAcceleration"]: - return LonLatSphericalAcceleration - - -@final -class LonCosLatSphericalVelocity(AbstractVelocity3D): - """Spherical differential representation.""" - - d_lon_coslat: ct.BatchableAngularSpeed = eqx.field( - converter=partial(Quantity["angular speed"].constructor, dtype=float) - ) - r"""Longitude * cos(Latitude) speed :math:`dlon/dt \in [-\infty, \infty].""" - - d_lat: ct.BatchableAngularSpeed = eqx.field( - converter=partial(Quantity["angular speed"].constructor, dtype=float) - ) - r"""Latitude speed :math:`dlat/dt \in [-\infty, \infty].""" - - d_distance: ct.BatchableSpeed = eqx.field( - converter=partial(Quantity["speed"].constructor, dtype=float) - ) - r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" - - @classproperty - @classmethod - def integral_cls(cls) -> type[LonLatSphericalPosition]: - return LonLatSphericalPosition - - @classproperty - @classmethod - def differential_cls(cls) -> type["LonLatSphericalAcceleration"]: - return LonLatSphericalAcceleration - - ############################################################################## -class AbstractSphericalAcceleration(AbstractAcceleration3D): - """Spherical acceleration representation.""" - - @classproperty - @classmethod - @abstractmethod - def integral_cls(cls) -> type[SphericalVelocity]: ... - - @final -class SphericalAcceleration(AbstractAcceleration3D): +class SphericalAcceleration(AbstractSphericalAcceleration): """Spherical differential representation.""" d2_r: ct.BatchableAcc = eqx.field( @@ -735,94 +229,3 @@ class SphericalAcceleration(AbstractAcceleration3D): @classmethod def integral_cls(cls) -> type[SphericalVelocity]: return SphericalVelocity - - -@final -class MathSphericalAcceleration(AbstractAcceleration3D): - """Spherical acceleration representation.""" - - d2_r: ct.BatchableAcc = eqx.field( - converter=partial(Quantity["acceleration"].constructor, dtype=float) - ) - r"""Radial acceleration :math:`d^2r/dt^2 \in [-\infty, \infty].""" - - d2_theta: ct.BatchableAngularAcc = eqx.field( - converter=partial(Quantity["angular acceleration"].constructor, dtype=float) - ) - r"""Azimuthal acceleration :math:`d^2\theta/dt^2 \in [-\infty, \infty].""" - - d2_phi: ct.BatchableAngularAcc = eqx.field( - converter=partial(Quantity["angular acceleration"].constructor, dtype=float) - ) - r"""Inclination acceleration :math:`d^2\phi/dt^2 \in [-\infty, \infty].""" - - @classproperty - @classmethod - def integral_cls(cls) -> type[MathSphericalVelocity]: - return MathSphericalVelocity - - -@final -class LonLatSphericalAcceleration(AbstractAcceleration3D): - """Spherical acceleration representation.""" - - d2_lon: ct.BatchableAngularAcc = eqx.field( - converter=partial(Quantity["angular acceleration"].constructor, dtype=float) - ) - r"""Longitude acceleration :math:`d^2lon/dt^2 \in [-\infty, \infty].""" - - d2_lat: ct.BatchableAngularAcc = eqx.field( - converter=partial(Quantity["angular acceleration"].constructor, dtype=float) - ) - r"""Latitude acceleration :math:`d^2lat/dt^2 \in [-\infty, \infty].""" - - d2_distance: ct.BatchableAcc = eqx.field( - converter=partial(Quantity["acceleration"].constructor, dtype=float) - ) - r"""Radial acceleration :math:`d^2r/dt^2 \in [-\infty, \infty].""" - - @classproperty - @classmethod - def integral_cls(cls) -> type[LonLatSphericalVelocity]: - return LonLatSphericalVelocity - - -##################################################################### - - -@register(jax.lax.mul_p) # type: ignore[misc] -def _mul_p_vmsph( - lhs: ArrayLike, rhs: MathSphericalPosition, / -) -> MathSphericalPosition: - """Scale the polar position by a scalar. - - Examples - -------- - >>> from unxt import Quantity - >>> import coordinax as cx - >>> import quaxed.numpy as jnp - - >>> v = cx.MathSphericalPosition(r=Quantity(3, "kpc"), - ... theta=Quantity(90, "deg"), - ... phi=Quantity(0, "deg")) - - >>> jnp.linalg.vector_norm(v, axis=-1) - Quantity['length'](Array(3., dtype=float32), unit='kpc') - - >>> nv = jnp.multiply(2, v) - >>> nv - MathSphericalPosition( - r=Distance(value=f32[], unit=Unit("kpc")), - theta=Quantity[...](value=f32[], unit=Unit("deg")), - phi=Quantity[...](value=f32[], unit=Unit("deg")) - ) - >>> nv.r - Distance(Array(6., dtype=float32), unit='kpc') - - """ - # Validation - lhs = eqx.error_if( - lhs, any(jax.numpy.shape(lhs)), f"must be a scalar, not {type(lhs)}" - ) - # Scale the radial distance - return replace(rhs, r=lhs * rhs.r) diff --git a/src/coordinax/_src/d3/transform.py b/src/coordinax/_src/d3/transform.py index a837883..5c14c96 100644 --- a/src/coordinax/_src/d3/transform.py +++ b/src/coordinax/_src/d3/transform.py @@ -10,18 +10,16 @@ from unxt import Quantity from .base import AbstractPosition3D, AbstractVelocity3D +from .base_spherical import AbstractSphericalPosition from .cartesian import CartesianAcceleration3D, CartesianPosition3D, CartesianVelocity3D from .cylindrical import CylindricalPosition, CylindricalVelocity -from .spherical import ( - AbstractSphericalPosition, +from .lonlatspherical import ( LonCosLatSphericalVelocity, LonLatSphericalPosition, LonLatSphericalVelocity, - MathSphericalPosition, - MathSphericalVelocity, - SphericalPosition, - SphericalVelocity, ) +from .mathspherical import MathSphericalPosition, MathSphericalVelocity +from .spherical import SphericalPosition, SphericalVelocity from coordinax._src.base import AbstractPosition ############################################################################### diff --git a/src/coordinax/_src/transform/d1.py b/src/coordinax/_src/transform/d1.py index cba5114..ce7525d 100644 --- a/src/coordinax/_src/transform/d1.py +++ b/src/coordinax/_src/transform/d1.py @@ -15,7 +15,8 @@ from coordinax._src.d2.polar import PolarPosition from coordinax._src.d3.cartesian import CartesianPosition3D from coordinax._src.d3.cylindrical import CylindricalPosition -from coordinax._src.d3.spherical import MathSphericalPosition, SphericalPosition +from coordinax._src.d3.mathspherical import MathSphericalPosition +from coordinax._src.d3.spherical import SphericalPosition # ============================================================================= # CartesianPosition1D diff --git a/src/coordinax/_src/transform/d2.py b/src/coordinax/_src/transform/d2.py index e31083b..697a000 100644 --- a/src/coordinax/_src/transform/d2.py +++ b/src/coordinax/_src/transform/d2.py @@ -19,7 +19,8 @@ from coordinax._src.d3.base import AbstractPosition3D from coordinax._src.d3.cartesian import CartesianPosition3D from coordinax._src.d3.cylindrical import CylindricalPosition -from coordinax._src.d3.spherical import MathSphericalPosition, SphericalPosition +from coordinax._src.d3.mathspherical import MathSphericalPosition +from coordinax._src.d3.spherical import SphericalPosition from coordinax._src.exceptions import IrreversibleDimensionChange diff --git a/src/coordinax/_src/transform/d3.py b/src/coordinax/_src/transform/d3.py index 8a98e40..a94e0fb 100644 --- a/src/coordinax/_src/transform/d3.py +++ b/src/coordinax/_src/transform/d3.py @@ -16,7 +16,8 @@ from coordinax._src.d2.polar import PolarPosition from coordinax._src.d3.cartesian import CartesianPosition3D from coordinax._src.d3.cylindrical import CylindricalPosition -from coordinax._src.d3.spherical import MathSphericalPosition, SphericalPosition +from coordinax._src.d3.mathspherical import MathSphericalPosition +from coordinax._src.d3.spherical import SphericalPosition from coordinax._src.exceptions import IrreversibleDimensionChange # =============================================================================