Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C interface for hsl_subset #221

Merged
merged 5 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ Quadmath = "be4d8f0f-7fa4-5f49-b795-2f01399ab2dd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
HSL_jll = "4, 2024"
Libdl = "1.9"
LinearAlgebra = "1.9"
OpenBLAS32_jll = "0.3.9"
Quadmath = "0.5.10"
julia = "^1.6.0"
SparseArrays = "1.9"
julia = "1.9"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 1 addition & 1 deletion gen/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"

[compat]
julia = "1.6"
HSL_jll = "=2024.11.28"
HSL_jll = "=2024.12.10"
158 changes: 129 additions & 29 deletions gen/rewriter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,45 @@ structure_modifications = Dict("_control_s}" => "_control{Float32}}",
"_sinfo_s}" => "_sinfo{Float32}}",
"_sinfo_d}" => "_sinfo{Float64}}")

function rewrite!(path::String, name::String, optimized::Bool)
function rewrite!(library::String, path::String, name::String, optimized::Bool)
if library == "libhsl"
libhsl_rewrite!(path, name, optimized)
elseif library == "hsl_subset"
hsl_subset_rewrite!(path, name, optimized)
else
error("The library $library is not supported.")
end
end

function libhsl_rewrite!(path::String, name::String, optimized::Bool)
text = read(path, String)
if name == "libhsl"
updated_text = replace(text, "# no prototype is found for this function at libhsl.h:44:6, please use with caution\n" => "")
updated_text = replace(updated_text, "major, minor, patch)\n" => ")\n major = Ref{Cint}(0)\n minor = Ref{Cint}(0)\n patch = Ref{Cint}(0)\n")
updated_text = replace(updated_text, "Ptr{Cint}" => "Ref{Cint}")
updated_text = replace(updated_text, " @ccall" => " @ccall")
updated_text = replace(updated_text, "Cvoid\n" => "Cvoid\n return VersionNumber(major[], minor[], patch[])\n")
text = replace(text, "# no prototype is found for this function at libhsl.h:44:6, please use with caution\n" => "")
text = replace(text, "major, minor, patch)\n" => ")\n major = Ref{Cint}(0)\n minor = Ref{Cint}(0)\n patch = Ref{Cint}(0)\n")
text = replace(text, "Ptr{Cint}" => "Ref{Cint}")
text = replace(text, " @ccall" => " @ccall")
text = replace(text, "Cvoid\n" => "Cvoid\n return VersionNumber(major[], minor[], patch[])\n")
else
solver = split(name, "_")[2]
updated_text = replace(text, "struct $solver" => "mutable struct $solver")
text = replace(text, "struct $solver" => "mutable struct $solver")
if optimized
for (keys, vals) in type_modifications
updated_text = replace(updated_text, solver * keys => vals)
text = replace(text, solver * keys => vals)
end
for (keys, vals) in structure_modifications
updated_text = replace(updated_text, solver * keys => solver * vals)
text = replace(text, solver * keys => solver * vals)
end
for structure in ("control", "info", "solve_control", "ainfo", "sinfo", "finfo")
updated_text = replace(updated_text, "mutable struct $(solver)_$(structure)_s" => "mutable struct $(solver)_$(structure){T}")
updated_text = replace(updated_text, "mutable struct $(solver)_$(structure)_i" => "mutable struct $(solver)_$(structure){T}")
updated_text = replace(updated_text, "Ptr{$(solver)_$(structure)" => "Ref{$(solver)_$(structure)")
text = replace(text, "mutable struct $(solver)_$(structure)_s" => "mutable struct $(solver)_$(structure){T}")
text = replace(text, "mutable struct $(solver)_$(structure)_i" => "mutable struct $(solver)_$(structure){T}")
text = replace(text, "Ptr{$(solver)_$(structure)" => "Ref{$(solver)_$(structure)")
end
updated_text = replace(updated_text, "::Float32\n" => "::T\n")
updated_text = replace(updated_text, "Float32}\n" => "T}\n") # NTuple{N, Float32} → NTuple{N, T}
text = replace(text, "::Float32\n" => "::T\n")
text = replace(text, "Float32}\n" => "T}\n") # NTuple{N, Float32} → NTuple{N, T}

# Add two constructors for each structure
blocks = split(updated_text, "end\n", keepempty=false)
updated_text = ""
blocks = split(text, "end\n", keepempty=false)
text = ""
for code in blocks
if contains(code, "mutable struct")
structure = code * "end\n"
Expand All @@ -84,39 +94,129 @@ function rewrite!(path::String, name::String, optimized::Bool)
end
end
structure = replace(structure, "end\n" => "\n " * structure_name * "($arguments) where T = new($arguments)\nend\n")
updated_text = updated_text * structure
text = text * structure
else
updated_text = updated_text * code * "end\n"
text = text * code * "end\n"
end
end

# Special cases where the structures are not parameterized.
if name == "hsl_ma48"
for type in ("T", "Float32", "Float64")
updated_text = replace(updated_text, "$(solver)_sinfo{$type}" => "$(solver)_sinfo")
updated_text = replace(updated_text, Regex("$(solver)_sinfo(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_sinfo\\1"))
text = replace(text, "$(solver)_sinfo{$type}" => "$(solver)_sinfo")
text = replace(text, Regex("$(solver)_sinfo(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_sinfo\\1"))
end
end

