Skip to content

Commit

Permalink
refactor: Constructors (#170)
Browse files Browse the repository at this point in the history
* feat: more constructors
* refactor: re-arrange
* refactor: consolidate interop

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Aug 29, 2024
1 parent b4aced9 commit 30e3b84
Show file tree
Hide file tree
Showing 16 changed files with 1,099 additions and 437 deletions.
6 changes: 2 additions & 4 deletions src/coordinax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@
__all__ += utils.__all__

# Interoperability
from . import _interop # noqa: E402
from ._coordinax import compat # noqa: E402

# Astropy
from ._interop import coordinax_interop_astropy # noqa: E402

# Runtime Typechecker
install_import_hook("coordinax", RUNTIME_TYPECHECKER)

Expand All @@ -87,6 +85,6 @@
dn,
funcs,
RUNTIME_TYPECHECKER,
coordinax_interop_astropy,
compat,
_interop,
)
130 changes: 58 additions & 72 deletions src/coordinax/_coordinax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import quaxed.array_api as xp
import quaxed.lax as qlax
from dataclassish import field_items, field_values, replace
from unxt import Quantity, unitsystem
from unxt import AbstractQuantity, Quantity, unitsystem

from .typing import Unit
from .utils import classproperty, full_shaped
Expand Down Expand Up @@ -62,7 +62,7 @@ class AbstractVector(ArrayValue): # type: ignore[misc]
@classmethod
@dispatch
def constructor(
cls: "type[AbstractVector]", obj: Mapping[str, Quantity], /
cls: "type[AbstractVector]", obj: Mapping[str, AbstractQuantity], /
) -> "AbstractVector":
"""Construct a vector from a mapping.
Expand Down Expand Up @@ -99,61 +99,6 @@ def constructor(
"""
return cls(**obj)

@classmethod
@dispatch
def constructor(cls: "type[AbstractVector]", obj: Quantity, /) -> "AbstractVector":
"""Construct a vector from a Quantity array.
The array is expected to have the components as the last axis.
Parameters
----------
obj : Quantity[Any, (*#batch, N), "..."]
The array with batches and N components.
Examples
--------
>>> import jax.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
(The 1D cases are handled by a different dispatch)
>>> vec = cx.CartesianPosition2D.constructor(Quantity([1, 2], "meter"))
>>> vec
CartesianPosition2D(
x=Quantity[...](value=f32[], unit=Unit("m")),
y=Quantity[...](value=f32[], unit=Unit("m"))
)
>>> vec = cx.CartesianPosition3D.constructor(Quantity([1, 2, 3], "meter"))
>>> vec
CartesianPosition3D(
x=Quantity[...](value=f32[], unit=Unit("m")),
y=Quantity[...](value=f32[], unit=Unit("m")),
z=Quantity[...](value=f32[], unit=Unit("m"))
)
>>> xs = Quantity(jnp.array([[1, 2, 3], [4, 5, 6]]), "meter")
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[...](value=f32[2], unit=Unit("m")),
y=Quantity[...](value=f32[2], unit=Unit("m")),
z=Quantity[...](value=f32[2], unit=Unit("m"))
)
>>> vec.x
Quantity['length'](Array([1., 4.], dtype=float32), unit='m')
"""
obj = eqx.error_if(
obj,
obj.shape[-1] != len(fields(cls)),
f"Cannot construct {cls} from array with shape {obj.shape}.",
)
comps = {f.name: obj[..., i] for i, f in enumerate(fields(cls))}
return cls(**comps)

