Skip to content

Commit

Permalink
fix MadNLPGPU on CUDA.jl 5.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
frapac authored and amontoison committed Jan 8, 2025
1 parent 9917706 commit f794d28
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
12 changes: 8 additions & 4 deletions lib/MadNLPGPU/src/KKT/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ function MadNLP.get_sparse_condensed_ext(
jt_map,
hess_map,
) where {T,VT<:CuVector{T}}
hess_com_ptr = map((i, j) -> (i, j), hess_map, 1:length(hess_map))
zvals = CuVector{Int}(1:length(hess_map))
hess_com_ptr = map((i, j) -> (i, j), hess_map, zvals)
if length(hess_com_ptr) > 0 # otherwise error is thrown
sort!(hess_com_ptr)
end

jt_csc_ptr = map((i, j) -> (i, j), jt_map, 1:length(jt_map))
jvals = CuVector{Int}(1:length(jt_map))
jt_csc_ptr = map((i, j) -> (i, j), jt_map, jvals)
if length(jt_csc_ptr) > 0 # otherwise error is thrown
sort!(jt_csc_ptr)
end
Expand Down Expand Up @@ -320,7 +322,8 @@ end
function MadNLP.coo_to_csc(
coo::MadNLP.SparseMatrixCOO{T,I,VT,VI},
) where {T,I,VT<:CuArray,VI<:CuArray}
coord = map((i, j, k) -> ((i, j), k), coo.I, coo.J, 1:length(coo.I))
zvals = CuVector{Int}(1:length(coo.I))
coord = map((i, j, k) -> ((i, j), k), coo.I, coo.J, zvals)
if length(coord) > 0
sort!(coord, lt = (((i, j), k), ((n, m), l)) -> (j, i) < (m, n))
end
Expand Down Expand Up @@ -505,7 +508,8 @@ function MadNLP._set_con_scale_sparse!(
jac_I,
jac_buffer,
) where {T,VT<:CuVector{T}}
inds = map((i, j) -> (i, j), jac_I, 1:length(jac_I))
ind_jac = CuVector{Int}(1:length(jac_I))
inds = map((i, j) -> (i, j), jac_I, ind_jac)
if !isempty(inds)
sort!(inds)
end
Expand Down
7 changes: 5 additions & 2 deletions src/KKT/Sparse/condensed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,18 @@ nzval(H) = H.nzval

n = size(H,2)

# N.B.: we should ensure that zind is allocated on the proper device.
zind = similar(nzval(H), Int, size(H, 2))
zind .= 1:size(H, 2)
map!(
i->(-1,i,0),
@view(sym[1:n]),
1:size(H,2)
zind,
)
map!(
i->(i,i),
@view(sym2[1:n]),
1:size(H,2)
zind,
)

_build_condensed_aug_symbolic_hess(
Expand Down

0 comments on commit f794d28

Please sign in to comment.