Skip to content

Commit

Permalink
Merge pull request #71 from bmad-sim/dev4
Browse files Browse the repository at this point in the history
Fix major sparse monomial indexing bugs, improve printing substantially
  • Loading branch information
mattsignorelli committed Jan 16, 2024
2 parents d37c674 + 46200d7 commit 7e79e22
Showing 1 changed file with 130 additions and 43 deletions.
173 changes: 130 additions & 43 deletions src/GTPSA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,23 +844,31 @@ function complexparams(d::Descriptor)::Vector{ComplexTPS}
end

# Function to convert var=>ord, params=(param=>ord,) to sparse monomial format (varidx1, ord1, varidx2, ord2, paramidx, ordp1,...)
function pairs_to_sm(vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())::Tuple{Vector{Cint}, Cint}
function pairs_to_sm(t::Union{TPS,ComplexTPS}, vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())::Tuple{Vector{Cint}, Cint}
# WE MUST ORDER THE VARIABLES !!!
desc = unsafe_load(Base.unsafe_convert(Ptr{Desc}, unsafe_load(t.tpsa).d))
nv = desc.nv # TOTAL NUMBER OF VARS!!!!!!
numv = Cint(length(vars))
nump = Cint(length(params))
sm = Vector{Cint}(undef, 2*(numv+nump))
imin = min(minimum(x->x.first, vars,init=typemax(Int)), minimum(x->x.first+nv, params,init=typemax(Int)))
imax = max(maximum(x->x.first, vars,init=0), maximum(x->x.first+nv, params,init=0))
len = imax-imin+1
sm = zeros(Cint, 2*len)
sm[1:2:end] = imin:imax
for i=1:numv
sm[2*i-1] = convert(Cint, vars[i].first)
sm[2*i] = convert(Cint, vars[i].second)
sm[2*(vars[i].first-imin+1)] = convert(Cint, vars[i].second)
end
for i=numv+1:numv+nump
sm[2*i-1] = convert(Cint, params[i].first+numv)
sm[2*i] = convert(Cint, params[i].second)
for i=1:nump
sm[2*(params[i].first+nv-imin+1)] = convert(Cint, params[i].second)
end
return sm, Cint(2)*(numv+nump)

return sm, 2*len
end

# Function to convert var=>ord, params=(param=>ord,) to monomial format (byte array of orders)
function pairs_to_m(vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())::Tuple{Vector{UInt8}, Cint}
function pairs_to_m(t::Union{TPS,ComplexTPS}, vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())::Tuple{Vector{UInt8}, Cint}
desc = unsafe_load(Base.unsafe_convert(Ptr{Desc}, unsafe_load(t.tpsa).d))
nv = desc.nv
n = Cint(0)
if isempty(params)
n = Cint(maximum(map(x->x.first, vars)))
Expand All @@ -884,7 +892,7 @@ end

function getindex(t::TPS, vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())::Float64
# use sparse monomial getter
sm, n = pairs_to_sm(vars..., params=params)
sm, n = pairs_to_sm(t, vars..., params=params)
return mad_tpsa_getsm(t.tpsa, n, sm)
end

Expand All @@ -895,7 +903,7 @@ end

function getindex(ct::ComplexTPS, vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())::ComplexF64
# use sparse monomial getter
sm, n = pairs_to_sm(vars..., params=params)
sm, n = pairs_to_sm(ct, vars..., params=params)
return mad_ctpsa_getsm(ct.tpsa, n, sm)
end

Expand All @@ -906,18 +914,16 @@ function setindex!(t::TPS, v::Real, ords::Integer...)
end

function setindex!(t::TPS, v::Real, vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())
# use sparse monomial getter
sm, n = pairs_to_sm(vars..., params=params)
sm, n = pairs_to_sm(t, vars..., params=params)
mad_tpsa_setsm!(t.tpsa, n, sm, 0.0, convert(Cdouble, v))
end

function setindex!(ct::ComplexTPS, v::Number, ords::Integer...)
mad_ctpsa_setm!(ct.tpsa, convert(Cint, length(ords)), convert(Vector{Cuchar}, [ords...]), convert(ComplexF64, 0), convert(ComplexF64, v))
end

