From 374f3e04a11f6f0b599e40661f989803ea9d6d8c Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 17:19:24 +0100 Subject: [PATCH 01/21] Sort out extension formattingg --- ext/TapirCUDAExt.jl | 121 ++++++++++++++-------------- ext/TapirDynamicPPLExt.jl | 9 +-- ext/TapirJETExt.jl | 7 +- ext/TapirLogDensityProblemsADExt.jl | 15 +--- ext/TapirSpecialFunctionsExt.jl | 13 +-- 5 files changed, 78 insertions(+), 87 deletions(-) diff --git a/ext/TapirCUDAExt.jl b/ext/TapirCUDAExt.jl index 7476428f8..68207eac3 100644 --- a/ext/TapirCUDAExt.jl +++ b/ext/TapirCUDAExt.jl @@ -1,71 +1,72 @@ module TapirCUDAExt - using LinearAlgebra, Random, Tapir +using LinearAlgebra, Random, Tapir - using Base: IEEEFloat - using CUDA: CuArray, cu +using Base: IEEEFloat +using CUDA: CuArray, cu - import Tapir: - MinimalCtx, - rrule!!, - @is_primitive, - tangent_type, - zero_tangent, - randn_tangent, - increment!!, - set_to_zero!!, - _add_to_primal, - _diff, - _dot, - _scale, - TestUtils, - CoDual, - NoPullback +import Tapir: + MinimalCtx, + rrule!!, + @is_primitive, + tangent_type, + zero_tangent, + randn_tangent, + increment!!, + set_to_zero!!, + _add_to_primal, + _diff, + _dot, + _scale, + TestUtils, + CoDual, + NoPullback - import Tapir.TestUtils: populate_address_map!, AddressMap, __increment_should_allocate +import Tapir.TestUtils: populate_address_map!, AddressMap, __increment_should_allocate - # Tell Tapir.jl how to handle CuArrays. +# Tell Tapir.jl how to handle CuArrays. - tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P - zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x) - function randn_tangent(rng::AbstractRNG, x::CuArray{Float32}) - return cu(randn(rng, Float32, size(x)...)) - end - TestUtils.has_equal_data(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x == y - increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y - __increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true - set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0 - _add_to_primal(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x + y - _diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y - _dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y)) - _scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y - function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray) - k = pointer_from_objref(p) - v = pointer_from_objref(t) - haskey(m, k) && (@assert m[k] == v) - m[k] = v - return m - end - function Tapir._verify_fdata_value(p::CuArray, f::CuArray) - if size(p) != size(f) - throw(InvalidFDataException("p has size $(size(p)) but f has size $(size(f))")) - end - return nothing +tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P +zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x) +function randn_tangent(rng::AbstractRNG, x::CuArray{Float32}) + return cu(randn(rng, Float32, size(x)...)) +end +TestUtils.has_equal_data(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x == y +increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y +__increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true +set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0 +_add_to_primal(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x + y +_diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y +_dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y)) +_scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y +function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray) + k = pointer_from_objref(p) + v = pointer_from_objref(t) + haskey(m, k) && (@assert m[k] == v) + m[k] = v + return m +end +function Tapir._verify_fdata_value(p::CuArray, f::CuArray) + if size(p) != size(f) + throw(InvalidFDataException("p has size $(size(p)) but f has size $(size(f))")) end + return nothing +end - # Basic rules for operating on CuArrays. +# Basic rules for operating on CuArrays. + +@is_primitive( + MinimalCtx, + Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N}, +) +function rrule!!( + p::CoDual{Type{P}}, + init::CoDual{UndefInitializer}, + dims::CoDual{Int}... +) where {P<:CuArray{<:Base.IEEEFloat}} + _dims = map(primal, dims) + y = CoDual(P(undef, _dims), P(undef, _dims)) + return y, NoPullback(p, init, dims...) +end - @is_primitive( - MinimalCtx, - Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N}, - ) - function rrule!!( - p::CoDual{Type{P}}, - init::CoDual{UndefInitializer}, - dims::CoDual{Int}... - ) where {P<:CuArray{<:Base.IEEEFloat}} - _dims = map(primal, dims) - y = CoDual(P(undef, _dims), P(undef, _dims)) - return y, NoPullback(p, init, dims...) - end end diff --git a/ext/TapirDynamicPPLExt.jl b/ext/TapirDynamicPPLExt.jl index 8348b9fb0..f82b38d4c 100644 --- a/ext/TapirDynamicPPLExt.jl +++ b/ext/TapirDynamicPPLExt.jl @@ -1,12 +1,7 @@ module TapirDynamicPPLExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL, istrans - using Tapir: Tapir -else - using ..DynamicPPL: DynamicPPL, istrans - using ..Tapir: Tapir -end +using DynamicPPL: DynamicPPL, istrans +using Tapir: Tapir using Tapir: DefaultCtx, CoDual, simple_zero_adjoint diff --git a/ext/TapirJETExt.jl b/ext/TapirJETExt.jl index bae7178f2..3cb6477f8 100644 --- a/ext/TapirJETExt.jl +++ b/ext/TapirJETExt.jl @@ -1,7 +1,8 @@ module TapirJETExt - using JET, Tapir +using JET, Tapir + +Tapir.TestUtils.test_opt(::Tapir.TestUtils.Shim, args...) = JET.test_opt(args...) +Tapir.TestUtils.report_opt(::Tapir.TestUtils.Shim, tt) = JET.report_opt(tt) - Tapir.TestUtils.test_opt(::Tapir.TestUtils.Shim, args...) = JET.test_opt(args...) - Tapir.TestUtils.report_opt(::Tapir.TestUtils.Shim, tt) = JET.report_opt(tt) end diff --git a/ext/TapirLogDensityProblemsADExt.jl b/ext/TapirLogDensityProblemsADExt.jl index 592cac06c..fcd91c497 100644 --- a/ext/TapirLogDensityProblemsADExt.jl +++ b/ext/TapirLogDensityProblemsADExt.jl @@ -3,17 +3,10 @@ module TapirLogDensityProblemsADExt -if isdefined(Base, :get_extension) - using ADTypes - using LogDensityProblemsAD: ADGradientWrapper - import LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity - import Tapir -else - using ADTypes - using ..LogDensityProblemsAD: ADGradientWrapper - import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity - import ..Tapir -end +using ADTypes +using LogDensityProblemsAD: ADGradientWrapper +import LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity +import Tapir struct TapirGradientLogDensity{Trule, L} <: ADGradientWrapper rule::Trule diff --git a/ext/TapirSpecialFunctionsExt.jl b/ext/TapirSpecialFunctionsExt.jl index 651fe0184..7a4d98e31 100644 --- a/ext/TapirSpecialFunctionsExt.jl +++ b/ext/TapirSpecialFunctionsExt.jl @@ -1,11 +1,12 @@ module TapirSpecialFunctionsExt - using SpecialFunctions, Tapir +using SpecialFunctions, Tapir - import Tapir: @from_rrule, DefaultCtx +import Tapir: @from_rrule, DefaultCtx + +@from_rrule DefaultCtx Tuple{typeof(airyai), Float64} +@from_rrule DefaultCtx Tuple{typeof(airyaix), Float64} +@from_rrule DefaultCtx Tuple{typeof(erfc), Float64} +@from_rrule DefaultCtx Tuple{typeof(erfcx), Float64} - @from_rrule DefaultCtx Tuple{typeof(airyai), Float64} - @from_rrule DefaultCtx Tuple{typeof(airyaix), Float64} - @from_rrule DefaultCtx Tuple{typeof(erfc), Float64} - @from_rrule DefaultCtx Tuple{typeof(erfcx), Float64} end From 6e797972f59568a9aae6a9cd5372800b970a4e0a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sat, 14 Sep 2024 12:44:28 +0100 Subject: [PATCH 02/21] Add codecov key to buildkite --- .buildkite/pipeline.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2caa45be1..4cd64f451 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -13,3 +13,5 @@ steps: timeout_in_minutes: 60 env: TEST_GROUP: "gpu" +env: + SECRET_CODECOV_SECRET: "O3JDNB8okS77naB4HEUJkS9jBDuCAAawjKp4wio5T5TAjacTjr+aSzx9Zp0Zp34XywZJjnQTIfBs7ryK46hA5oli20QBTBDBaEaxCDLWnrqpIe0N4KXiRMFthds8U3NT9rnCej6XehdhIU6qtbXjd8I0rAnJYJZQ1ffrlXzfhNX35zuKo2UZ4tl5aMFiMf5bTBLxe9d3F1xUthbLEQmePZYGbHfpJ7K/6op44UF9UL+GIV2MgKS0ZmkiPSBz2ICa/r6ny6iZsKU3ddM38eQfjIE50c42X/3JDjQhP4Pb/OTPbGk6udXISbjrascxgB6OC+oILfm53sME4ARdlZV7MA==;U2FsdGVkX19vGkyy2rbTqUEV8Q1cJK6muQec2lufYO8=" From ec86c1bbe344c6f777f09fb1cbcf46fca306961f Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sat, 14 Sep 2024 12:56:08 +0100 Subject: [PATCH 03/21] Update .buildkite/pipeline.yml --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 4cd64f451..a1c7c3322 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -14,4 +14,4 @@ steps: env: TEST_GROUP: "gpu" env: - SECRET_CODECOV_SECRET: "O3JDNB8okS77naB4HEUJkS9jBDuCAAawjKp4wio5T5TAjacTjr+aSzx9Zp0Zp34XywZJjnQTIfBs7ryK46hA5oli20QBTBDBaEaxCDLWnrqpIe0N4KXiRMFthds8U3NT9rnCej6XehdhIU6qtbXjd8I0rAnJYJZQ1ffrlXzfhNX35zuKo2UZ4tl5aMFiMf5bTBLxe9d3F1xUthbLEQmePZYGbHfpJ7K/6op44UF9UL+GIV2MgKS0ZmkiPSBz2ICa/r6ny6iZsKU3ddM38eQfjIE50c42X/3JDjQhP4Pb/OTPbGk6udXISbjrascxgB6OC+oILfm53sME4ARdlZV7MA==;U2FsdGVkX19vGkyy2rbTqUEV8Q1cJK6muQec2lufYO8=" + SECRET_CODECOV_TOKEN: "evD6ybRun2WW45mxNN8GcGySZi1/f9lZRlE2sv3tP/Vj8pU7QdGQ6mYrMauhnpU7RiVUBbeypi6YFGqMSbanbpmKZ32JOrlc841KYt4iJnpUsY3RrjXsRN39bJZ8b/qMcFiF5U0JcENb1/nLs8OzeFLZECxNKX0lK2OlrZ7ZcEGfa4xFsBSlnUAxWbUzTrg0xgfCrLIEp2oRN3HzvdV3kgNPyKgqRVOd/qbgLu+RuElkUj7T3AAfPgn6f57qiEIwPQ8gDIxwVfpxGyUWwcGitadfW4hGEHmliLxvPNAxCUz+FIyvTmviOV/5wJYHOYeOeymQO1qFYi8PpOg4odY8YA==;U2FsdGVkX1+W15397ySKjT/XFmT88oi/K542mpVC+4c=" From 7d8ea4363d343dbf3dbff6e55c342577a9d8908a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sat, 14 Sep 2024 13:10:36 +0100 Subject: [PATCH 04/21] Update .buildkite/pipeline.yml --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index a1c7c3322..4aa0672dc 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -14,4 +14,4 @@ steps: env: TEST_GROUP: "gpu" env: - SECRET_CODECOV_TOKEN: "evD6ybRun2WW45mxNN8GcGySZi1/f9lZRlE2sv3tP/Vj8pU7QdGQ6mYrMauhnpU7RiVUBbeypi6YFGqMSbanbpmKZ32JOrlc841KYt4iJnpUsY3RrjXsRN39bJZ8b/qMcFiF5U0JcENb1/nLs8OzeFLZECxNKX0lK2OlrZ7ZcEGfa4xFsBSlnUAxWbUzTrg0xgfCrLIEp2oRN3HzvdV3kgNPyKgqRVOd/qbgLu+RuElkUj7T3AAfPgn6f57qiEIwPQ8gDIxwVfpxGyUWwcGitadfW4hGEHmliLxvPNAxCUz+FIyvTmviOV/5wJYHOYeOeymQO1qFYi8PpOg4odY8YA==;U2FsdGVkX1+W15397ySKjT/XFmT88oi/K542mpVC+4c=" + SECRET_CODECOV_TOKEN: "M+fbiuiAVojU6VIquhJ6+oC/EI4hL8Jfubdnc03HV1t0Fpn//CxfnTZi/lT8SoBromLOCIYhQdhCrKla8VgnWVPNmcmrpBkEHtJJ6egmclwwYQGsPAsFrubosSquhK3IGgr2JOxQo4Aa9t5tx1yahEeRxQOKKhgF4pKIscGCz7OYGTIivOMg69AFPKl91i/8rtwVjS4Q2cBj0+reA8jx0DiM70tQe0ceYdXOCczfSkE1iegpSzreYTPoKc6DDt0keTr6FqRlUPKr4TuqA+smYV91MY7OUwi1w7kb0nkzPyCES72T9i1RBAHhC0p28WVRUEkOn7X/V4nhiZOodaStRw==;U2FsdGVkX18+XJdcOzrVVhv3a6+jf7qxRz0TxBPNLyXzoom+BTQu+CKUjld70VdDUGSUseqiNXS9ZdnNIAV53w==" From a7a47868d13eaca1caefcf15646ab05660020e7d Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sat, 14 Sep 2024 13:29:47 +0100 Subject: [PATCH 05/21] Fix formatting --- .buildkite/pipeline.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 4aa0672dc..1db4f3d33 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,3 +1,6 @@ +env: + SECRET_CODECOV_TOKEN: "M+fbiuiAVojU6VIquhJ6+oC/EI4hL8Jfubdnc03HV1t0Fpn//CxfnTZi/lT8SoBromLOCIYhQdhCrKla8VgnWVPNmcmrpBkEHtJJ6egmclwwYQGsPAsFrubosSquhK3IGgr2JOxQo4Aa9t5tx1yahEeRxQOKKhgF4pKIscGCz7OYGTIivOMg69AFPKl91i/8rtwVjS4Q2cBj0+reA8jx0DiM70tQe0ceYdXOCczfSkE1iegpSzreYTPoKc6DDt0keTr6FqRlUPKr4TuqA+smYV91MY7OUwi1w7kb0nkzPyCES72T9i1RBAHhC0p28WVRUEkOn7X/V4nhiZOodaStRw==;U2FsdGVkX18+XJdcOzrVVhv3a6+jf7qxRz0TxBPNLyXzoom+BTQu+CKUjld70VdDUGSUseqiNXS9ZdnNIAV53w==" + steps: - label: "Julia v1" plugins: @@ -13,5 +16,3 @@ steps: timeout_in_minutes: 60 env: TEST_GROUP: "gpu" -env: - SECRET_CODECOV_TOKEN: "M+fbiuiAVojU6VIquhJ6+oC/EI4hL8Jfubdnc03HV1t0Fpn//CxfnTZi/lT8SoBromLOCIYhQdhCrKla8VgnWVPNmcmrpBkEHtJJ6egmclwwYQGsPAsFrubosSquhK3IGgr2JOxQo4Aa9t5tx1yahEeRxQOKKhgF4pKIscGCz7OYGTIivOMg69AFPKl91i/8rtwVjS4Q2cBj0+reA8jx0DiM70tQe0ceYdXOCczfSkE1iegpSzreYTPoKc6DDt0keTr6FqRlUPKr4TuqA+smYV91MY7OUwi1w7kb0nkzPyCES72T9i1RBAHhC0p28WVRUEkOn7X/V4nhiZOodaStRw==;U2FsdGVkX18+XJdcOzrVVhv3a6+jf7qxRz0TxBPNLyXzoom+BTQu+CKUjld70VdDUGSUseqiNXS9ZdnNIAV53w==" From 9cb0928473b429f29e4c041cdb503612c422b4ab Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 09:14:02 +0100 Subject: [PATCH 06/21] Bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f135556c0..65294a512 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.50" +version = "0.2.51" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 2991f0de6a4dffc9b58724daa7cb347ebab56471 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 09:58:29 +0100 Subject: [PATCH 07/21] Unbump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 65294a512..f135556c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.51" +version = "0.2.50" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From bf46c7bbf2a561839e343b880031a826635c6477 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 12:13:34 +0100 Subject: [PATCH 08/21] Bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f135556c0..65294a512 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.50" +version = "0.2.51" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From a8bb89aa1e8472c22e963f8df25f11db08b208f9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 12:23:36 +0100 Subject: [PATCH 09/21] Unbump commit --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 65294a512..f135556c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.51" +version = "0.2.50" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 54690fc5e6dbddce541c23748e121decab8e8cfb Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 12:38:35 +0100 Subject: [PATCH 10/21] Fix deprecations --- test/integration_testing/cuda.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration_testing/cuda.jl b/test/integration_testing/cuda.jl index 9df8ea62d..664e084c1 100644 --- a/test/integration_testing/cuda.jl +++ b/test/integration_testing/cuda.jl @@ -5,13 +5,13 @@ using CUDA # Check we can operate on CuArrays. test_tangent( Xoshiro(123456), - CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}(undef, 8, 8); + CuArray{Float32, 2, CUDA.DeviceMemory}(undef, 8, 8); interface_only=false, ) # Check we can instantiate a CuArray. test_rule( - sr(123456), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, undef, 256; + sr(123456), CuArray{Float32, 1, CUDA.DeviceMemory}, undef, 256; interface_only=true, is_primitive=true, ) end From f1f30148909723ba207d980c0bd5e4fc45a240e0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 16:06:58 +0100 Subject: [PATCH 11/21] Add ext and src to coverage for buildkite --- .buildkite/pipeline.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 1db4f3d33..70cf9570a 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -9,6 +9,9 @@ steps: - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext agents: queue: "juliagpu" cuda: "*" From b6feae53c6a258a4d7837d8c6b7bfef122f78607 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 17:14:24 +0100 Subject: [PATCH 12/21] Remove redundant line --- .buildkite/pipeline.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 70cf9570a..ef9914101 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -8,7 +8,6 @@ steps: version: "1" - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: - codecov: true dirs: - src - ext From 38370c1c201019fc9cc78172865ad6077fcf257d Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 19 Sep 2024 13:23:25 +0100 Subject: [PATCH 13/21] Tweak rule for more feedback --- ext/MooncakeCUDAExt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt.jl index cec1e46f1..77644b6fc 100644 --- a/ext/MooncakeCUDAExt.jl +++ b/ext/MooncakeCUDAExt.jl @@ -64,6 +64,7 @@ function rrule!!( init::CoDual{UndefInitializer}, dims::CoDual{Int}... ) where {P<:CuArray{<:Base.IEEEFloat}} + @show "a rule?" _dims = map(primal, dims) y = CoDual(P(undef, _dims), P(undef, _dims)) return y, NoPullback(p, init, dims...) From 6f28f317a240a5e0379d75118f0cd8478c04f771 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 19 Sep 2024 13:34:34 +0100 Subject: [PATCH 14/21] Remove show command -- seems fine --- ext/MooncakeCUDAExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt.jl index 77644b6fc..cec1e46f1 100644 --- a/ext/MooncakeCUDAExt.jl +++ b/ext/MooncakeCUDAExt.jl @@ -64,7 +64,6 @@ function rrule!!( init::CoDual{UndefInitializer}, dims::CoDual{Int}... ) where {P<:CuArray{<:Base.IEEEFloat}} - @show "a rule?" _dims = map(primal, dims) y = CoDual(P(undef, _dims), P(undef, _dims)) return y, NoPullback(p, init, dims...) From cf5bd71346296251df2b989f21b2240ba74bcccd Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 17:20:14 +0100 Subject: [PATCH 15/21] Update pipeline.yml Try new key --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index ef9914101..c982545da 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,5 +1,5 @@ env: - SECRET_CODECOV_TOKEN: "M+fbiuiAVojU6VIquhJ6+oC/EI4hL8Jfubdnc03HV1t0Fpn//CxfnTZi/lT8SoBromLOCIYhQdhCrKla8VgnWVPNmcmrpBkEHtJJ6egmclwwYQGsPAsFrubosSquhK3IGgr2JOxQo4Aa9t5tx1yahEeRxQOKKhgF4pKIscGCz7OYGTIivOMg69AFPKl91i/8rtwVjS4Q2cBj0+reA8jx0DiM70tQe0ceYdXOCczfSkE1iegpSzreYTPoKc6DDt0keTr6FqRlUPKr4TuqA+smYV91MY7OUwi1w7kb0nkzPyCES72T9i1RBAHhC0p28WVRUEkOn7X/V4nhiZOodaStRw==;U2FsdGVkX18+XJdcOzrVVhv3a6+jf7qxRz0TxBPNLyXzoom+BTQu+CKUjld70VdDUGSUseqiNXS9ZdnNIAV53w==" + SECRET_CODECOV_TOKEN: "OQDm6ItKZvxShVmPvjHQjJsCg0tbXqkAE6iEJAT/pD9q6kxlgmEywlf3atV4puF7YkydNDrNivL2uhqy8QoK/bKJrW0wHQ7+9F1Q3NZg/OOkAIRZxDg5+90E7ucJwquk04iEq0W4sD5D4XXeKLRUe+vs0+kmID+Pe6L2i+K2UOAkPJ/JHPgvFzPUv3lbEt0qhv81hn9qtE3wrlp4SbsQWMbKixFSftdWm7Op8+mBA55pjoq+PVI2FOksFjsM/zSBvQn+oJbz6OVi7BRDr5LFm59DOMv/udKkE1YbkNCxL5ao4YeGIq+vJwJbLwAC74ZIeHcPpoFQAccburSue/JYJA==;U2FsdGVkX19J5BVBerBpwP0M6aA4aGNFV8QPDoP2/0EnM+4xawaahn2AkCfjf0ZJ9sB2KEJZLhz2BAxpNhh+pTRtPwehvfnpDWEtAMJTjWs=" steps: - label: "Julia v1" From 08257b1a7a6fa67e55be854ce17f730ed5aea351 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 17:42:38 +0100 Subject: [PATCH 16/21] Update pipeline.yml --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index c982545da..af00de5d7 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,5 +1,5 @@ env: - SECRET_CODECOV_TOKEN: "OQDm6ItKZvxShVmPvjHQjJsCg0tbXqkAE6iEJAT/pD9q6kxlgmEywlf3atV4puF7YkydNDrNivL2uhqy8QoK/bKJrW0wHQ7+9F1Q3NZg/OOkAIRZxDg5+90E7ucJwquk04iEq0W4sD5D4XXeKLRUe+vs0+kmID+Pe6L2i+K2UOAkPJ/JHPgvFzPUv3lbEt0qhv81hn9qtE3wrlp4SbsQWMbKixFSftdWm7Op8+mBA55pjoq+PVI2FOksFjsM/zSBvQn+oJbz6OVi7BRDr5LFm59DOMv/udKkE1YbkNCxL5ao4YeGIq+vJwJbLwAC74ZIeHcPpoFQAccburSue/JYJA==;U2FsdGVkX19J5BVBerBpwP0M6aA4aGNFV8QPDoP2/0EnM+4xawaahn2AkCfjf0ZJ9sB2KEJZLhz2BAxpNhh+pTRtPwehvfnpDWEtAMJTjWs=" + SECRET_CODECOV_TOKEN: "nkcRFVXdaPNAbiI0x3qK/XUG8rWjBc8fU73YEyP35SeS465XORqrIYrHUbHuJTRyeyqNRdsHaBcV1P7TBbKAaTQAjHQ1Q0KYfd0uRMSWpZSCgTBz5AwttAxVfFrX+Ky3PzTi2TfDe0uPFZtFo0Asq6sUEr1on+Oo+j+q6br2NK6CrA5yKKuTX4Q2V/UPOIK4vNXY3+zDTKSNtr+HQOlcVEeRIk/0ZQ78Cjd52flEaVw8GWo/CC4YBzLtcOZgaFdgOTEDNHMr0mw6zLE4Y6nxq4lHVSoraSjxjhkB0pXTZ1c51yHX8Jc+q6HC5s87+2Zq5YtsuQSGao+eMtkTAYwfLw==;U2FsdGVkX18z27J3+gNgxsPNnXA0ad4LvZnXeohTam7/6UPqX5+3BYI0tAiVkCho4vlJyL7dd8JEyNtk9BFXsg==" steps: - label: "Julia v1" From 0f405ef91edaf7bd624e3c3f9ede500b1bf6861c Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 4 Oct 2024 09:14:18 +0100 Subject: [PATCH 17/21] Update test/integration_testing/cuda.jl --- test/integration_testing/cuda.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration_testing/cuda.jl b/test/integration_testing/cuda.jl index 664e084c1..84c04e135 100644 --- a/test/integration_testing/cuda.jl +++ b/test/integration_testing/cuda.jl @@ -12,6 +12,6 @@ using CUDA # Check we can instantiate a CuArray. test_rule( sr(123456), CuArray{Float32, 1, CUDA.DeviceMemory}, undef, 256; - interface_only=true, is_primitive=true, + interface_only=true, is_primitive=true, debug_mode=true, ) end From d5366bbb1931bfde3a1059fd0a51cf3b61852e22 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 4 Oct 2024 09:41:35 +0100 Subject: [PATCH 18/21] Test is_primitive properly --- src/test_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index fb658d272..ec083f868 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -499,9 +499,9 @@ function test_rule( # Construct the rule. rule = Mooncake.build_rrule(interp, _typeof(__get_primals(x)); debug_mode) - # If we're requiring `is_primitive`, then check that `rule == rrule!!`. + # If something is primitive, then the rule should be `rrule!!`. if is_primitive - @test rule === rrule!! + @test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!) end # Generate random tangents for anything that is not already a CoDual. From 74f0c88a8c2c0fa1d1ddc6e630a8029bd2ad822a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 4 Oct 2024 09:47:56 +0100 Subject: [PATCH 19/21] Fix formatting again --- ext/MooncakeJETExt.jl | 7 ++++--- ext/MooncakeLogDensityProblemsADExt.jl | 15 ++++----------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/ext/MooncakeJETExt.jl b/ext/MooncakeJETExt.jl index 7e232bbf1..7b90550e5 100644 --- a/ext/MooncakeJETExt.jl +++ b/ext/MooncakeJETExt.jl @@ -1,7 +1,8 @@ module MooncakeJETExt - using JET, Mooncake +using JET, Mooncake + +Mooncake.TestUtils.test_opt(::Mooncake.TestUtils.Shim, args...) = JET.test_opt(args...) +Mooncake.TestUtils.report_opt(::Mooncake.TestUtils.Shim, tt) = JET.report_opt(tt) - Mooncake.TestUtils.test_opt(::Mooncake.TestUtils.Shim, args...) = JET.test_opt(args...) - Mooncake.TestUtils.report_opt(::Mooncake.TestUtils.Shim, tt) = JET.report_opt(tt) end diff --git a/ext/MooncakeLogDensityProblemsADExt.jl b/ext/MooncakeLogDensityProblemsADExt.jl index 29f3870d2..886bb7921 100644 --- a/ext/MooncakeLogDensityProblemsADExt.jl +++ b/ext/MooncakeLogDensityProblemsADExt.jl @@ -3,17 +3,10 @@ module MooncakeLogDensityProblemsADExt -if isdefined(Base, :get_extension) - using ADTypes - using LogDensityProblemsAD: ADGradientWrapper - import LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity - import Mooncake -else - using ADTypes - using ..LogDensityProblemsAD: ADGradientWrapper - import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity - import ..Mooncake -end +using ADTypes +using LogDensityProblemsAD: ADGradientWrapper +import LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity +import Mooncake struct MooncakeGradientLogDensity{Trule, L} <: ADGradientWrapper rule::Trule From f6d8b7da28c1eb2a0e81fbff2439e236312ab005 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 4 Oct 2024 09:49:42 +0100 Subject: [PATCH 20/21] Formatting --- ext/MooncakeCUDAExt.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ext/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt.jl index cec1e46f1..0ab3eb1f2 100644 --- a/ext/MooncakeCUDAExt.jl +++ b/ext/MooncakeCUDAExt.jl @@ -56,17 +56,13 @@ end # Basic rules for operating on CuArrays. @is_primitive( - MinimalCtx, - Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N}, + MinimalCtx, Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N}, ) function rrule!!( - p::CoDual{Type{P}}, - init::CoDual{UndefInitializer}, - dims::CoDual{Int}... + p::CoDual{Type{P}}, init::CoDual{UndefInitializer}, dims::CoDual{Int}... ) where {P<:CuArray{<:Base.IEEEFloat}} _dims = map(primal, dims) - y = CoDual(P(undef, _dims), P(undef, _dims)) - return y, NoPullback(p, init, dims...) + return CoDual(P(undef, _dims), P(undef, _dims)), NoPullback(p, init, dims...) end end From fc3bb10315be6cd054ad2bebe7fdb03e63a6607c Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 4 Oct 2024 10:39:40 +0100 Subject: [PATCH 21/21] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a4c88649b..6ac178be7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.6" +version = "0.4.7" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"