diff --git a/Project.toml b/Project.toml index 04dbe1fe..6dce58b0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -32,7 +32,7 @@ MooncakeLogDensityProblemsADExt = "LogDensityProblemsAD" MooncakeSpecialFunctionsExt = "SpecialFunctions" [compat] -ADTypes = "1.2" +ADTypes = "1.9" BenchmarkTools = "1" CUDA = "5" ChainRulesCore = "1" diff --git a/README.md b/README.md index d77ba237..1b35d4b8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Mooncake +# Mooncake.jl (formerly Tapir.jl) [![Build Status](https://github.com/compintell/Mooncake.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/compintell/Mooncake.jl/actions/workflows/CI.yml?query=branch%3Amain) [![codecov](https://codecov.io/github/compintell/Mooncake.jl/graph/badge.svg?token=NUPWTB4IAP)](https://codecov.io/github/compintell/Mooncake.jl) @@ -8,41 +8,32 @@ The goal of the `Mooncake.jl` project is to produce a reverse-mode AD package which is written entirely in Julia, which improves over both `ReverseDiff.jl` and `Zygote.jl` in several ways, and is competitive with `Enzyme.jl`. -## Note on renaming - -On 18/09/2024 this package was renamed from Tapir.jl to Mooncake.jl. -The last version while the package was called Tapir.jl was 0.2.51. -Upon renaming, the version was bumped to 0.3.0. - -We are currently going through the process of updating the name of the package in the general registry and updating dependents to use the new package naming. -This should be largely complete in a few days. -During this time, there will be no new releases of Mooncake.jl, and there will be issues with its interaction with ADTypes.jl, LogDensityProblemsAD.jl, and possibly other things that we haven't thought of. - ## Note on project status `Mooncake.jl` is under active development. You should presently expect releases involving breaking changes on a semi-regular basis. -We are trying to keep this README as up to date as possible, particularly with regards to the best examples of code to look at to understand how to use Mooncake.jl. -If you encounter a new version of Mooncake.jl in the wild, please consult this README for the most up-to-date advice. +We are trying to keep this README as up to date as possible, particularly with regards to the best examples of code to look at to understand how to use `Mooncake.jl`. +If you encounter a new version of `Mooncake.jl` in the wild, please consult this README for the most up-to-date advice. # Getting Started -There are several ways to interact with Mooncake.jl. -The one that we recommend people begin with is [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl/). For example, use it as follows to compute the gradient of a function mapping a `Vector{Float64}` to `Float64`. +There are several ways to interact with `Mooncake.jl`. +The one that we recommend people begin with is [`DifferentiationInterface.jl`](https://github.com/gdalle/DifferentiationInterface.jl/). +For example, use it as follows to compute the gradient of a function mapping a `Vector{Float64}` to `Float64`. ```julia using DifferentiationInterface import Mooncake -f(x) = sum(abs2, x) -backend = AutoMooncake() -x = ones(3) -extras = prepare_gradient(f, backend, x) -gradient(f, backend, x, extras) +f(x) = sum(cos, x) +backend = AutoMooncake(; config=nothing) +x = ones(1_000) +prep = prepare_gradient(f, backend, x) +gradient(f, prep, backend, x) ``` -You should expect that the first time you run `gradient` that it will take a little bit of time, but subsequent runs should be fast. +You should expect that `prep` takes a little bit of time to run, but that `gradient` is fast. We are committed to ensuring support for DifferentiationInterface, which is why we recommend using that. -If you are interested in slightly more flexible functionality, you should consider `Mooncake.value_and_gradient!!`. See its docstring for more info. +If you are interested in interacting in a more direct fashion with `Mooncake.jl`, you should consider `Mooncake.value_and_gradient!!`. See its docstring for more info. # How it works @@ -117,6 +108,9 @@ For about 48 hours is was called `Phi.jl`, but the community guidelines state th We then chose `Tapir.jl`, and didn't initially feel that other work [of the same name](https://github.com/wsmoses/Tapir-LLVM) presented a serious name clash, as it isn't AD-specific or a Julia project. As it turns out, there has been significant work attempting to integrate the ideas from this work into the [Julia compiler](https://github.com/JuliaLang/julia/pull/39773), so the clash is something of a problem. +On 18/09/2024 this package was renamed from `Tapir.jl` to `Mooncake.jl`. +The last version while the package was called `Tapir.jl` was 0.2.51. +Upon renaming, the version was bumped to 0.3.0. We finally settled on `Mooncake.jl`. Hopefully this name will stick. # Project Status diff --git a/docs/src/debug_mode.md b/docs/src/debug_mode.md index d73692bb..0a62f607 100644 --- a/docs/src/debug_mode.md +++ b/docs/src/debug_mode.md @@ -43,10 +43,10 @@ Mooncake.build_rrule ``` When using ADTypes.jl, you can choose whether or not to use it via the `debug_mode` kwarg: -# ```jldoctest -# julia> AutoMooncake(Mooncake.Config()) -# AutoMooncake(Mooncake.Config()) -# ``` +```jldoctest +julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true)) +AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false)) +``` ### When Should You Use Debug Mode? diff --git a/ext/MooncakeLogDensityProblemsADExt.jl b/ext/MooncakeLogDensityProblemsADExt.jl index ec4986fc..29f3870d 100644 --- a/ext/MooncakeLogDensityProblemsADExt.jl +++ b/ext/MooncakeLogDensityProblemsADExt.jl @@ -58,16 +58,18 @@ function logdensity_and_gradient(∇l::MooncakeGradientLogDensity, x::Vector{Flo return Mooncake.primal(y), dx end -# # Interop with ADTypes. -# function ADgradient(x::ADTypes.AutoMooncake, ℓ) -# if x.debug_mode -# msg = "Running Mooncake in debug mode. This mode is computationally expensive, " * -# "should only be used when debugging a problem with AD, and turned off in " * -# "general use. Do this by using AutoMooncake(debug_mode=false)." -# @info msg -# end -# return ADgradient(Val(:Mooncake), ℓ; debug_mode=x.debug_mode) -# end +# Interop with ADTypes. +function ADgradient(x::ADTypes.AutoMooncake, ℓ) + debug_mode = x.config.debug_mode + if debug_mode + msg = "Running Mooncake in debug mode. This mode is computationally expensive, " * + "should only be used when debugging a problem with AD, and turned off in " * + "general use. Do this by using " * + "AutoMooncake(; config=Mooncake.Config(debug_mode=false))." + @info msg + end + return ADgradient(Val(:Mooncake), ℓ; debug_mode) +end Base.parent(x::MooncakeGradientLogDensity) = Mooncake.primal(x.ℓ) diff --git a/src/rrules/lapack.jl b/src/rrules/lapack.jl index 4c3c9497..9578a07e 100644 --- a/src/rrules/lapack.jl +++ b/src/rrules/lapack.jl @@ -47,20 +47,22 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) dA = tangent(_A) function getrf_pb!!(::NoRData) - # Run reverse-pass. - L, U = UnitLowerTriangular(A_mat), UpperTriangular(A_mat) - dA_mat = wrap_ptr_as_view(dA, LDA_val, M_val, N_val) - dL, dU = tril(dA_mat, -1), UpperTriangular(dA_mat) + GC.@preserve args begin + # Run reverse-pass. + L, U = UnitLowerTriangular(A_mat), UpperTriangular(A_mat) + dA_mat = wrap_ptr_as_view(dA, LDA_val, M_val, N_val) + dL, dU = tril(dA_mat, -1), UpperTriangular(dA_mat) - # Figure out the pivot matrix used. - p = LinearAlgebra.ipiv2perm(ipiv_vec, N_val) + # Figure out the pivot matrix used. + p = LinearAlgebra.ipiv2perm(ipiv_vec, N_val) - # Compute pullback using Seth's method. - __dF = tril(L'dL, -1) + UpperTriangular(dU * U') - dA_mat .= (inv(L') * __dF * inv(U'))[invperm(p), :] + # Compute pullback using Seth's method. + __dF = tril(L'dL, -1) + UpperTriangular(dU * U') + dA_mat .= (inv(L') * __dF * inv(U'))[invperm(p), :] - # Restore initial state. - A_mat .= A_store + # Restore initial state. + A_mat .= A_store + end return tuple_fill(NoRData(), Val(12 + Nargs)) end @@ -91,48 +93,52 @@ for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) args::Vararg{Any, Nargs}, ) where {Nargs} - # Load in data. - ul_p, tA_p, diag_p = map(primal, (_ul, _tA, _diag)) - N_p, Nrhs_p, lda_p, ldb_p, info_p = map(primal, (_N, _Nrhs, _lda, _ldb, _info)) - ul, tA, diag, N, Nrhs, lda, ldb, info = map( - unsafe_load, (ul_p, tA_p, diag_p, N_p, Nrhs_p, lda_p, ldb_p, info_p), - ) - - A = wrap_ptr_as_view(primal(_A), lda, N, N) - B = wrap_ptr_as_view(primal(_B), ldb, N, Nrhs) - B_copy = copy(B) - - # Run the primal. - ccall( - $(blas_name(fname)), - Cvoid, - ( - Ptr{UInt8}, Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, - Clong, Clong, Clong, - ), - ul_p, tA_p, diag_p, N_p, Nrhs_p, primal(_A), lda_p, primal(_B),ldb_p, info_p, - 1, 1, 1, - ) + GC.@preserve args begin + # Load in data. + ul_p, tA_p, diag_p = map(primal, (_ul, _tA, _diag)) + N_p, Nrhs_p, lda_p, ldb_p, info_p = map(primal, (_N, _Nrhs, _lda, _ldb, _info)) + ul, tA, diag, N, Nrhs, lda, ldb, info = map( + unsafe_load, (ul_p, tA_p, diag_p, N_p, Nrhs_p, lda_p, ldb_p, info_p), + ) + + A = wrap_ptr_as_view(primal(_A), lda, N, N) + B = wrap_ptr_as_view(primal(_B), ldb, N, Nrhs) + B_copy = copy(B) + + # Run the primal. + ccall( + $(blas_name(fname)), + Cvoid, + ( + Ptr{UInt8}, Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, + Clong, Clong, Clong, + ), + ul_p, tA_p, diag_p, N_p, Nrhs_p, primal(_A), lda_p, primal(_B),ldb_p, info_p, + 1, 1, 1, + ) + end _dA = tangent(_A) _dB = tangent(_B) function trtrs_pb!!(::NoRData) - # Compute cotangent of B. - dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) - LAPACK.trtrs!(Char(ul), Char(tA) == 'N' ? 'T' : 'N', Char(diag), A, dB) + GC.@preserve args begin + # Compute cotangent of B. + dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) + LAPACK.trtrs!(Char(ul), Char(tA) == 'N' ? 'T' : 'N', Char(diag), A, dB) - # Compute cotangent of A. - dA = wrap_ptr_as_view(_dA, lda, N, N) - if Char(tA) == 'N' - dA .-= tri!(dB * B', Char(ul), Char(diag)) - else - dA .-= tri!(B * dB', Char(ul), Char(diag)) - end + # Compute cotangent of A. + dA = wrap_ptr_as_view(_dA, lda, N, N) + if Char(tA) == 'N' + dA .-= tri!(dB * B', Char(ul), Char(diag)) + else + dA .-= tri!(B * dB', Char(ul), Char(diag)) + end - # Restore initial state. - B .= B_copy + # Restore initial state. + B .= B_copy + end return tuple_fill(NoRData(), Val(16 + Nargs)) end @@ -160,80 +166,84 @@ for (fname, elty) in ((:dgetrs_, :Float64), (:sgetrs_, :Float32)) args::Vararg{Any, Nargs}, ) where {Nargs} - # Load in values. - tA = Char(unsafe_load(primal(_tA))) - N, Nrhs, lda, ldb, info = map(unsafe_load ∘ primal, (_N, _Nrhs, _lda, _ldb, _info)) - ipiv = unsafe_wrap(Vector{BlasInt}, primal(_ipiv), N) - A = wrap_ptr_as_view(primal(_A), lda, N, N) - B = wrap_ptr_as_view(primal(_B), ldb, N, Nrhs) - B0 = copy(B) - - # Pivot B. - p = LinearAlgebra.ipiv2perm(ipiv, N) - - if tA == 'N' - # Apply permutation matrix. - B .= B[p, :] - - # Run inv(L) * B and write result to B. - LAPACK.trtrs!('L', 'N', 'U', A, B) - B1 = copy(B) # record intermediate state for use in pullback. - - # Run inv(U) * B and write result to B. - LAPACK.trtrs!('U', 'N', 'N', A, B) - B2 = B - else - # Run inv(U)^T * B and write result to B. - LAPACK.trtrs!('U', 'T', 'N', A, B) - B1 = copy(B) # record intermediate state for use in pullback. - - # Run inv(L)^T * B and write result to B. - LAPACK.trtrs!('L', 'T', 'U', A, B) - B2 = B - - # Apply permutation matrix. - B2 .= B2[invperm(p), :] - end + GC.@preserve args begin + # Load in values. + tA = Char(unsafe_load(primal(_tA))) + N, Nrhs, lda, ldb, info = map(unsafe_load ∘ primal, (_N, _Nrhs, _lda, _ldb, _info)) + ipiv = unsafe_wrap(Vector{BlasInt}, primal(_ipiv), N) + A = wrap_ptr_as_view(primal(_A), lda, N, N) + B = wrap_ptr_as_view(primal(_B), ldb, N, Nrhs) + B0 = copy(B) - # We need to write to `info`. - unsafe_store!(primal(_info), 0) + # Pivot B. + p = LinearAlgebra.ipiv2perm(ipiv, N) + + if tA == 'N' + # Apply permutation matrix. + B .= B[p, :] + + # Run inv(L) * B and write result to B. + LAPACK.trtrs!('L', 'N', 'U', A, B) + B1 = copy(B) # record intermediate state for use in pullback. + + # Run inv(U) * B and write result to B. + LAPACK.trtrs!('U', 'N', 'N', A, B) + B2 = B + else + # Run inv(U)^T * B and write result to B. + LAPACK.trtrs!('U', 'T', 'N', A, B) + B1 = copy(B) # record intermediate state for use in pullback. + + # Run inv(L)^T * B and write result to B. + LAPACK.trtrs!('L', 'T', 'U', A, B) + B2 = B + + # Apply permutation matrix. + B2 .= B2[invperm(p), :] + end + + # We need to write to `info`. + unsafe_store!(primal(_info), 0) + end _dA = tangent(_A) _dB = tangent(_B) function getrs_pb!!(::NoRData) - dA = wrap_ptr_as_view(_dA, lda, N, N) - dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) + GC.@preserve args begin + dA = wrap_ptr_as_view(_dA, lda, N, N) + dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) - if tA == 'N' + if tA == 'N' - # Run pullback for inv(U) * B. - LAPACK.trtrs!('U', 'T', 'N', A, dB) - dA .-= tri!(dB * B2', 'U', 'N') + # Run pullback for inv(U) * B. + LAPACK.trtrs!('U', 'T', 'N', A, dB) + dA .-= tri!(dB * B2', 'U', 'N') - # Run pullback for inv(L) * B. - LAPACK.trtrs!('L', 'T', 'U', A, dB) - dA .-= tri!(dB * B1', 'L', 'U') + # Run pullback for inv(L) * B. + LAPACK.trtrs!('L', 'T', 'U', A, dB) + dA .-= tri!(dB * B1', 'L', 'U') - # Undo permutation. - dB .= dB[invperm(p), :] - else + # Undo permutation. + dB .= dB[invperm(p), :] + else - # Undo permutation. - dB .= dB[p, :] - B2 .= B2[p, :] + # Undo permutation. + dB .= dB[p, :] + B2 .= B2[p, :] - # Run pullback for inv(L^T) * B. - LAPACK.trtrs!('L', 'N', 'U', A, dB) - dA .-= tri!(B2 * dB', 'L', 'U') + # Run pullback for inv(L^T) * B. + LAPACK.trtrs!('L', 'N', 'U', A, dB) + dA .-= tri!(B2 * dB', 'L', 'U') - # Run pullback for inv(U^T) * B. - LAPACK.trtrs!('U', 'N', 'N', A, dB) - dA .-= tri!(B1 * dB', 'U', 'N') - end + # Run pullback for inv(U^T) * B. + LAPACK.trtrs!('U', 'N', 'N', A, dB) + dA .-= tri!(B1 * dB', 'U', 'N') + end - # Restore initial state. - B .= B0 + # Restore initial state. + B .= B0 + end return tuple_fill(NoRData(), Val(15 + Nargs)) end @@ -259,39 +269,43 @@ for (fname, elty) in ((:dgetri_, :Float64), (:sgetri_, :Float32)) args::Vararg{Any, Nargs}, ) where {Nargs} - # Pull out data. - N_p, lda_p, lwork_p, info_p = map(primal, (_N, _lda, _lwork, _info)) - N, lda, lwork, info = map(unsafe_load, (N_p, lda_p, lwork_p, info_p)) - A_p = primal(_A) - A = wrap_ptr_as_view(A_p, lda, N, N) - A_copy = copy(A) - - # Run forwards-pass. - ccall( - $(blas_name(fname)), Cvoid, - ( - Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, - Ptr{BlasInt}, Ptr{BlasInt}, - ), - N_p, A_p, lda_p, primal(_ipiv), primal(_work), lwork_p, info_p, - ) - - p = LinearAlgebra.ipiv2perm(unsafe_wrap(Array, primal(_ipiv), N), N) + GC.@preserve args begin + # Pull out data. + N_p, lda_p, lwork_p, info_p = map(primal, (_N, _lda, _lwork, _info)) + N, lda, lwork, info = map(unsafe_load, (N_p, lda_p, lwork_p, info_p)) + A_p = primal(_A) + A = wrap_ptr_as_view(A_p, lda, N, N) + A_copy = copy(A) + + # Run forwards-pass. + ccall( + $(blas_name(fname)), Cvoid, + ( + Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, + Ptr{BlasInt}, Ptr{BlasInt}, + ), + N_p, A_p, lda_p, primal(_ipiv), primal(_work), lwork_p, info_p, + ) + + p = LinearAlgebra.ipiv2perm(unsafe_wrap(Array, primal(_ipiv), N), N) + end _dA = tangent(_A) function getri_pb!!(::NoRData) - if lwork != -1 - dA = wrap_ptr_as_view(_dA, lda, N, N) - A .= A[:, p] - dA .= dA[:, p] - - # Cotangent w.r.t. L. - dL = -(A' * dA) / UnitLowerTriangular(A_copy)' - dU = -(UpperTriangular(A_copy)' \ (dA * A')) - dA .= tri!(dL, 'L', 'U') .+ tri!(dU, 'U', 'N') - - # Restore initial state. - A .= A_copy + GC.@preserve args begin + if lwork != -1 + dA = wrap_ptr_as_view(_dA, lda, N, N) + A .= A[:, p] + dA .= dA[:, p] + + # Cotangent w.r.t. L. + dL = -(A' * dA) / UnitLowerTriangular(A_copy)' + dU = -(UpperTriangular(A_copy)' \ (dA * A')) + dA .= tri!(dL, 'L', 'U') .+ tri!(dU, 'U', 'N') + + # Restore initial state. + A .= A_copy + end end return tuple_fill(NoRData(), Val(13 + Nargs)) @@ -318,43 +332,47 @@ for (fname, elty) in ((:dpotrf_, :Float64), (:spotrf_, :Float32)) args::Vararg{Any, Nargs}, ) where {Nargs} - # Pull out the data. - uplo_p, N_p, A_p, lda_p, info_p = map(primal, (_uplo, _N, _A, _lda, _info)) - uplo, lda, N = map(unsafe_load, (uplo_p, lda_p, N_p)) + GC.@preserve args begin + # Pull out the data. + uplo_p, N_p, A_p, lda_p, info_p = map(primal, (_uplo, _N, _A, _lda, _info)) + uplo, lda, N = map(unsafe_load, (uplo_p, lda_p, N_p)) - # Make a copy of the initial state for later restoration. - A = wrap_ptr_as_view(A_p, lda, N, N) - A_copy = copy(A) + # Make a copy of the initial state for later restoration. + A = wrap_ptr_as_view(A_p, lda, N, N) + A_copy = copy(A) - # Run forwards-pass. - ccall( - $(blas_name(fname)), Cvoid, - (Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}), - uplo_p, N_p, A_p, lda_p, info_p, - ) + # Run forwards-pass. + ccall( + $(blas_name(fname)), Cvoid, + (Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}), + uplo_p, N_p, A_p, lda_p, info_p, + ) + end _dA = tangent(_A) function potrf_pb!!(::NoRData) - dA = wrap_ptr_as_view(_dA, lda, N, N) - dA2 = dA + GC.@preserve args begin + dA = wrap_ptr_as_view(_dA, lda, N, N) + dA2 = dA + + # Compute cotangents. + if Char(uplo) == 'L' + E = LowerTriangular(2 * ones(N, N)) - Diagonal(ones(N)) + L = LowerTriangular(A) + B = L' \ (E' .* (dA2'L)) / L + dA .= 0.5 * __sym(B) .* E .+ triu!(dA2, 1) + else + E = UpperTriangular(2 * ones(N, N) - Diagonal(ones(N))) + U = UpperTriangular(A) + B = U \ ((U * dA2') .* E') / U' + dA .= 0.5 * __sym(B) .* E .+ tril!(dA2, -1) + end - # Compute cotangents. - if Char(uplo) == 'L' - E = LowerTriangular(2 * ones(N, N)) - Diagonal(ones(N)) - L = LowerTriangular(A) - B = L' \ (E' .* (dA2'L)) / L - dA .= 0.5 * __sym(B) .* E .+ triu!(dA2, 1) - else - E = UpperTriangular(2 * ones(N, N) - Diagonal(ones(N))) - U = UpperTriangular(A) - B = U \ ((U * dA2') .* E') / U' - dA .= 0.5 * __sym(B) .* E .+ tril!(dA2, -1) + # Restore initial state. + A .= A_copy end - # Restore initial state. - A .= A_copy - return tuple_fill(NoRData(), Val(11 + Nargs)) end return zero_fcodual(Cvoid()), potrf_pb!! @@ -379,48 +397,53 @@ for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32)) _info::CoDual{Ptr{BlasInt}}, args::Vararg{Any, Nargs}, ) where {Nargs} - # Pull out the data. - uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p = map( - primal, (_uplo, _N, _Nrhs, _A, _lda, _B, _ldb, _info) - ) - uplo, lda, N, ldb, Nrhs = map(unsafe_load, (uplo_p, lda_p, N_p, ldb_p, Nrhs_p)) - - # Make a copy of the initial state for later restoration. - A = wrap_ptr_as_view(A_p, lda, N, N) - B = wrap_ptr_as_view(B_p, ldb, N, Nrhs) - B_copy = copy(B) - - # Run forwards-pass. - ccall( - $(blas_name(fname)), Cvoid, - ( - Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, - Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, - ), - uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p, - ) + + GC.@preserve args begin + # Pull out the data. + uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p = map( + primal, (_uplo, _N, _Nrhs, _A, _lda, _B, _ldb, _info) + ) + uplo, lda, N, ldb, Nrhs = map(unsafe_load, (uplo_p, lda_p, N_p, ldb_p, Nrhs_p)) + + # Make a copy of the initial state for later restoration. + A = wrap_ptr_as_view(A_p, lda, N, N) + B = wrap_ptr_as_view(B_p, ldb, N, Nrhs) + B_copy = copy(B) + + # Run forwards-pass. + ccall( + $(blas_name(fname)), Cvoid, + ( + Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, + Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, + ), + uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p, + ) + end _dA = tangent(_A) _dB = tangent(_B) function potrs_pb!!(::NoRData) - dA = wrap_ptr_as_view(_dA, lda, N, N) - dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) + GC.@preserve args begin + dA = wrap_ptr_as_view(_dA, lda, N, N) + dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) + + # Compute cotangents. + if Char(uplo) == 'L' + tmp = __sym(B_copy * dB') / LowerTriangular(A)' + dA .-= 2 .* tril!(LinearAlgebra.LAPACK.potrs!('L', A, tmp)) + LinearAlgebra.LAPACK.potrs!('L', A, dB) + else + tmp = UpperTriangular(A)' \ __sym(B_copy * dB') + dA .-= 2 .* triu!((tmp / UpperTriangular(A)) / UpperTriangular(A)') + LinearAlgebra.LAPACK.potrs!('U', A, dB) + end - # Compute cotangents. - if Char(uplo) == 'L' - tmp = __sym(B_copy * dB') / LowerTriangular(A)' - dA .-= 2 .* tril!(LinearAlgebra.LAPACK.potrs!('L', A, tmp)) - LinearAlgebra.LAPACK.potrs!('L', A, dB) - else - tmp = UpperTriangular(A)' \ __sym(B_copy * dB') - dA .-= 2 .* triu!((tmp / UpperTriangular(A)) / UpperTriangular(A)') - LinearAlgebra.LAPACK.potrs!('U', A, dB) + # Restore initial state. + B .= B_copy end - # Restore initial state. - B .= B_copy - return tuple_fill(NoRData(), Val(14 + Nargs)) end return zero_fcodual(Cvoid()), potrs_pb!! diff --git a/test/integration_testing/logdensityproblemsad_interop.jl b/test/integration_testing/logdensityproblemsad_interop.jl index 22ac2507..31dcba27 100644 --- a/test/integration_testing/logdensityproblemsad_interop.jl +++ b/test/integration_testing/logdensityproblemsad_interop.jl @@ -7,19 +7,24 @@ LogDensityProblemsAD.logdensity(::TestLogDensity2, x) = -sum(abs2, x) LogDensityProblemsAD.dimension(::TestLogDensity2) = 20 test_gradient(x) = -2 .* x -# @testset "AD via Mooncake" begin -# l = TestLogDensity2() -# ∇l = ADgradient(Val(:Mooncake), l) +@testset "AD via Mooncake" begin + l = TestLogDensity2() + ∇l = ADgradient(Val(:Mooncake), l) -# @test dimension(∇l) == 20 -# @test capabilities(∇l) == LogDensityProblemsAD.LogDensityOrder(1) -# for _ in 1:100 -# x = randn(20) -# @test isapprox(@inferred(logdensity(∇l, x)), logdensity(l, x)) -# @test isapprox(logdensity_and_gradient(∇l, x)[1], logdensity(TestLogDensity2(), x)) -# @test isapprox(logdensity_and_gradient(∇l, x)[2], test_gradient(x)) -# end + @test dimension(∇l) == 20 + @test capabilities(∇l) == LogDensityProblemsAD.LogDensityOrder(1) + for _ in 1:100 + x = randn(20) + @test isapprox(@inferred(logdensity(∇l, x)), logdensity(l, x)) + @test isapprox(logdensity_and_gradient(∇l, x)[1], logdensity(TestLogDensity2(), x)) + @test isapprox(logdensity_and_gradient(∇l, x)[2], test_gradient(x)) + end -# @test ADgradient(ADTypes.AutoMooncake(debug_mode=false), l) isa typeof(∇l) -# @test parent(∇l) === l -# end + config = Mooncake.Config(; debug_mode=false) + @test ADgradient(ADTypes.AutoMooncake(; config), l) isa typeof(∇l) + @test parent(∇l) === l + + # Run in debug mode. + debug_config = Mooncake.Config(; debug_mode=true) + @test parent(ADgradient(ADTypes.AutoMooncake(; config=debug_config), l)) === l +end