diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 541ccdd9..fafabd00 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -22,9 +22,13 @@ jobs: - x64 AD: - ForwardDiff + - Tapir - Tracker - ReverseDiff - Zygote + exclude: + - version: 1.6 + AD: Tapir steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index e0f4f874..b78b7817 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -37,6 +38,7 @@ BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" BijectorsTrackerExt = "Tracker" +BijectorsTapirExt = "Tapir" BijectorsZygoteExt = "Zygote" [compat] @@ -60,6 +62,7 @@ Requires = "0.5, 1" ReverseDiff = "1" Roots = "1.3.4, 2" Statistics = "1" +Tapir = "0.2.23" Tracker = "0.2" Zygote = "0.6.63" julia = "1.6" @@ -69,5 +72,6 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/BijectorsTapirExt.jl b/ext/BijectorsTapirExt.jl new file mode 100644 index 00000000..70805a82 --- /dev/null +++ b/ext/BijectorsTapirExt.jl @@ -0,0 +1,38 @@ +module BijectorsTapirExt + +if isdefined(Base, :get_extension) + using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule + using Bijectors: find_alpha, ChainRulesCore +else + using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule + using ..Bijectors: find_alpha, ChainRulesCore +end + +for P in [Float16, Float32, Float64] + @from_rrule(MinimalCtx, Tuple{typeof(find_alpha),P,P,P}) +end + +# The final argument could be an Integer of some kind. This should be fine provided that +# it has tangent type equal to `NoTangent`, which means that it's non-differentiable and +# can be safely dropped. We verify that the concrete type of the Integer satisfies this +# constraint, and error if (for some reason) it does not. This should be fine unless a very +# unusual Integer type is encountered. +@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) + +function Tapir.rrule!!( + ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} +) where {P<:Base.IEEEFloat,I<:Integer} + # Require that the integer is non-differentiable. + if tangent_type(I) != Tapir.NoTangent + msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." + throw(ArgumentError(msg)) + end + out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z)) + function find_alpha_pb(dout::P) + _, dx, dy, _ = pb(dout) + return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData() + end + return Tapir.zero_fcodual(out), find_alpha_pb +end + +end diff --git a/test/Project.toml b/test/Project.toml index c36de954..22c517f9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -15,6 +15,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index f206a0da..e21f644c 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -27,6 +27,19 @@ end test_frule(Bijectors.find_alpha, x, y, z) test_rrule(Bijectors.find_alpha, x, y, z) + if @isdefined Tapir + rng = Xoshiro(123456) + Tapir.TestUtils.test_rrule!!( + rng, Bijectors.find_alpha, x, y, z; is_primitive=true, perf_flag=:none + ) + Tapir.TestUtils.test_rrule!!( + rng, Bijectors.find_alpha, x, y, 3; is_primitive=true, perf_flag=:none + ) + Tapir.TestUtils.test_rrule!!( + rng, Bijectors.find_alpha, x, y, UInt32(3); is_primitive=true, perf_flag=:none + ) + end + test_rrule( Bijectors.combine, Bijectors.PartitionMask(3, [1], [2]) ⊢ ChainRulesTestUtils.NoTangent(), diff --git a/test/ad/utils.jl b/test/ad/utils.jl index bea823fd..d88d2cd9 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -30,5 +30,28 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end + if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10" + rule = Tapir.build_rrule(f, x; safety_on=false) + if :tapir in broken + @test_broken( + isapprox( + Tapir.value_and_gradient!!(rule, f, x)[2][2], + finitediff; + rtol=rtol, + atol=atol, + ) + ) + else + @test( + isapprox( + Tapir.value_and_gradient!!(rule, f, x)[2][2], + finitediff; + rtol=rtol, + atol=atol, + ) + ) + end + end + return nothing end diff --git a/test/runtests.jl b/test/runtests.jl index 052b6056..eafc34bc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,14 @@ if VERSION < v"1.9" using Compat: stack end +# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing +# on at least version 1.10. +if VERSION >= v"1.10" + using Pkg + Pkg.add("Tapir") + using Tapir +end + const GROUP = get(ENV, "GROUP", "All") # Always include this since it can be useful for other tests.