Skip to content

Commit

Permalink
Jacobian takes AbstractArray, broadcastable overload added
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsignorelli committed Aug 7, 2024
1 parent 838ad34 commit 2956355
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/getset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ function gradient(t::TPS; include_params=false)
end

"""
GTPSA.jacobian!(result, m::AbstractVector{<:TPS}; include_params=false)
GTPSA.jacobian!(result, m::AbstractArray{<:TPS}; include_params=false)
Extracts the first-order partial derivatives (evaluated at 0) from the Vector of TPSs.
and fills the `result` matrix in-place. The partial derivatives wrt the parameters will
Expand All @@ -387,12 +387,13 @@ in the TPSs.
### Output
- `result` -- Matrix to fill with the Jacobian of `m`, must be 1-based indexing
"""
function jacobian!(result, m::AbstractVector{<:TPS}; include_params=false)
function jacobian!(result, m::AbstractArray{<:TPS}; include_params=false)
Base.require_one_based_indexing(result, m)
n = numvars(first(m))
if include_params
n += numparams(first(m))
end

if size(result) != (length(m), n)
error("Incorrect size for result")
end
Expand All @@ -407,7 +408,7 @@ function jacobian!(result, m::AbstractVector{<:TPS}; include_params=false)
end

"""
GTPSA.jacobian(m::AbstractVector{<:TPS}; include_params=false)
GTPSA.jacobian(m::AbstractArray{<:TPS}; include_params=false)
Extracts the first-order partial derivatives (evaluated at 0) from the Vector of TPSs.
The partial derivatives wrt the parameters will also be extracted when the `include_params`
Expand All @@ -421,7 +422,7 @@ the first-order monomial coefficients already in the TPSs.
### Output
- `J` -- Jacobian of `m`
"""
function jacobian(m::AbstractVector{<:TPS}; include_params=false)
function jacobian(m::AbstractArray{<:TPS}; include_params=false)
Base.require_one_based_indexing(m)
n = numvars(first(m))
if include_params
Expand Down
2 changes: 2 additions & 0 deletions src/tps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ eps(::Type{TPS{T}}) where {T} = eps(T)
floatmin(::Type{TPS{T}}) where {T} = floatmin(T)
floatmax(::Type{TPS{T}}) where {T} = floatmax(T)

Base.broadcastable(o::TPS) = Ref(o)




Expand Down

0 comments on commit 2956355

Please sign in to comment.