diff --git a/src/coordinax/_coordinax/base.py b/src/coordinax/_coordinax/base.py index da1f9b2..8af5fd8 100644 --- a/src/coordinax/_coordinax/base.py +++ b/src/coordinax/_coordinax/base.py @@ -15,7 +15,6 @@ from typing import TYPE_CHECKING, Any, Literal, NoReturn, TypeVar import astropy.units as u -import equinox as eqx import jax import jax.numpy as jnp import numpy as np @@ -807,7 +806,7 @@ def __str__(self) -> str: # TODO: move to the class in py3.11+ -@AbstractVector.constructor._f.dispatch # noqa: SLF001 +@AbstractVector.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001 def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVector: """Construct a vector from another vector. @@ -888,94 +887,3 @@ def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVe return obj return cls(**dict(field_items(obj))) - - -@AbstractVector.constructor._f.dispatch # noqa: SLF001 -def constructor( - cls: type[AbstractVector], obj: Mapping[str, u.Quantity], / -) -> AbstractVector: - """Construct a vector from a mapping. - - Parameters - ---------- - cls : type[AbstractVector] - The vector class. - obj : Mapping[str, `astropy.units.Quantity`] - The mapping of components. - - Examples - -------- - >>> import jax.numpy as jnp - >>> from astropy.units import Quantity - >>> import coordinax as cx - - >>> xs = {"x": Quantity(1, "m"), "y": Quantity(2, "m"), "z": Quantity(3, "m")} - >>> vec = cx.CartesianPosition3D.constructor(xs) - >>> vec - CartesianPosition3D( - x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), - y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), - z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")) - ) - - >>> xs = {"x": Quantity([1, 2], "m"), "y": Quantity([3, 4], "m"), - ... "z": Quantity([5, 6], "m")} - >>> vec = cx.CartesianPosition3D.constructor(xs) - >>> vec - CartesianPosition3D( - x=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")), - y=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")), - z=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")) - ) - - """ - return cls(**obj) - - -# TODO: move to the class in py3.11+ -@AbstractVector.constructor._f.dispatch # noqa: SLF001 -def constructor(cls: type[AbstractVector], obj: u.Quantity, /) -> AbstractVector: - """Construct a vector from an Astropy Quantity array. - - The array is expected to have the components as the last dimension. - - Parameters - ---------- - cls : type[AbstractVector] - The vector class. - obj : Quantity[Any, (*#batch, N), "..."] - The array of components. - - Examples - -------- - >>> import jax.numpy as jnp - >>> from astropy.units import Quantity - >>> import coordinax as cx - - >>> xs = Quantity([1, 2, 3], "meter") - >>> vec = cx.CartesianPosition3D.constructor(xs) - >>> vec - CartesianPosition3D( - x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), - y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), - z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")) - ) - - >>> xs = Quantity(jnp.array([[1, 2, 3], [4, 5, 6]]), "meter") - >>> vec = cx.CartesianPosition3D.constructor(xs) - >>> vec - CartesianPosition3D( - x=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")), - y=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")), - z=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")) - ) - >>> vec.x - Quantity['length'](Array([1., 4.], dtype=float32), unit='m') - - """ - _ = eqx.error_if( - obj, - obj.shape[-1] != len(fields(cls)), - f"Cannot construct {cls} from array with shape {obj.shape}.", - ) - return cls(**{f.name: obj[..., i] for i, f in enumerate(fields(cls))}) diff --git a/src/coordinax/_interop/coordinax_interop_astropy/constructors.py b/src/coordinax/_interop/coordinax_interop_astropy/constructors.py index b40c89f..defa8de 100644 --- a/src/coordinax/_interop/coordinax_interop_astropy/constructors.py +++ b/src/coordinax/_interop/coordinax_interop_astropy/constructors.py @@ -3,6 +3,8 @@ __all__: list[str] = [] +from collections.abc import Mapping +from dataclasses import fields import astropy.coordinates as apyc import astropy.units as u @@ -14,6 +16,51 @@ ##################################################################### +@cx.AbstractVector.constructor._f.dispatch # noqa: SLF001 +def constructor( + cls: type[cx.AbstractVector], obj: Mapping[str, u.Quantity], / +) -> cx.AbstractVector: + """Construct a vector from a mapping. + + Parameters + ---------- + cls : type[AbstractVector] + The vector class. + obj : Mapping[str, `astropy.units.Quantity`] + The mapping of components. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from astropy.units import Quantity + >>> import coordinax as cx + + >>> xs = {"x": Quantity(1, "m"), "y": Quantity(2, "m"), "z": Quantity(3, "m")} + >>> vec = cx.CartesianPosition3D.constructor(xs) + >>> vec + CartesianPosition3D( + x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), + y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), + z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")) + ) + + >>> xs = {"x": Quantity([1, 2], "m"), "y": Quantity([3, 4], "m"), + ... "z": Quantity([5, 6], "m")} + >>> vec = cx.CartesianPosition3D.constructor(xs) + >>> vec + CartesianPosition3D( + x=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")), + y=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")), + z=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")) + ) + + """ + return cls(**obj) + + +##################################################################### + + @cx.AbstractPosition3D.constructor._f.dispatch(precedence=-1) # noqa: SLF001 def constructor( cls: type[cx.AbstractPosition3D], obj: apyc.CartesianRepresentation, / @@ -438,7 +485,54 @@ def constructor( ##################################################################### -# TODO: move to the class in py3.11+ +@cx.AbstractVector.constructor._f.dispatch # noqa: SLF001 +def constructor(cls: type[cx.AbstractVector], obj: u.Quantity, /) -> cx.AbstractVector: + """Construct a vector from an Astropy Quantity array. + + The array is expected to have the components as the last dimension. + + Parameters + ---------- + cls : type[AbstractVector] + The vector class. + obj : Quantity[Any, (*#batch, N), "..."] + The array of components. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from astropy.units import Quantity + >>> import coordinax as cx + + >>> xs = Quantity([1, 2, 3], "meter") + >>> vec = cx.CartesianPosition3D.constructor(xs) + >>> vec + CartesianPosition3D( + x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), + y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), + z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")) + ) + + >>> xs = Quantity(jnp.array([[1, 2, 3], [4, 5, 6]]), "meter") + >>> vec = cx.CartesianPosition3D.constructor(xs) + >>> vec + CartesianPosition3D( + x=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")), + y=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")), + z=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")) + ) + >>> vec.x + Quantity['length'](Array([1., 4.], dtype=float32), unit='m') + + """ + _ = eqx.error_if( + obj, + obj.shape[-1] != len(fields(cls)), + f"Cannot construct {cls} from array with shape {obj.shape}.", + ) + return cls(**{f.name: obj[..., i] for i, f in enumerate(fields(cls))}) + + @cx.FourVector.constructor._f.dispatch # noqa: SLF001 def constructor( cls: type[cx.FourVector], obj: Shaped[u.Quantity, "*batch 4"], /