From 29563555bcc9c1fc40316c5eca0ae0baf38e0b1a Mon Sep 17 00:00:00 2001 From: mattsignorelli Date: Wed, 7 Aug 2024 13:37:00 -0400 Subject: [PATCH] Jacobian takes AbstractArray, broadcastable overload added --- src/getset.jl | 9 +++++---- src/tps.jl | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/getset.jl b/src/getset.jl index a32c4e1..3227f2c 100644 --- a/src/getset.jl +++ b/src/getset.jl @@ -372,7 +372,7 @@ function gradient(t::TPS; include_params=false) end """ - GTPSA.jacobian!(result, m::AbstractVector{<:TPS}; include_params=false) + GTPSA.jacobian!(result, m::AbstractArray{<:TPS}; include_params=false) Extracts the first-order partial derivatives (evaluated at 0) from the Vector of TPSs. and fills the `result` matrix in-place. The partial derivatives wrt the parameters will @@ -387,12 +387,13 @@ in the TPSs. ### Output - `result` -- Matrix to fill with the Jacobian of `m`, must be 1-based indexing """ -function jacobian!(result, m::AbstractVector{<:TPS}; include_params=false) +function jacobian!(result, m::AbstractArray{<:TPS}; include_params=false) Base.require_one_based_indexing(result, m) n = numvars(first(m)) if include_params n += numparams(first(m)) end + if size(result) != (length(m), n) error("Incorrect size for result") end @@ -407,7 +408,7 @@ function jacobian!(result, m::AbstractVector{<:TPS}; include_params=false) end """ - GTPSA.jacobian(m::AbstractVector{<:TPS}; include_params=false) + GTPSA.jacobian(m::AbstractArray{<:TPS}; include_params=false) Extracts the first-order partial derivatives (evaluated at 0) from the Vector of TPSs. The partial derivatives wrt the parameters will also be extracted when the `include_params` @@ -421,7 +422,7 @@ the first-order monomial coefficients already in the TPSs. ### Output - `J` -- Jacobian of `m` """ -function jacobian(m::AbstractVector{<:TPS}; include_params=false) +function jacobian(m::AbstractArray{<:TPS}; include_params=false) Base.require_one_based_indexing(m) n = numvars(first(m)) if include_params diff --git a/src/tps.jl b/src/tps.jl index 5c8de3b..3f4ba92 100644 --- a/src/tps.jl +++ b/src/tps.jl @@ -117,6 +117,8 @@ eps(::Type{TPS{T}}) where {T} = eps(T) floatmin(::Type{TPS{T}}) where {T} = floatmin(T) floatmax(::Type{TPS{T}}) where {T} = floatmax(T) +Base.broadcastable(o::TPS) = Ref(o) +