Skip to content

Commit

Permalink
Tapir.jl Usage (#319)
Browse files Browse the repository at this point in the history
* Bump patch version and add Tapir ext

* Make Tapir available at test time

* Add Tapir runs to AD testing

* Add single rule to Tapir to handle bisection

* using Tapir

* Run on 1.6 only and add Tapir to AD tests

* Disable more tests

* Update ext/BijectorsTapirExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fix formatting

* Restrict version

* Remove Tapir from Project

* Do not run Tapir CI on 1.6

* Enable 1.6 tests in general

* Enable 1.6 on interface tests

* Tweak versioning

* Cancel when multiple things are pushed

* Add Tapir to extras

* Comment out tapir usage

* Try allowing more versions of Tapir

* Allow more versions of Tapir

* More tweaks

* Add Pkg to test deps

* Refine CI

* Use Tapir on 1.10

* Remove CI modifications

* Formatting

* add comment to Tapir installation

* Support a range of types

* Fix Project.toml

* Fix formatting

* Fix formatting

* Fix formatting

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Sort out formatting

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 3, 2024
1 parent c3474b2 commit 2849aca
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ jobs:
- x64
AD:
- ForwardDiff
- Tapir
- Tracker
- ReverseDiff
- Zygote
exclude:
- version: 1.6
AD: Tapir
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -37,6 +38,7 @@ BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsTrackerExt = "Tracker"
BijectorsTapirExt = "Tapir"
BijectorsZygoteExt = "Zygote"

[compat]
Expand All @@ -60,6 +62,7 @@ Requires = "0.5, 1"
ReverseDiff = "1"
Roots = "1.3.4, 2"
Statistics = "1"
Tapir = "0.2.23"
Tracker = "0.2"
Zygote = "0.6.63"
julia = "1.6"
Expand All @@ -69,5 +72,6 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
38 changes: 38 additions & 0 deletions ext/BijectorsTapirExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module BijectorsTapirExt

if isdefined(Base, :get_extension)
using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule
using Bijectors: find_alpha, ChainRulesCore
else
using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule
using ..Bijectors: find_alpha, ChainRulesCore
end

for P in [Float16, Float32, Float64]
@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),P,P,P})
end

# The final argument could be an Integer of some kind. This should be fine provided that
# it has tangent type equal to `NoTangent`, which means that it's non-differentiable and
# can be safely dropped. We verify that the concrete type of the Integer satisfies this
# constraint, and error if (for some reason) it does not. This should be fine unless a very
# unusual Integer type is encountered.
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})

function Tapir.rrule!!(
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I}
) where {P<:Base.IEEEFloat,I<:Integer}
# Require that the integer is non-differentiable.
if tangent_type(I) != Tapir.NoTangent
msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent."
throw(ArgumentError(msg))
end
out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z))
function find_alpha_pb(dout::P)
_, dx, dy, _ = pb(dout)
return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData()
end
return Tapir.zero_fcodual(out), find_alpha_pb
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
13 changes: 13 additions & 0 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ end
test_frule(Bijectors.find_alpha, x, y, z)
test_rrule(Bijectors.find_alpha, x, y, z)

if @isdefined Tapir
rng = Xoshiro(123456)
Tapir.TestUtils.test_rrule!!(
rng, Bijectors.find_alpha, x, y, z; is_primitive=true, perf_flag=:none
)
Tapir.TestUtils.test_rrule!!(
rng, Bijectors.find_alpha, x, y, 3; is_primitive=true, perf_flag=:none
)
Tapir.TestUtils.test_rrule!!(
rng, Bijectors.find_alpha, x, y, UInt32(3); is_primitive=true, perf_flag=:none
)
end

test_rrule(
Bijectors.combine,
Bijectors.PartitionMask(3, [1], [2]) ChainRulesTestUtils.NoTangent(),
Expand Down
23 changes: 23 additions & 0 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,28 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
end
end

if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10"
rule = Tapir.build_rrule(f, x; safety_on=false)
if :tapir in broken
@test_broken(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
else
@test(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
end
end

return nothing
end
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ if VERSION < v"1.9"
using Compat: stack
end

# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing
# on at least version 1.10.
if VERSION >= v"1.10"
using Pkg
Pkg.add("Tapir")
using Tapir
end

const GROUP = get(ENV, "GROUP", "All")

# Always include this since it can be useful for other tests.
Expand Down

0 comments on commit 2849aca

Please sign in to comment.