diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2caa45be1..af00de5d7 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,3 +1,6 @@ +env: + SECRET_CODECOV_TOKEN: "nkcRFVXdaPNAbiI0x3qK/XUG8rWjBc8fU73YEyP35SeS465XORqrIYrHUbHuJTRyeyqNRdsHaBcV1P7TBbKAaTQAjHQ1Q0KYfd0uRMSWpZSCgTBz5AwttAxVfFrX+Ky3PzTi2TfDe0uPFZtFo0Asq6sUEr1on+Oo+j+q6br2NK6CrA5yKKuTX4Q2V/UPOIK4vNXY3+zDTKSNtr+HQOlcVEeRIk/0ZQ78Cjd52flEaVw8GWo/CC4YBzLtcOZgaFdgOTEDNHMr0mw6zLE4Y6nxq4lHVSoraSjxjhkB0pXTZ1c51yHX8Jc+q6HC5s87+2Zq5YtsuQSGao+eMtkTAYwfLw==;U2FsdGVkX18z27J3+gNgxsPNnXA0ad4LvZnXeohTam7/6UPqX5+3BYI0tAiVkCho4vlJyL7dd8JEyNtk9BFXsg==" + steps: - label: "Julia v1" plugins: @@ -5,7 +8,9 @@ steps: version: "1" - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: - codecov: true + dirs: + - src + - ext agents: queue: "juliagpu" cuda: "*" diff --git a/Project.toml b/Project.toml index a4c88649b..6ac178be7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.6" +version = "0.4.7" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt.jl index 7ed8bbf7c..0ab3eb1f2 100644 --- a/ext/MooncakeCUDAExt.jl +++ b/ext/MooncakeCUDAExt.jl @@ -1,71 +1,68 @@ module MooncakeCUDAExt - using LinearAlgebra, Random, Mooncake +using LinearAlgebra, Random, Mooncake - using Base: IEEEFloat - using CUDA: CuArray, cu +using Base: IEEEFloat +using CUDA: CuArray, cu - import Mooncake: - MinimalCtx, - rrule!!, - @is_primitive, - tangent_type, - zero_tangent, - randn_tangent, - increment!!, - set_to_zero!!, - _add_to_primal, - _diff, - _dot, - _scale, - TestUtils, - CoDual, - NoPullback +import Mooncake: + MinimalCtx, + rrule!!, + @is_primitive, + tangent_type, + zero_tangent, + randn_tangent, + increment!!, + set_to_zero!!, + _add_to_primal, + _diff, + _dot, + _scale, + TestUtils, + CoDual, + NoPullback - import Mooncake.TestUtils: populate_address_map!, AddressMap, __increment_should_allocate +import Mooncake.TestUtils: populate_address_map!, AddressMap, __increment_should_allocate - # Tell Mooncake.jl how to handle CuArrays. +# Tell Mooncake.jl how to handle CuArrays. - tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P - zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x) - function randn_tangent(rng::AbstractRNG, x::CuArray{Float32}) - return cu(randn(rng, Float32, size(x)...)) - end - TestUtils.has_equal_data(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x == y - increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y - __increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true - set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0 - _add_to_primal(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x + y - _diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y - _dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y)) - _scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y - function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray) - k = pointer_from_objref(p) - v = pointer_from_objref(t) - haskey(m, k) && (@assert m[k] == v) - m[k] = v - return m - end - function Mooncake._verify_fdata_value(p::CuArray, f::CuArray) - if size(p) != size(f) - throw(InvalidFDataException("p has size $(size(p)) but f has size $(size(f))")) - end - return nothing +tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P +zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x) +function randn_tangent(rng::AbstractRNG, x::CuArray{Float32}) + return cu(randn(rng, Float32, size(x)...)) +end +TestUtils.has_equal_data(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x == y +increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y +__increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true +set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0 +_add_to_primal(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x + y +_diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y +_dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y)) +_scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y +function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray) + k = pointer_from_objref(p) + v = pointer_from_objref(t) + haskey(m, k) && (@assert m[k] == v) + m[k] = v + return m +end +function Mooncake._verify_fdata_value(p::CuArray, f::CuArray) + if size(p) != size(f) + throw(InvalidFDataException("p has size $(size(p)) but f has size $(size(f))")) end + return nothing +end - # Basic rules for operating on CuArrays. +# Basic rules for operating on CuArrays. + +@is_primitive( + MinimalCtx, Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N}, +) +function rrule!!( + p::CoDual{Type{P}}, init::CoDual{UndefInitializer}, dims::CoDual{Int}... +) where {P<:CuArray{<:Base.IEEEFloat}} + _dims = map(primal, dims) + return CoDual(P(undef, _dims), P(undef, _dims)), NoPullback(p, init, dims...) +end - @is_primitive( - MinimalCtx, - Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N}, - ) - function rrule!!( - p::CoDual{Type{P}}, - init::CoDual{UndefInitializer}, - dims::CoDual{Int}... - ) where {P<:CuArray{<:Base.IEEEFloat}} - _dims = map(primal, dims) - y = CoDual(P(undef, _dims), P(undef, _dims)) - return y, NoPullback(p, init, dims...) - end end diff --git a/ext/MooncakeJETExt.jl b/ext/MooncakeJETExt.jl index 7e232bbf1..7b90550e5 100644 --- a/ext/MooncakeJETExt.jl +++ b/ext/MooncakeJETExt.jl @@ -1,7 +1,8 @@ module MooncakeJETExt - using JET, Mooncake +using JET, Mooncake + +Mooncake.TestUtils.test_opt(::Mooncake.TestUtils.Shim, args...) = JET.test_opt(args...) +Mooncake.TestUtils.report_opt(::Mooncake.TestUtils.Shim, tt) = JET.report_opt(tt) - Mooncake.TestUtils.test_opt(::Mooncake.TestUtils.Shim, args...) = JET.test_opt(args...) - Mooncake.TestUtils.report_opt(::Mooncake.TestUtils.Shim, tt) = JET.report_opt(tt) end diff --git a/ext/MooncakeLogDensityProblemsADExt.jl b/ext/MooncakeLogDensityProblemsADExt.jl index c2d39f8cf..79d756ac5 100644 --- a/ext/MooncakeLogDensityProblemsADExt.jl +++ b/ext/MooncakeLogDensityProblemsADExt.jl @@ -3,17 +3,10 @@ module MooncakeLogDensityProblemsADExt -if isdefined(Base, :get_extension) - using ADTypes - using LogDensityProblemsAD: ADGradientWrapper - import LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity - import Mooncake -else - using ADTypes - using ..LogDensityProblemsAD: ADGradientWrapper - import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity - import ..Mooncake -end +using ADTypes +using LogDensityProblemsAD: ADGradientWrapper +import LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity +import Mooncake struct MooncakeGradientLogDensity{Trule, L} <: ADGradientWrapper rule::Trule diff --git a/ext/MooncakeSpecialFunctionsExt.jl b/ext/MooncakeSpecialFunctionsExt.jl index ade619d4d..eca33b5e0 100644 --- a/ext/MooncakeSpecialFunctionsExt.jl +++ b/ext/MooncakeSpecialFunctionsExt.jl @@ -1,11 +1,12 @@ module MooncakeSpecialFunctionsExt - using SpecialFunctions, Mooncake +using SpecialFunctions, Mooncake - import Mooncake: @from_rrule, DefaultCtx +import Mooncake: @from_rrule, DefaultCtx + +@from_rrule DefaultCtx Tuple{typeof(airyai), Float64} +@from_rrule DefaultCtx Tuple{typeof(airyaix), Float64} +@from_rrule DefaultCtx Tuple{typeof(erfc), Float64} +@from_rrule DefaultCtx Tuple{typeof(erfcx), Float64} - @from_rrule DefaultCtx Tuple{typeof(airyai), Float64} - @from_rrule DefaultCtx Tuple{typeof(airyaix), Float64} - @from_rrule DefaultCtx Tuple{typeof(erfc), Float64} - @from_rrule DefaultCtx Tuple{typeof(erfcx), Float64} end diff --git a/src/test_utils.jl b/src/test_utils.jl index fb658d272..ec083f868 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -499,9 +499,9 @@ function test_rule( # Construct the rule. rule = Mooncake.build_rrule(interp, _typeof(__get_primals(x)); debug_mode) - # If we're requiring `is_primitive`, then check that `rule == rrule!!`. + # If something is primitive, then the rule should be `rrule!!`. if is_primitive - @test rule === rrule!! + @test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!) end # Generate random tangents for anything that is not already a CoDual. diff --git a/test/integration_testing/cuda.jl b/test/integration_testing/cuda.jl index 9df8ea62d..84c04e135 100644 --- a/test/integration_testing/cuda.jl +++ b/test/integration_testing/cuda.jl @@ -5,13 +5,13 @@ using CUDA # Check we can operate on CuArrays. test_tangent( Xoshiro(123456), - CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}(undef, 8, 8); + CuArray{Float32, 2, CUDA.DeviceMemory}(undef, 8, 8); interface_only=false, ) # Check we can instantiate a CuArray. test_rule( - sr(123456), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, undef, 256; - interface_only=true, is_primitive=true, + sr(123456), CuArray{Float32, 1, CUDA.DeviceMemory}, undef, 256; + interface_only=true, is_primitive=true, debug_mode=true, ) end