Skip to content

Commit

Permalink
clean up of basic operators, added in-place add,sub,mul,div
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsignorelli committed Apr 18, 2024
1 parent fba8920 commit 5e5f235
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 218 deletions.
4 changes: 2 additions & 2 deletions docs/src/man/o_all.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ conj, angle, complex, promote_rule, getindex, setindex!, ==, <,
`zeros` and `ones` are overloaded from Base so that allocated `TPS`/`ComplexTPS`s are placed in each element. If we didn't explicity overload these functions, every element would correspond to the exact same heap-allocated TPS, which is problematic when setting individual monomial coefficients of the same TPS.

`GTPSA.jl` overloads (and exports) the following functions from the corresponding packages:
**`LinearAlgebra`**: `norm`
**`LinearAlgebra`**: `norm`, `mul!`
**`SpecialFunctions`**: `erf`, `erfc`

`GTPSA.jl` also provides the following functions NOT included in Base or any of the above packages:
```
unit, sinhc, asinc, asinhc, polar, rect
add!, sub!, div!, unit, sinhc, asinc, asinhc, polar, rect
```

If there is a mathematical function in Base which you'd like and is not included in the above list, feel free to submit an [issue](https://github.com/bmad-sim/GTPSA.jl/issues).
6 changes: 5 additions & 1 deletion src/GTPSA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ import Base: +,
show,
copy!

import LinearAlgebra: norm
import LinearAlgebra: norm, mul!
import SpecialFunctions: erf, erfc

using GTPSA_jll, Printf, PrettyTables
Expand All @@ -83,6 +83,10 @@ export
rect,
clear!,
complex!,
add!,
sub!,
mul!,
div!,

# Monomial as TPS creators:
vars,
Expand Down
301 changes: 89 additions & 212 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ function copy!(t::TPS, t1::TPS)
mad_tpsa_copy!(t1.tpsa, t.tpsa)
end

function copy!(ct::ComplexTPS, t1::TPS)
mad_ctpsa_cplx!(t1.tpsa, Base.unsafe_convert(Ptr{RTPSA}, C_NULL), ct.tpsa)
end

function copy!(ct::ComplexTPS, ct1::ComplexTPS)
mad_ctpsa_copy!(ct1.tpsa, ct.tpsa)
end
Expand Down Expand Up @@ -249,255 +253,128 @@ function isequal(t1::TPS, ct1::ComplexTPS)::Bool
return isequal(ct1, t1)
end