function getindex(ct::ComplexTPS, v::Number, vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())::ComplexF64
# use sparse monomial getter
sm, n = pairs_to_sm(vars..., params=params)
function setindex!(ct::ComplexTPS, v::Number, vars::Pair{<:Integer, <:Integer}...; params::Tuple{Vararg{Pair{<:Integer,<:Integer}}}=())
sm, n = pairs_to_sm(ct, vars..., params=params)
mad_ctpsa_setsm!(ct.tpsa, n, sm, convert(ComplexF64, 0), convert(ComplexF64, v))
end

Expand Down Expand Up @@ -1177,15 +1183,12 @@ end


# --- print ---
function show(io::IO, d::Descriptor)
desc = unsafe_load(d.desc)
function show_GTPSA_info(io::IO, desc::Desc)
nv = desc.nv
np = desc.np
nn = desc.nn
no_ = unsafe_wrap(Vector{Cuchar}, desc.no, nn)
no = convert(Vector{Int}, no_)
println(io, "GTPSA Descriptor")
println(io, "-----------------------")
if nv > 0
@printf(io, "%-18s %i\n", "# Variables: ", nv)
if all(no[1] .== no[1:nv])
Expand All @@ -1208,6 +1211,19 @@ function show(io::IO, d::Descriptor)
end
end

function show(io::IO, d::Descriptor)
println(io, "GTPSA Descriptor")
println(io, "-----------------------")
desc = unsafe_load(d.desc)
show_GTPSA_info(io, desc)
end

struct ParamPair
idx::Int
ord::Int
end
# Dumb workaround to get nice string output in PrettyTables
Base.show(io::IO, p::ParamPair) = print(io, "k", p.idx, "=>", p.ord)

function show(io::IO, t::TPS)
desc = unsafe_load(Base.unsafe_convert(Ptr{Desc}, unsafe_load(t.tpsa).d))
Expand All @@ -1216,21 +1232,57 @@ function show(io::IO, t::TPS)
nn = desc.nn
v = Ref{Cdouble}()
mono = Vector{UInt8}(undef, nn)
out = Matrix{Any}(undef, 0, (1+1+1+nn)) # First col is coefficient, rest are orders
idx = Cint(-1)
idx = mad_tpsa_cycle!(t.tpsa, idx, nn, mono, v)
while idx >= 0
order = Int(sum(mono))
out = vcat(out, Any[v[] order "" convert(Vector{Int}, mono)...])

# If nn > 6 (6 chosen arbitrarily), use sparse monomial format for print
if nn <= 6
out = Matrix{Any}(undef, 0, (1+1+1+nn)) # Coefficient, order, spacing, exponents
idx = Cint(-1)
idx = mad_tpsa_cycle!(t.tpsa, idx, nn, mono, v)
end
if size(out)[1] == 0
out = vcat(out, Any[0.0 Int(0) "" zeros(Int,nn)...])
while idx >= 0
order = Int(sum(mono))
out = vcat(out, Any[v[] order nothing convert(Vector{Int}, mono)...])
idx = mad_tpsa_cycle!(t.tpsa, idx, nn, mono, v)
end
if size(out)[1] == 0
out = vcat(out, Any[0.0 Int(0) nothing zeros(Int,nn)...])
end
formatters = (ft_printf("%23.16le", [1]), ft_printf("%2i", 2:3+nn), ft_nonothing)
cutcols = 1+1+1+nn
else
out = Matrix{Any}(nothing, 0, (1+1+1+desc.mo))
idx = Cint(-1)
idx = mad_tpsa_cycle!(t.tpsa, idx, nn, mono, v)
while idx >= 0
order = Int(sum(mono))
varpairs = Vector{Pair{Int,Int}}(undef,0)
parampairs = Vector{ParamPair}(undef,0)
# Create variable pairs
for var_idx in findall(x->x>0x0, mono)
if var_idx > nv
push!(parampairs, ParamPair(var_idx-nv,Int(mono[var_idx])))
else
push!(varpairs, var_idx=>Int(mono[var_idx]))
end
end
out = vcat(out, Any[v[] order nothing varpairs... parampairs... repeat([nothing], desc.mo - (length(varpairs)+length(parampairs)))...])
idx = mad_tpsa_cycle!(t.tpsa, idx, nn, mono, v)
end
if size(out)[1] == 0
out = vcat(out, Any[0.0 Int(0) repeat([nothing], desc.mo +1)...])
end
formatters = (ft_printf("%23.16le", [1]), ft_printf("%2i", 2), ft_nonothing)
# Refine out to get rid of extra trailing columns
cutcols = 4
while cutcols < size(out)[2]
if all(isnothing.(out[:,cutcols]))
break
end
cutcols = cutcols + 1
end
end
println(io, "TPS:")
println(io, " COEFFICIENT ORDER EXPONENTS")
formatters = (ft_printf("%23.16le", [1]), ft_printf("%2i", 2:3+nn))
pretty_table(io, out,tf=tf_borderless,formatters=formatters,show_header=false, alignment=:l)
pretty_table(io, out[:,1:cutcols],tf=tf_borderless,formatters=formatters,show_header=false, alignment=:l)
end