if name == "hsl_mc64"
for type in ("T", "Float32", "Float64")
updated_text = replace(updated_text, "$(solver)_control{$type}" => "$(solver)_control")
updated_text = replace(updated_text, Regex("$(solver)_control(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_control\\1"))
updated_text = replace(updated_text, "$(solver)_info{$type}" => "$(solver)_info")
updated_text = replace(updated_text, Regex("$(solver)_info(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_info\\1"))
text = replace(text, "$(solver)_control{$type}" => "$(solver)_control")
text = replace(text, Regex("$(solver)_control(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_control\\1"))
text = replace(text, "$(solver)_info{$type}" => "$(solver)_info")
text = replace(text, Regex("$(solver)_info(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_info\\1"))
end
end

if name == "hsl_mc68" || name == "hsl_mc78" || name == "hsl_mc79"
for type in ("T", "Cint", "Clong")
updated_text = replace(updated_text, "$(solver)_control{$type}" => "$(solver)_control")
updated_text = replace(updated_text, Regex("$(solver)_control(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_control\\1"))
updated_text = replace(updated_text, "$(solver)_info{$type}" => "$(solver)_info")
updated_text = replace(updated_text, Regex("$(solver)_info(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_info\\1"))
text = replace(text, "$(solver)_control{$type}" => "$(solver)_control")
text = replace(text, Regex("$(solver)_control(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_control\\1"))
text = replace(text, "$(solver)_info{$type}" => "$(solver)_info")
text = replace(text, Regex("$(solver)_info(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_info\\1"))
end
end
end
end
write(path, updated_text)
write(path, text)
(name != "libhsl") && format_file(path, YASStyle())
end

function hsl_subset_rewrite!(path::String, name::String, optimized::Bool)
text = read(path, String)
structures = ""
info_structures = Tuple{String, String, Bool}[]
if optimized
text = replace(text, "struct " => "mutable struct ")
text = replace(text, "hsl_longc_" => Int64)

blocks = split(text, "end\n")
text = ""
for (index, code) in enumerate(blocks)
if contains(code, "function")
for (ipc_, rpc_, suffix, lib) in (("Int32", "Float32" , "_s" , "libhsl_subset" ),
("Int32", "Float64" , "_d" , "libhsl_subset" ),
("Int32", "Float128", "_q" , "libhsl_subset" ),
("Int64", "Float32" , "_s_64", "libhsl_subset_64"),
("Int64", "Float64" , "_d_64", "libhsl_subset_64"),
("Int64", "Float128", "_q_64", "libhsl_subset_64"))
# We only want to generate two methods (Int32 / Int64) for hsl_mc68
(name == "hsl_mc68") && (rpc_ != "Float64") && continue

fname = split(split(code, "function ")[2], "(")[1]
fname_generic = fname[1:end-2]
pp_fname = fname[1:end-2] * suffix
routine = code * "end\n"
if name == "hsl_mc68"
endswith(fname, "_i") || error("The symbol $fname should have the suffix _i")
routine = replace(routine, "function $fname(" => "function $(fname_generic)(::Type{$ipc_}, ")
else
endswith(fname, "_d") || error("The symbol $fname should have the suffix _d")
routine = replace(routine, "function $fname(" => "function $(fname_generic)(::Type{$rpc_}, ::Type{$ipc_}, ")
end
routine = replace(routine, "libhsl.$fname(" => "$lib.$(pp_fname)(")
routine = replace(routine, "ipc_" => ipc_)
routine = replace(routine, "rpc_" => rpc_)

# Update the type of the structures
routine = replace(routine, "_d}" => "_d{$rpc_,$ipc_}}")
routine = replace(routine, "_i}" => "_i{$rpc_,$ipc_}}")

# Float128 should be passed by value as a Cfloat128
routine = replace(routine, "::Float128" => "::Cfloat128")

text = text * routine * "\n"
end
elseif contains(code, "struct")
structure = code * "end\n"
structure_name = split(split(code, "struct ")[2], "\n")[1] |> String
generic_structure_name = structure_name[1:end-2] |> String
generic_structure_name = 'M' * generic_structure_name[2:end]
generic_structure_name = replace(generic_structure_name, "_solve_control" => "SolveControl")
generic_structure_name = replace(generic_structure_name, "_control" => "Control")
generic_structure_name = replace(generic_structure_name, "_ainfo" => "Ainfo")
generic_structure_name = replace(generic_structure_name, "_finfo" => "Finfo")
generic_structure_name = replace(generic_structure_name, "_sinfo" => "Sinfo")
generic_structure_name = replace(generic_structure_name, "_info" => "Info")
structure = replace(structure, "rpc_" => "T")
structure = replace(structure, "ipc_" => "INT")
if !contains(code, "rpc_")
structure = replace(structure, structure_name => generic_structure_name * "{INT}")
push!(info_structures, (structure_name, generic_structure_name, false))
else
structure = replace(structure, structure_name => generic_structure_name * "{T,INT}")
push!(info_structures, (structure_name, generic_structure_name, true))
end
structures = structures * structure * "\n"
else
text = text * code
end
end
end
text = structures * "\n" * text
startswith(text, '\n') && (text = text[2:end])

# Rename the structures in the wrappers
for (old_struct, new_struct, bool) in info_structures
if bool
text = replace(text, "Ptr{$old_struct" => "Ref{$new_struct")
else
for precision in ("Float32", "Float64", "Float128")
text = replace(text, "Ptr{$old_struct{$precision,Int32}}" => "Ref{$new_struct{Int32}}")
text = replace(text, "Ptr{$old_struct{$precision,Int64}}" => "Ref{$new_struct{Int64}}")
end
end
end

write(path, text)
format_file(path, YASStyle())
end
Loading
Loading