@classmethod
@dispatch
def constructor(
Expand Down Expand Up @@ -863,40 +808,81 @@ def __str__(self) -> str:

# TODO: move to the class in py3.11+
@AbstractVector.constructor._f.dispatch # noqa: SLF001
def constructor( # noqa: D417
cls: type[AbstractVector], obj: AbstractVector, /
) -> AbstractVector:
def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVector:
"""Construct a vector from another vector.
Parameters
----------
obj : :class:`coordinax.AbstractVector`
cls : type[AbstractVector], positional-only
The vector class.
obj : :class:`coordinax.AbstractVector`, positional-only
The vector to construct from.
Examples
--------
>>> import jax.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
>>> x, y, z = Quantity(1, "meter"), Quantity(2, "meter"), Quantity(3, "meter")
>>> vec = cx.CartesianPosition3D(x=x, y=y, z=z)
>>> cart = cx.CartesianPosition3D.constructor(vec)
Positions:
>>> q = cx.CartesianPosition3D.constructor([1, 2, 3], "km")
>>> cart = cx.CartesianPosition3D.constructor(q)
>>> cart
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m"))
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("km")),
y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("km")),
z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("km"))
)
>>> cart.x
Quantity['length'](Array(1., dtype=float32), unit='m')
>>> cx.AbstractPosition3D.constructor(cart) is cart
True
>>> sph = cart.represent_as(cx.SphericalPosition)
>>> cx.AbstractPosition3D.constructor(sph) is sph
True
>>> cyl = cart.represent_as(cx.CylindricalPosition)
>>> cx.AbstractPosition3D.constructor(cyl) is cyl
True
Velocities:
>>> p = cx.CartesianVelocity3D.constructor([1, 2, 3], "km/s")
>>> cart = cx.CartesianVelocity3D.constructor(p)
>>> cx.AbstractVelocity3D.constructor(cart) is cart
True
>>> sph = cart.represent_as(cx.SphericalVelocity, q)
>>> cx.AbstractVelocity3D.constructor(sph) is sph
True
>>> cyl = cart.represent_as(cx.CylindricalVelocity, q)
>>> cx.AbstractVelocity3D.constructor(cyl) is cyl
True
Accelerations:
>>> p = cx.CartesianVelocity3D.constructor([1, 1, 1], "km/s")
>>> cart = cx.CartesianAcceleration3D.constructor([1, 2, 3], "km/s2")
>>> cx.AbstractAcceleration3D.constructor(cart) is cart
True
>>> sph = cart.represent_as(cx.SphericalAcceleration, p, q)
>>> cx.AbstractAcceleration3D.constructor(sph) is sph
True
>>> cyl = cart.represent_as(cx.CylindricalAcceleration, p, q)
>>> cx.AbstractAcceleration3D.constructor(cyl) is cyl
True
"""
if not isinstance(obj, cls):
msg = f"Cannot construct {cls} from {type(obj)}."
raise TypeError(msg)

# avoid copying if the types are the same. Isinstance is not strict
# Avoid copying if the types are the same. `isinstance` is not strict
# enough, so we use type() instead.
if type(obj) is cls: # pylint: disable=unidiomatic-typecheck
return obj
Expand Down
5 changes: 2 additions & 3 deletions src/coordinax/_coordinax/base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def __neg__(self) -> "Self":
>>> d2r = cx.RadialAcceleration.constructor([1], "m/s2")
>>> -d2r
RadialAcceleration(
d2_r=Quantity[PhysicalType('acceleration')](value=i32[], unit=Unit("m / s2")) )
RadialAcceleration( d2_r=Quantity[...](value=i32[], unit=Unit("m / s2")) )
>>> d2p = cx.PolarAcceleration(Quantity(1, "m/s2"), Quantity(1, "mas/yr2"))
>>> negd2p = -d2p
Expand All @@ -107,7 +106,7 @@ def __neg__(self) -> "Self":
>>> negd2p.d2_phi
Quantity['angular acceleration'](Array(-1., dtype=float32), unit='mas / yr2')
""" # noqa: E501
"""
return replace(self, **{k: -v for k, v in field_items(self)})

# ===============================================================
Expand Down
94 changes: 83 additions & 11 deletions src/coordinax/_coordinax/d1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,37 @@ def differential_cls(cls) -> type["AbstractVelocity1D"]:
raise NotImplementedError


# TODO: move to the class in py3.11+
@AbstractVector.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001
# -------------------------------------------------------------------


@AbstractPosition1D.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001
def constructor(
cls: type[AbstractPosition1D],
x: Shaped[Quantity["length"], "*batch"] | Shaped[Quantity["length"], "*batch 1"],
obj: Shaped[Quantity["length"], "*batch"] | Shaped[Quantity["length"], "*batch 1"],
/,
) -> AbstractPosition1D:
"""Construct a 1D vector.
"""Construct a 1D position.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> cx.CartesianPosition1D.constructor(Quantity(1, "meter"))
CartesianPosition1D(
x=Quantity[...](value=f32[], unit=Unit("m"))
)
CartesianPosition1D( x=Quantity[...](value=f32[], unit=Unit("m")) )
>>> cx.CartesianPosition1D.constructor(Quantity([1], "meter"))
CartesianPosition1D(
x=Quantity[...](value=f32[], unit=Unit("m"))
)
CartesianPosition1D( x=Quantity[...](value=f32[], unit=Unit("m")) )
>>> cx.RadialPosition.constructor(Quantity(1, "meter"))
RadialPosition(r=Distance(value=f32[], unit=Unit("m")))
>>> cx.RadialPosition.constructor(Quantity([1], "meter"))
RadialPosition(r=Distance(value=f32[], unit=Unit("m")))
"""
return cls(**{fields(cls)[0].name: jnp.atleast_1d(x)[..., 0]})
comps = {f.name: jnp.atleast_1d(obj)[..., i] for i, f in enumerate(fields(cls))}
return cls(**comps)


#####################################################################
Expand Down Expand Up @@ -91,6 +96,39 @@ def differential_cls(cls) -> type[AbstractAcceleration]:
raise NotImplementedError


# -------------------------------------------------------------------


@AbstractVelocity1D.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001
def constructor(
cls: type[AbstractVelocity1D],
obj: Shaped[Quantity["speed"], "*batch"] | Shaped[Quantity["speed"], "*batch 1"],
/,
) -> AbstractVelocity1D:
"""Construct a 1D velocity.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> cx.CartesianVelocity1D.constructor(Quantity(1, "m/s"))
CartesianVelocity1D( d_x=Quantity[...]( value=i32[], unit=Unit("m / s") ) )
>>> cx.CartesianVelocity1D.constructor(Quantity([1], "m/s"))
CartesianVelocity1D( d_x=Quantity[...]( value=i32[], unit=Unit("m / s") ) )
>>> cx.RadialVelocity.constructor(Quantity(1, "m/s"))
RadialVelocity( d_r=Quantity[...]( value=i32[], unit=Unit("m / s") ) )
>>> cx.RadialVelocity.constructor(Quantity([1], "m/s"))
RadialVelocity( d_r=Quantity[...]( value=i32[], unit=Unit("m / s") ) )
"""
comps = {f.name: jnp.atleast_1d(obj)[..., i] for i, f in enumerate(fields(cls))}
return cls(**comps)


#####################################################################


Expand All @@ -109,3 +147,37 @@ def _cartesian_cls(cls) -> type[AbstractVector]:
@abstractmethod
def integral_cls(cls) -> type[AbstractVelocity1D]:
raise NotImplementedError


# -------------------------------------------------------------------


@AbstractAcceleration1D.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001
def constructor(
cls: type[AbstractAcceleration1D],
obj: Shaped[Quantity["acceleration"], "*batch"]
| Shaped[Quantity["acceleration"], "*batch 1"],
/,
) -> AbstractAcceleration1D:
"""Construct a 1D acceleration.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> cx.CartesianAcceleration1D.constructor(Quantity(1, "m/s2"))
CartesianAcceleration1D( d2_x=Quantity[...](value=i32[], unit=Unit("m / s2")) )
>>> cx.CartesianAcceleration1D.constructor(Quantity([1], "m/s2"))
CartesianAcceleration1D( d2_x=Quantity[...](value=i32[], unit=Unit("m / s2")) )
>>> cx.RadialAcceleration.constructor(Quantity(1, "m/s2"))
RadialAcceleration( d2_r=Quantity[...](value=i32[], unit=Unit("m / s2")) )
>>> cx.RadialAcceleration.constructor(Quantity([1], "m/s2"))
RadialAcceleration( d2_r=Quantity[...](value=i32[], unit=Unit("m / s2")) )
"""
comps = {f.name: jnp.atleast_1d(obj)[..., i] for i, f in enumerate(fields(cls))}
return cls(**comps)
6 changes: 6 additions & 0 deletions src/coordinax/_coordinax/d2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def differential_cls(cls) -> type["AbstractVelocity2D"]:
raise NotImplementedError


#####################################################################


class AbstractVelocity2D(AbstractVelocity):
"""Abstract representation of 2D vector differentials."""

Expand All @@ -52,6 +55,9 @@ def differential_cls(cls) -> type[AbstractAcceleration]:
raise NotImplementedError


#####################################################################


class AbstractAcceleration2D(AbstractAcceleration):
"""Abstract representation of 2D vector accelerations."""

Expand Down
Loading

0 comments on commit 30e3b84

Please sign in to comment.