function show(io::IO, t::ComplexTPS)
Expand All @@ -1240,22 +1292,57 @@ function show(io::IO, t::ComplexTPS)
nn = desc.nn
v = Ref{ComplexF64}()
mono = Vector{UInt8}(undef, nn)
out = Matrix{Any}(undef, 0, (1+1+1+1+nn)) # First col is coefficient, rest are orders
idx = Cint(-1)
idx = mad_ctpsa_cycle!(t.tpsa, idx, nn, mono, v)
while idx >= 0
order = Int(sum(mono))
out = vcat(out, Any[real(v[]) imag(v[]) order "" convert(Vector{Int}, mono)...])

# If nn > 6 (6 chosen arbitrarily), use sparse monomial format for print
if nn <= 6
out = Matrix{Any}(undef, 0, (1+1+1+1+nn)) # First col is coefficient, rest are orders
idx = Cint(-1)
idx = mad_ctpsa_cycle!(t.tpsa, idx, nn, mono, v)
end
if size(out)[1] == 0
out = vcat(out, Any[0.0 0.0 Int(0) "" zeros(Int,nn)...])
while idx >= 0
order = Int(sum(mono))
out = vcat(out, Any[real(v[]) imag(v[]) order nothing convert(Vector{Int}, mono)...])
idx = mad_ctpsa_cycle!(t.tpsa, idx, nn, mono, v)
end
if size(out)[1] == 0
out = vcat(out, Any[0.0 0.0 Int(0) nothing zeros(Int,nn)...])
end
formatters = (ft_printf("%23.16le", [1]),ft_printf("%23.16le", [2]), ft_printf("%2i", 3:4+nn), ft_nonothing)
cutcols = 1+1+1+1+nn
else
out = Matrix{Any}(nothing, 0, (1+1+1+1+desc.mo))
idx = Cint(-1)
idx = mad_ctpsa_cycle!(t.tpsa, idx, nn, mono, v)
while idx >= 0
order = Int(sum(mono))
varpairs = Vector{Pair{Int,Int}}(undef,0)
parampairs = Vector{ParamPair}(undef,0)
# Create variable pairs
for var_idx in findall(x->x>0x0, mono)
if var_idx > nv
push!(parampairs, ParamPair(var_idx-nv,Int(mono[var_idx])))
else
push!(varpairs, var_idx=>Int(mono[var_idx]))
end
end
out = vcat(out, Any[real(v[]) imag(v[]) order nothing varpairs... parampairs... repeat([nothing], desc.mo - (length(varpairs)+length(parampairs)))...])
idx = mad_ctpsa_cycle!(t.tpsa, idx, nn, mono, v)
end
if size(out)[1] == 0
out = vcat(out, Any[0.0 0.0 Int(0) repeat([nothing], desc.mo +1)...])
end
formatters = (ft_printf("%23.16le", [1]),ft_printf("%23.16le", [2]), ft_printf("%2i", 3), ft_nonothing)
# Refine out to get rid of extra trailing columns
cutcols = 5
while cutcols < size(out)[2]
if all(isnothing.(out[:,cutcols]))
break
end
cutcols = cutcols + 1
end
end
println(io, "ComplexTPS:")
#println(io, " COEFFICIENT")
println(io, " REAL IMAG ORDER EXPONENTS")
formatters = (ft_printf("%23.16le", [1]),ft_printf("%23.16le", [2]), ft_printf("%2i", 3:4+nn))
pretty_table(io, out,tf=tf_borderless,formatters=formatters,show_header=false, alignment=:l)
pretty_table(io, out[:,1:cutcols],tf=tf_borderless,formatters=formatters,show_header=false, alignment=:l)
end


Expand Down

0 comments on commit 7e79e22

Please sign in to comment.