# --- add ---
# TPS:
function +(t1::TPS, t2::TPS)::TPS
t = zero(t1)
mad_tpsa_add!(t1.tpsa, t2.tpsa, t.tpsa)
return t
end
# TPS, TPS:
add!(a::Ptr{RTPSA}, b::Ptr{RTPSA}, c::Ptr{RTPSA}) = mad_tpsa_add!(a, b, c)
add!(a::Ptr{CTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_add!(a, b, c)
add!(a::Ptr{RTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_addt!(b, a, c)
add!(a::Ptr{CTPSA}, b::Ptr{RTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_addt!(a, b, c)

function +(t1::TPS, a::Real)::TPS
t = TPS(t1)
mad_tpsa_set0!(t.tpsa, 1., convert(Float64,a))
return t
end

function +(a::Real, t1::TPS)::TPS
return t1 + a
end

# ComplexTPS:
function +(ct1::ComplexTPS, ct2::ComplexTPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_add!(ct1.tpsa, ct2.tpsa, ct.tpsa)
return ct
end
add!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, t2::Union{TPS,ComplexTPS}) = add!(t1.tpsa, t2.tpsa, t.tpsa)

function +(ct1::ComplexTPS, a::Number)::ComplexTPS
ct = ComplexTPS(ct1)
mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64, a))
return ct
end
# TPS, scalar:
set0!(t::Ptr{RTPSA}, a::Float64, b::Float64) = mad_tpsa_set0!(t, a, b)
set0!(t::Ptr{CTPSA}, a::ComplexF64, b::ComplexF64) = mad_ctpsa_set0!(t, a, b)

function +(a::Number, ct1::ComplexTPS)::ComplexTPS
return ct1 + a
function add!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, a::Number)
copy!(t, t1)
set0!(t.tpsa, convert(numtype(t), 1), convert(numtype(t), a))
end

# TPS to ComplexTPS promotion, w/o creating temp ComplexTPS:
function +(ct1::ComplexTPS, t1::TPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_addt!(ct1.tpsa, t1.tpsa, ct.tpsa)
return ct
end
add!(t::Union{TPS,ComplexTPS}, a::Number, t1::Union{TPS,ComplexTPS}) = add!(t, t1, a)

function +(t1::TPS, ct1::ComplexTPS)::ComplexTPS
return ct1 + t1
for t = ((TPS,TPS),(TPS,Real),(Real,TPS),(TPS,Complex),(Complex,TPS),(ComplexTPS,TPS),(TPS,ComplexTPS),(ComplexTPS,ComplexTPS),(ComplexTPS, Number), (Number, ComplexTPS))
@eval begin
function +(t1::$t[1], t2::$t[2])
use = $(t[1] == TPS || t[1] == ComplexTPS ? :t1 : :t2)
t = (promote_type(typeof(t1),typeof(t2)))(use=use)
add!(t, t1, t2)
return t
end

function +(t1::TPS, a::Complex)::ComplexTPS
ct = ComplexTPS(t1)
mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64, a))
return ct
end

function +(a::Complex, t1::TPS)::ComplexTPS
return t1 + a
end


# --- sub ---
# TPS:
function -(t1::TPS, t2::TPS)::TPS
t = zero(t1)
mad_tpsa_sub!(t1.tpsa, t2.tpsa, t.tpsa)
return t
end

# TPS, TPS:
sub!(a::Ptr{RTPSA}, b::Ptr{RTPSA}, c::Ptr{RTPSA}) = mad_tpsa_sub!(a, b, c)
sub!(a::Ptr{CTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_sub!(a, b, c)
sub!(a::Ptr{CTPSA}, b::Ptr{RTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_subt!(a, b, c)
sub!(a::Ptr{RTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_tsub!(a, b, c)

function -(t1::TPS, a::Real)::TPS
t = TPS(t1)
mad_tpsa_set0!(t.tpsa, 1., convert(Float64, -a))
return t
end
sub!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, t2::Union{TPS,ComplexTPS}) = sub!(t1.tpsa, t2.tpsa, t.tpsa)

function -(a::Real, t1::TPS)::TPS
t = zero(t1)
mad_tpsa_scl!(t1.tpsa, -1., t.tpsa)
mad_tpsa_set0!(t.tpsa, 1., convert(Float64, a))
return t
# TPS, scalar:
scl!(a::Ptr{RTPSA}, v::Float64, c::Ptr{RTPSA}) = mad_tpsa_scl!(a, v, c)
scl!(a::Ptr{CTPSA}, v::ComplexF64, c::Ptr{CTPSA}) = mad_ctpsa_scl!(a, v, c)
function scl!(a::Ptr{RTPSA}, v::ComplexF64, c::Ptr{CTPSA})
mad_ctpsa_cplx!(a, Base.unsafe_convert(Ptr{RTPSA},C_NULL), c)
scl!(c, v, c)
end

# ComplexTPS:
function -(ct1::ComplexTPS, ct2::ComplexTPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_sub!(ct1.tpsa, ct2.tpsa, ct.tpsa)
return ct
end

function -(ct1::ComplexTPS, a::Number)::ComplexTPS
ct = ComplexTPS(ct1)
mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64, -a))
return ct
end

function -(a::Number, ct1::ComplexTPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_scl!(ct1.tpsa, convert(ComplexF64, -1), ct.tpsa)
mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64,a))
return ct
end
sub!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, a::Number) = add!(t, t1, -a)

# TPS to ComplexTPS promotion, w/o creating temp ComplexTPS:
function -(ct1::ComplexTPS, t1::TPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_subt!(ct1.tpsa, t1.tpsa, ct.tpsa)
return ct
function sub!(t::Union{TPS,ComplexTPS}, a::Number, t1::Union{TPS,ComplexTPS})
scl!(t1.tpsa, convert(numtype(t), -1.), t.tpsa)
set0!(t.tpsa, convert(numtype(t), 1.), convert(numtype(t), a))
end

function -(t1::TPS, ct1::ComplexTPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_tsub!(t1.tpsa, ct1.tpsa, ct.tpsa)
return ct
for t = ((TPS,TPS),(TPS,Real),(Real,TPS),(TPS,Complex),(Complex,TPS),(ComplexTPS,TPS),(TPS,ComplexTPS),(ComplexTPS,ComplexTPS),(ComplexTPS, Number), (Number, ComplexTPS))
@eval begin
function -(t1::$t[1], t2::$t[2])
use = $(t[1] == TPS || t[1] == ComplexTPS ? :t1 : :t2)
t = (promote_type(typeof(t1),typeof(t2)))(use=use)
sub!(t, t1, t2)
return t
end

function -(t1::TPS, a::Complex)::ComplexTPS
ct = ComplexTPS(t1)
mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64, -a))
return ct
end

function -(a::Complex, t1::TPS)::ComplexTPS
ct = ComplexTPS(t1)
mad_ctpsa_scl!(ct.tpsa, convert(ComplexF64, -1), ct.tpsa)
mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64,a))
return ct
end


# --- mul ---
# TPS:
function *(t1::TPS, t2::TPS)::TPS
t = zero(t1)
mad_tpsa_mul!(t1.tpsa, t2.tpsa, t.tpsa)
# TPS, TPS:
mul!(a::Ptr{RTPSA}, b::Ptr{RTPSA}, c::Ptr{RTPSA}) = mad_tpsa_mul!(a, b, c)
mul!(a::Ptr{CTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_mul!(a, b, c)
mul!(a::Ptr{RTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_mult!(b, a, c)
mul!(a::Ptr{CTPSA}, b::Ptr{RTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_mult!(a, b, c)

mul!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, t2::Union{TPS,ComplexTPS}) = mul!(t1.tpsa, t2.tpsa, t.tpsa)

# TPS, scalar:
mul!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, a::Number) = scl!(t1.tpsa, convert(numtype(t), a), t.tpsa)
mul!(t::Union{TPS,ComplexTPS}, a::Number, t1::Union{TPS,ComplexTPS}) = mul!(t, t1, a)

for t = ((TPS,TPS),(TPS,Real),(Real,TPS),(TPS,Complex),(Complex,TPS),(ComplexTPS,TPS),(TPS,ComplexTPS),(ComplexTPS,ComplexTPS),(ComplexTPS, Number), (Number, ComplexTPS))
@eval begin
function *(t1::$t[1], t2::$t[2])
use = $(t[1] == TPS || t[1] == ComplexTPS ? :t1 : :t2)
t = (promote_type(typeof(t1),typeof(t2)))(use=use)
mul!(t, t1, t2)
return t
end

function *(t1::TPS, a::Real)::TPS
t = zero(t1)
mad_tpsa_scl!(t1.tpsa, convert(Float64, a), t.tpsa)
return t
end

function *(a::Real, t1::TPS)::TPS
return t1 * a
end

# ComplexTPS:
function *(ct1::ComplexTPS, ct2::ComplexTPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_mul!(ct1.tpsa, ct2.tpsa, ct.tpsa)
return ct
end

function *(ct1::ComplexTPS, a::Number)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_scl!(ct1.tpsa, convert(ComplexF64,a), ct.tpsa)
return ct
end

function *(a::Number, ct1::ComplexTPS)::ComplexTPS
return ct1 * a
end

# TPS to ComplexTPS promotion, w/o creating temp ComplexTPS:
function *(ct1::ComplexTPS, t1::TPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_mult!(ct1.tpsa, t1.tpsa, ct.tpsa)
return ct
end

function *(t1::TPS, ct1::ComplexTPS)::ComplexTPS
return ct1 * t1
end

function *(t1::TPS, a::Complex)::ComplexTPS
ct = ComplexTPS(t1)
mad_ctpsa_scl!(ct.tpsa, convert(ComplexF64,a), ct.tpsa)
return ct
end

function *(a::Complex, t1::TPS)::ComplexTPS
return t1 * a
end


# --- div ---
# TPS:
function /(t1::TPS, t2::TPS)::TPS
t = zero(t1)
mad_tpsa_div!(t1.tpsa, t2.tpsa, t.tpsa)
div!(a::Ptr{RTPSA}, b::Ptr{RTPSA}, c::Ptr{RTPSA}) = mad_tpsa_div!(a, b, c)
div!(a::Ptr{CTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_div!(a, b, c)
div!(a::Ptr{CTPSA}, b::Ptr{RTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_divt!(a, b, c)
div!(a::Ptr{RTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_tdiv!(a, b, c)

div!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, t2::Union{TPS,ComplexTPS}) = div!(t1.tpsa, t2.tpsa, t.tpsa)

# TPS, scalar:
inv!(a::Ptr{RTPSA}, v::Float64, c::Ptr{RTPSA}) = mad_tpsa_inv!(a, v, c)
inv!(a::Ptr{CTPSA}, v::ComplexF64, c::Ptr{CTPSA}) = mad_ctpsa_inv!(a, v, c)
function inv!(a::Ptr{RTPSA}, v::ComplexF64, c::Ptr{CTPSA})
mad_ctpsa_cplx!(a, Base.unsafe_convert(Ptr{RTPSA}, C_NULL), c)
inv!(c, v, c)
end

div!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, a::Number) = mul!(t, t1, 1/a)
div!(t::Union{TPS,ComplexTPS}, a::Number, t1::Union{TPS,ComplexTPS}) = inv!(t1.tpsa, convert(numtype(t), a), t.tpsa)

for t = ((TPS,TPS),(TPS,Real),(Real,TPS),(TPS,Complex),(Complex,TPS),(ComplexTPS,TPS),(TPS,ComplexTPS),(ComplexTPS,ComplexTPS),(ComplexTPS, Number), (Number, ComplexTPS))
@eval begin
function /(t1::$t[1], t2::$t[2])
use = $(t[1] == TPS || t[1] == ComplexTPS ? :t1 : :t2)
t = (promote_type(typeof(t1),typeof(t2)))(use=use)
div!(t, t1, t2)
return t
end

function /(t1::TPS, a::Real)::TPS
t = zero(t1)
mad_tpsa_scl!(t1.tpsa, convert(Float64, 1/a), t.tpsa)
return t
end

function /(a::Real, t1::TPS)::TPS
t = zero(t1)
mad_tpsa_inv!(t1.tpsa, convert(Float64,a), t.tpsa)
return t
end

# ComplexTPS:
function /(ct1::ComplexTPS, ct2::ComplexTPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_div!(ct1.tpsa, ct2.tpsa, ct.tpsa)
return ct
end

function /(ct1::ComplexTPS, a::Number)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_scl!(ct1.tpsa, convert(ComplexF64, 1/a), ct.tpsa)
return ct
end

function /(a::Number, ct1::ComplexTPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_inv!(ct1.tpsa, convert(ComplexF64, a), ct.tpsa)
return ct
end

# TPS to ComplexTPS promotion, w/o creating temp ComplexTPS:
function /(ct1::ComplexTPS, t1::TPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_divt!(ct1.tpsa, t1.tpsa, ct.tpsa)
return ct
end

function /(t1::TPS, ct1::ComplexTPS)::ComplexTPS
ct = zero(ct1)
mad_ctpsa_tdiv!(t1.tpsa, ct1.tpsa, ct.tpsa)
return ct
end

function /(t1::TPS, a::Complex)::ComplexTPS
ct = ComplexTPS(t1)
mad_ctpsa_scl!(ct.tpsa, convert(ComplexF64, 1/a), ct.tpsa)
return ct
end

function /(a::Complex, t1::TPS)::ComplexTPS
ct = ComplexTPS(t1)
mad_ctpsa_inv!(ct.tpsa, convert(ComplexF64, a), ct.tpsa)
return ct
end


# --- pow ---
# TPS:
function ^(t1::TPS, t2::TPS)::TPS
Expand Down
Loading

0 comments on commit 5e5f235

Please sign in to comment.