From 965b98bdafbbc82de94398fb5f66d7ff51996679 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 29 Apr 2024 12:57:36 +0100 Subject: [PATCH] Wct/refactor api and ad implementation (#121) * Separate forwards- and reverse-data * Move low level maths rules over to new system * Extend tuple_map to handle named tuples * Get most builtins working * Fix low_level_maths rules * Excise register-related code * Excise more register-related code * Include interpreter code for now * Single block code works + remove redundant stacks * Add function to get phi nodes from bbcode * Add online type checking functionality * Improve safety error messages * Remove redundant stack code * Fix phi_nodes bbcode function * Safety checks in test utils * Add reflection to utils * Some work * Tweaks * Improve cos and sin rules * Incorporate test from slack * Move generic functionality from tangnets to utils and document * Add tests to code move to utils from tangents * Remove redundant function * Don't run performance tests in s2s tests * Remove redundant code * Remove redundant code * Remove commented out code * Remove redundant code * Reorganise tangents file * Move generic functionality to utils from tangents * Move generic functionality to utils from tangents * Start tidying up tangent test cases * Simplify tangent testing * Test perf for all tangents and fix perf bug * Unify all type testing * Remove redundant test info * Remove redundant alias * Fix comment * Enable more correctness tests * Tidy up test_utils further * Add more tuple tests * Add more tuple test cases * Run fwds_rvs_data tests on all types * Improve fwds and rvs implementation and get all related tests passing * Simplify function names * Rename zero_reverse_data_from_type to zero_rdata_from_type * Rename zero_reverse_data to zero_rdata * Remove redundant code * Get NamedTuples working * Add tests for structs in new * Remove arbitrary limit on number of arguments to _new_ * Formatting * Formatting * Get all existing _new_ tests passing * Newline at end of file * Add remainder of standard tangent test cases to _new_ * Add triangular test cases to new * Remove unused code * Enable more integration tests * Check number of cotangents in safe mode * Tweak error message * Rename safety to safe_mode * Improve comment * Rename safety to safe_mode in runtests * Fix bug in new for partially initialised structs * Require same lengths of args to tuple_map * Improve docstring for tuple_map * Improve docstring * Fix vararg bug * Fix typo in cglobal rule * Split out test types construction * Fix edge case for _new_ rrule * Improve some builtin rules slightly * Use alias for block stack * Active more integration tests * All existing lgetfield tests pass * Support order keyword * Fix up lsetfield * Active more integration tests * Add more test cases to lgetfield * Tidy up lgetfield rrule implementation * Reenable getfield and setfield rules * Fix avoiding_non_diff_code rules * Enable foreigncall stuff * Enable more tests and fix increment perf * Improve arrayref and set -- fix pointerref and set * Sort out more things * Enable all foreigncall tests * Fix safe_mode recursion * Fix safe mode compilation times * Refactor combine_data to binary tangent and tangent_type * Fix increment inference bug * Fix deprecations in benchmarks script * combine_data to tangent and tangent_type * Name zero_rdata_from_type to zero_like_rdata_from_type * Fix dynamic rrule safe mode * Move zero_like_rdata to its own files * Fix zerordata and friends * Include IdDict tests * Move tuple_fill to utils * Fix gemv * Update chainrules macro * Update trmv * Update remainder of blas rules * Update lapack rules * Add additional builtins test * Add NoPullback rule to misc * Add additional options to test utils * Make DimensionMismatch tangent NoTangent * Turn safety off for misc tests * Minor codegen performance tweaks * Reactivate remainder of tests * Fix abstract types * Bump patch * Tweak contributor list * Try running the GC before benchmarking * Update performance bounds * Tweak plotting range * Rename functions * Tidy up safe mode implementation * Improve safe mode * Improve safe mode error messages * Fix rrule for arrayset * Fix pointer fdata type bug * Fix safe_mode tests * Add increment_rdata functionality * Simplify arrayref and pointerref implementations * Simplify lgetfield rrule implementation * Use lazy rdata in getfield * Simplify lsetfield implementation * Remove some specialisation to reduce compile times * Improve codual * Fix tangent_type for Tuple * Fix up iddict tests * Remove GC calls * Loosen performance bounds on some type unstable tests * Hopefully fix allocations in CI * Fix typo * Remove intrinsic test case * Remove redundant code * Fix performance of rdata creation * Remove redundant code and tidy up variable naming * Fix peformance of zero_rdata_from_type * Add regression tests * Fix performance bug * Remove allocation tests which don't make sense * Fix randn_tangent perf and improve test reliability * Add pprof code to turing integration tests * Force inline rules for getfield and setfield * Increased stability for non-literal getfield and setfield * Add signature to type of LazyDerivedRule * Add unique predecessor computation * Reduce block usage * Revert flawed performance fix * Enable all tests * Fix performance regression * Improve arrayref and arrayset performance * Add erfcx rule to special functions ext * Specialised getfield rules for nondifferentiable stuff * Make eltypes non-differentiable * Test type with LazyRZero * Improve pullback codegen for unreachable blocks * Fix test cases * Tangent type for Method is NoTangent * Revert Turing performance test again * Add more lsetfield tests * More tests for lgetfield and bug fix * Add comment noting choice of code layout * Optimise safe mode compile times * Import more things during testing * Make NoPullback use lazy rdata * Support NoRData in instantiate * Preserve NoPullback in RRuleZeroWrapper * Update NoPullback uses to use new version * Reduce problem sizes to make CI happier * Bump Turing compat * Add fast increment_field for homogeneous Tuple types * Optimise for homogeneous tuples * Optimise homogeneous named tuples increment_field * Optimise getfield for homogeneously-typed NamedTuples * Reduce test problem sizes further to reduce CI burden * Add TemporalGPs integration test * Force-inline more things * Ensure to_benchmark is compiled before running the profiler * Check that getfield works with tuple of types * Fix test utils for tuple of types * Add functionality to determine if a node is used * Do not AD getfield when not used * Optimise safe mode * Add getfield regression integration test * Force-inling forwards-pass IR for calls and invokes * Add additional small-union test * Optimise for un-used ssa nodes and things with NoPullback pullbacks * Improve documentation for ADInfo * Document BlockStack const * Improve documentation for ADInfo outer constructors * Improve documentation of misc functions in reverse mode ad transformations * Tidy up RRuleZeroWrapper and ReturnNode * Tidy up gotoifnot implementation * Remove commented-out line of code * Tidy up reverse-mode code further * Explain special handling for unused getfield calls carefully * Add directions to bbcode file * More documentation for reverse-mode * Improve transformation documentation further * Improve fwds_rvs documentation * More informative name than __convert * Improve comment * Rename fwds_codual_type to fcodual_type * Improve unique pred characterisation documentation --- Project.toml | 10 +- bench/run_benchmarks.jl | 7 +- ext/TapirSpecialFunctionsExt.jl | 1 + src/Tapir.jl | 7 +- src/chain_rules_macro.jl | 16 +- src/codual.jl | 56 +- src/fwds_rvs_data.jl | 580 ++++++++++ src/interpreter/bbcode.jl | 168 ++- src/interpreter/interpreted_function.jl | 588 ---------- src/interpreter/ir_utils.jl | 19 +- src/interpreter/registers.jl | 46 - src/interpreter/reverse_mode_ad.jl | 589 ---------- src/interpreter/s2s_reverse_mode_ad.jl | 1018 ++++++++++------- src/interpreter/zero_like_rdata.jl | 37 + .../avoiding_non_differentiable_code.jl | 14 +- src/rrules/blas.jl | 178 +-- src/rrules/builtins.jl | 816 ++++++------- src/rrules/foreigncall.jl | 400 +++---- src/rrules/iddict.jl | 44 +- src/rrules/lapack.jl | 137 +-- src/rrules/low_level_maths.jl | 39 +- src/rrules/misc.jl | 236 ++-- src/rrules/new.jl | 125 +- src/safe_mode.jl | 145 +++ src/stack.jl | 78 +- src/tangents.jl | 365 +++--- src/test_utils.jl | 358 ++++-- src/utils.jl | 135 ++- test/codual.jl | 5 + test/front_matter.jl | 24 +- test/fwds_rvs_data.jl | 35 + test/integration_testing/misc.jl | 3 +- test/integration_testing/special_functions.jl | 3 + test/integration_testing/temporalgps.jl | 44 + test/integration_testing/turing.jl | 6 + test/interpreter/bbcode.jl | 145 +++ test/interpreter/interpreted_function.jl | 230 ---- test/interpreter/registers.jl | 8 - test/interpreter/reverse_mode_ad.jl | 365 ------ test/interpreter/s2s_reverse_mode_ad.jl | 57 +- test/interpreter/zero_like_rdata.jl | 13 + test/rrules/iddict.jl | 5 +- test/runtests.jl | 8 +- test/safe_mode.jl | 36 + test/stack.jl | 13 - test/tangents.jl | 136 +-- test/utils.jl | 61 + 47 files changed, 3689 insertions(+), 3720 deletions(-) create mode 100644 src/fwds_rvs_data.jl delete mode 100644 src/interpreter/interpreted_function.jl delete mode 100644 src/interpreter/registers.jl delete mode 100644 src/interpreter/reverse_mode_ad.jl create mode 100644 src/interpreter/zero_like_rdata.jl create mode 100644 src/safe_mode.jl create mode 100644 test/fwds_rvs_data.jl create mode 100644 test/integration_testing/temporalgps.jl delete mode 100644 test/interpreter/interpreted_function.jl delete mode 100644 test/interpreter/registers.jl delete mode 100644 test/interpreter/reverse_mode_ad.jl create mode 100644 test/interpreter/zero_like_rdata.jl create mode 100644 test/safe_mode.jl diff --git a/Project.toml b/Project.toml index 62c8e76e..c45f5a2f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" -authors = ["Will Tebbutt and contributors"] -version = "0.1.2" +authors = ["Will Tebbutt, Hong Ge, and contributors"] +version = "0.2.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -36,7 +36,8 @@ PDMats = "0.11" Setfield = "1" SpecialFunctions = "2" StableRNGs = "1" -Turing = "0.29" +TemporalGPs = "0.6" +Turing = "0.31" julia = "1" [extras] @@ -50,8 +51,9 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "DiffTests", "Distributions", "FillArrays", "KernelFunctions", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing"] +test = ["AbstractGPs", "BenchmarkTools", "DiffTests", "Distributions", "FillArrays", "KernelFunctions", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] diff --git a/bench/run_benchmarks.jl b/bench/run_benchmarks.jl index 212b5fde..3ee4a1c2 100644 --- a/bench/run_benchmarks.jl +++ b/bench/run_benchmarks.jl @@ -22,7 +22,6 @@ using Tapir: CoDual, generate_hand_written_rrule!!_test_cases, generate_derived_rrule!!_test_cases, - InterpretedFunction, TestUtils, PInterp, _typeof @@ -100,8 +99,8 @@ function _generate_gp_inputs() end @model broadcast_demo(x) = begin - μ ~ TruncatedNormal(1, 2, 0.1, 10) - σ ~ TruncatedNormal(1, 2, 0.1, 10) + μ ~ truncated(Normal(1, 2), 0.1, 10) + σ ~ truncated(Normal(1, 2), 0.1, 10) x .~ LogNormal(μ, σ) end @@ -295,7 +294,7 @@ Constructs a histogram of the `tapir_ratio` field of `df`, with formatting that well-suited to the numbers typically found in this field. """ function plot_ratio_histogram!(df::DataFrame) - bin = 10.0 .^ (0.0:0.05:6.0) + bin = 10.0 .^ (-1.0:0.05:4.0) xlim = extrema(bin) histogram(df.Tapir; xscale=:log10, xlim, bin, title="log", label="") end diff --git a/ext/TapirSpecialFunctionsExt.jl b/ext/TapirSpecialFunctionsExt.jl index 85b12723..651fe018 100644 --- a/ext/TapirSpecialFunctionsExt.jl +++ b/ext/TapirSpecialFunctionsExt.jl @@ -7,4 +7,5 @@ module TapirSpecialFunctionsExt @from_rrule DefaultCtx Tuple{typeof(airyai), Float64} @from_rrule DefaultCtx Tuple{typeof(airyaix), Float64} @from_rrule DefaultCtx Tuple{typeof(erfc), Float64} + @from_rrule DefaultCtx Tuple{typeof(erfcx), Float64} end diff --git a/src/Tapir.jl b/src/Tapir.jl index d0c4e7d5..1e4183a4 100644 --- a/src/Tapir.jl +++ b/src/Tapir.jl @@ -28,10 +28,13 @@ using LinearAlgebra.LAPACK: getrf!, getrs!, getri!, trtrs!, potrf!, potrs! # Needs to be defined before various other things. function _foreigncall_ end +function rrule!! end include("utils.jl") include("tangents.jl") +include("fwds_rvs_data.jl") include("codual.jl") +include("safe_mode.jl") include("stack.jl") include(joinpath("interpreter", "contexts.jl")) @@ -39,9 +42,7 @@ include(joinpath("interpreter", "abstract_interpretation.jl")) include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_utils.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) -include(joinpath("interpreter", "registers.jl")) -include(joinpath("interpreter", "interpreted_function.jl")) -include(joinpath("interpreter", "reverse_mode_ad.jl")) +include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) include("test_utils.jl") diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index cf2bae57..085e70c2 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -1,5 +1,5 @@ -__increment_shim!!(::NoTangent, ::ChainRulesCore.NoTangent) = NoTangent() -__increment_shim!!(x, y) = increment!!(x, y) +_to_rdata(::ChainRulesCore.NoTangent) = NoRData() +_to_rdata(dx::Float64) = dx """ @from_rrule ctx sig @@ -38,21 +38,15 @@ macro from_rrule(ctx, sig) map(n -> :(Tapir.primal($n)), arg_names)..., ) - pb_arg_names = map(n -> Symbol("dx_$(n)"), eachindex(arg_names)) pb_output_names = map(n -> Symbol("dx_$(n)_inc"), eachindex(arg_names)) call_pb = Expr(:(=), Expr(:tuple, pb_output_names...), :(pb(dy))) - incrementers = Expr( - :tuple, - map(pb_arg_names, pb_output_names) do a, b - :(Tapir.__increment_shim!!($a, $b)) - end..., - ) + incrementers = Expr(:tuple, map(b -> :(Tapir._to_rdata($b)), pb_output_names)...) pb = ExprTools.combinedef(Dict( :head => :function, :name => :pb!!, - :args => [:dy, pb_arg_names...], + :args => [:dy], :body => quote $call_pb return $incrementers @@ -67,7 +61,7 @@ macro from_rrule(ctx, sig) :body => quote y, pb = $call_rrule $pb - return Tapir.zero_codual(y), pb!! + return Tapir.zero_fcodual(y), pb!! end, ) ) diff --git a/src/codual.jl b/src/codual.jl index 2e92c4a3..a64fc8b5 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -3,7 +3,11 @@ struct CoDual{Tx, Tdx} dx::Tdx end -# Always sharpen the first thing if it's a type, in order to preserve dispatch possibility. +# Always sharpen the first thing if it's a type so static dispatch remains possible. +function CoDual(x::Type{P}, dx::NoFData) where {P} + return CoDual{@isdefined(P) ? Type{P} : typeof(x), NoFData}(P, dx) +end + function CoDual(x::Type{P}, dx::NoTangent) where {P} return CoDual{@isdefined(P) ? Type{P} : typeof(x), NoTangent}(P, dx) end @@ -26,21 +30,63 @@ See implementation for details, as this function is subject to change. """ @inline uninit_codual(x::P) where {P} = CoDual(x, uninit_tangent(x)) +@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x)) + """ codual_type(P::Type) -Shorthand for `CoDual{P, tangent_type(P}}` when `P` is concrete, equal to `CoDual` if not. +The type of the `CoDual` which contains instances of `P` and associated tangents. """ function codual_type(::Type{P}) where {P} P == DataType && return CoDual P isa Union && return Union{codual_type(P.a), codual_type(P.b)} + P <: UnionAll && return CoDual return isconcretetype(P) ? CoDual{P, tangent_type(P)} : CoDual end codual_type(::Type{Type{P}}) where {P} = CoDual{Type{P}, NoTangent} -struct NoPullback end +struct NoPullback{R<:Tuple} + r::R +end + +""" + NoPullback(args::CoDual...) + +Construct a `NoPullback` from the arguments passed to an `rrule!!`. For each argument, +extracts the primal value, and constructs a `LazyZeroRData`. These are stored in a +`NoPullback` which, in the reverse-pass of AD, instantiates these `LazyZeroRData`s and +returns them in order to perform the reverse-pass of AD. + +The advantage of this approach is that if it is possible to construct the zero rdata element +for each of the arguments lazily, the `NoPullback` generated will be a singleton type. This +means that AD can avoid generating a stack to store this pullback, which can result in +significant performance improvements. +""" +function NoPullback(args::Vararg{CoDual, N}) where {N} + return NoPullback(tuple_map(LazyZeroRData ∘ primal, args)) +end + +@inline (pb::NoPullback)(_) = tuple_map(instantiate, pb.r) + +to_fwds(x::CoDual) = CoDual(primal(x), fdata(tangent(x))) + +to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFData()) + +zero_fcodual(p) = to_fwds(zero_codual(p)) + +""" + fcodual_type(P::Type) + +The type of the `CoDual` which contains instances of `P` and its fdata. +""" +function fcodual_type(::Type{P}) where {P} + P == DataType && return CoDual + P isa Union && return Union{fcodual_type(P.a), fcodual_type(P.b)} + P <: UnionAll && return CoDual + return isconcretetype(P) ? CoDual{P, fdata_type(tangent_type(P))} : CoDual +end -@inline (::NoPullback)(dy, dx...) = dx +fcodual_type(::Type{Type{P}}) where {P} = CoDual{Type{P}, NoFData} -might_be_active(args) = any(might_be_active ∘ _typeof, args) +zero_rdata(x::CoDual) = zero_rdata(primal(x)) diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl new file mode 100644 index 00000000..b61e56bd --- /dev/null +++ b/src/fwds_rvs_data.jl @@ -0,0 +1,580 @@ +""" + NoFData + +Singleton type which indicates that there is nothing to be propagated on the forwards-pass +in addition to the primal data. +""" +struct NoFData end + +increment!!(::NoFData, ::NoFData) = NoFData() + +""" + FData(data::NamedTuple) + +The component of a `struct` which is propagated alongside the primal on the forwards-pass of +AD. For example, the tangents for `Float64`s do not need to be propagated on the forwards- +pass of reverse-mode AD, so any `Float64` fields of `Tangent` do not need to appear in the +associated `FData`. +""" +struct FData{T<:NamedTuple} + data::T +end + +fields_type(::Type{FData{T}}) where {T<:NamedTuple} = T + +increment!!(x::F, y::F) where {F<:FData} = F(tuple_map(increment!!, x.data, y.data)) + +""" + fdata_type(T) + +Returns the type of the forwards data associated to a tangent of type `T`. +""" +fdata_type(T) + +fdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?")) + +fdata_type(::Type{T}) where {T<:IEEEFloat} = NoFData + +function fdata_type(::Type{PossiblyUninitTangent{T}}) where {T} + Tfields = fdata_type(T) + return PossiblyUninitTangent{Tfields} +end + +@generated function fdata_type(::Type{T}) where {T} + + # If the tangent type is NoTangent, then the forwards-component must be `NoFData`. + T == NoTangent && return NoFData + + # This method can only handle struct types. Tell user to implement their own method. + isprimitivetype(T) && throw(error( + "$T is a primitive type. Implement a method of `fdata_type` for it." + )) + + # If the type is a Union, then take the union type of its arguments. + T isa Union && return Union{fdata_type(T.a), fdata_type(T.b)} + + # If `P` is a mutable type, then its forwards data is its tangent. + ismutabletype(T) && return T + + # If the type is itself abstract, it's forward data could be anything. + # The same goes for if the type has any undetermined type parameters. + (isabstracttype(T) || !isconcretetype(T)) && return Any + + # If `P` is an immutable type, then some of its fields may not need to be propagated + # on the forwards-pass. + if T <: Tangent + Tfields = fields_type(T) + fwds_data_field_types = map(1:fieldcount(Tfields)) do n + return fdata_type(fieldtype(Tfields, n)) + end + all(==(NoFData), fwds_data_field_types) && return NoFData + return FData{NamedTuple{fieldnames(Tfields), Tuple{fwds_data_field_types...}}} + end + + return :(error("Unhandled type $T")) +end + +fdata_type(::Type{T}) where {T<:Ptr} = T + +@generated function fdata_type(::Type{P}) where {P<:Tuple} + isa(P, Union) && return Union{fdata_type(P.a), fdata_type(P.b)} + isempty(P.parameters) && return NoFData + isa(last(P.parameters), Core.TypeofVararg) && return Any + all(p -> fdata_type(p) == NoFData, P.parameters) && return NoFData + return Tuple{map(fdata_type, fieldtypes(P))...} +end + +@generated function fdata_type(::Type{NamedTuple{names, T}}) where {names, T<:Tuple} + if fdata_type(T) == NoFData + return NoFData + elseif isconcretetype(fdata_type(T)) + return NamedTuple{names, fdata_type(T)} + else + return Any + end +end + +""" + fdata_field_type(::Type{P}, n::Int) where {P} + +Returns the type of to the nth field of the fdata type associated to `P`. Will be a +`PossiblyUninitTangent` if said field can be undefined. +""" +function fdata_field_type(::Type{P}, n::Int) where {P} + Tf = tangent_type(fieldtype(P, n)) + f = ismutabletype(P) ? Tf : fdata_type(Tf) + return is_always_initialised(P, n) ? f : _wrap_type(f) +end + +""" + fdata(t)::fdata_type(typeof(t)) + +Extract the forwards data from tangent `t`. +""" +@generated function fdata(t::T) where {T} + + # Ask for the forwards-data type. Useful catch-all error checking for unexpected types. + F = fdata_type(T) + + # Catch-all for anything with no forwards-data. + F == NoFData && return :(NoFData()) + + # Catch-all for anything where we return the whole object (mutable structs, arrays...). + F == T && return :(t) + + # T must be a `Tangent` by now. If it's not, something has gone wrong. + !(T <: Tangent) && return :(error("Unhandled type $T")) + return :($F(fdata(t.fields))) +end + +function fdata(t::T) where {T<:PossiblyUninitTangent} + F = fdata_type(T) + return is_init(t) ? F(fdata(val(t))) : F() +end + +@generated function fdata(t::Union{Tuple, NamedTuple}) + fdata_type(t) == NoFData && return NoFData() + return :(tuple_map(fdata, t)) +end + +uninit_fdata(p) = fdata(uninit_tangent(p)) + +""" + NoRData() + +Nothing to propagate backwards on the reverse-pass. +""" +struct NoRData end + +@inline increment!!(::NoRData, ::NoRData) = NoRData() + +@inline increment_field!!(::NoRData, y, ::Val) = NoRData() + +struct RData{T<:NamedTuple} + data::T +end + +fields_type(::Type{RData{T}}) where {T<:NamedTuple} = T + +@inline increment!!(x::RData{T}, y::RData{T}) where {T} = RData(increment!!(x.data, y.data)) + +@inline function increment_field!!(x::RData{T}, y, ::Val{f}) where {T, f} + y isa NoRData && return x + new_val = fieldtype(T, f) <: PossiblyUninitTangent ? fieldtype(T, f)(y) : y + return RData(increment_field!!(x.data, new_val, Val(f))) +end + +""" + rdata_type(T) + +Returns the type of the reverse data of a tangent of type T. +""" +rdata_type(T) + +rdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?")) + +rdata_type(::Type{T}) where {T<:IEEEFloat} = T + +function rdata_type(::Type{PossiblyUninitTangent{T}}) where {T} + return PossiblyUninitTangent{rdata_type(T)} +end + +@generated function rdata_type(::Type{T}) where {T} + + # If the tangent type is NoTangent, then the reverse-component must be `NoRData`. + T == NoTangent && return NoRData + + # This method can only handle struct types. Tell user to implement their own method. + isprimitivetype(T) && throw(error( + "$T is a primitive type. Implement a method of `rdata_type` for it." + )) + + # If the type is a Union, then take the union type of its arguments. + T isa Union && return Union{rdata_type(T.a), rdata_type(T.b)} + + # If `P` is a mutable type, then all tangent info is propagated on the forwards-pass. + ismutabletype(T) && return NoRData + + # If the type is itself abstract, it's reverse data could be anything. + # The same goes for if the type has any undetermined type parameters. + (isabstracttype(T) || !isconcretetype(T)) && return Any + + # If `T` is an immutable type, then some of its fields may not have been propagated on + # the forwards-pass. + if T <: Tangent + Tfs = fields_type(T) + rvs_types = map(n -> rdata_type(fieldtype(Tfs, n)), 1:fieldcount(Tfs)) + all(==(NoRData), rvs_types) && return NoRData + return RData{NamedTuple{fieldnames(Tfs), Tuple{rvs_types...}}} + end +end + +rdata_type(::Type{<:Ptr}) = NoRData + +@generated function rdata_type(::Type{P}) where {P<:Tuple} + isa(P, Union) && return Union{rdata_type(P.a), rdata_type(P.b)} + isempty(P.parameters) && return NoRData + isa(last(P.parameters), Core.TypeofVararg) && return Any + all(p -> rdata_type(p) == NoRData, P.parameters) && return NoRData + return Tuple{map(rdata_type, fieldtypes(P))...} +end + +function rdata_type(::Type{NamedTuple{names, T}}) where {names, T<:Tuple} + if rdata_type(T) == NoRData + return NoRData + elseif isconcretetype(rdata_type(T)) + return NamedTuple{names, rdata_type(T)} + else + return Any + end +end + +""" + rdata_field_type(::Type{P}, n::Int) where {P} + +Returns the type of to the nth field of the rdata type associated to `P`. Will be a +`PossiblyUninitTangent` if said field can be undefined. +""" +function rdata_field_type(::Type{P}, n::Int) where {P} + r = rdata_type(tangent_type(fieldtype(P, n))) + return is_always_initialised(P, n) ? r : _wrap_type(r) +end + +""" + rdata(t)::rdata_type(typeof(t)) + +Extract the reverse data from tangent `t`. +""" +@generated function rdata(t::T) where {T} + + # Ask for the reverse-data type. Useful catch-all error checking for unexpected types. + R = rdata_type(T) + + # Catch-all for anything with no reverse-data. + R == NoRData && return :(NoRData()) + + # Catch-all for anything where we return the whole object (Float64, isbits structs, ...) + R == T && return :(t) + + # T must be a `Tangent` by now. If it's not, something has gone wrong. + !(T <: Tangent) && return :(error("Unhandled type $T")) + return :($(rdata_type(T))(rdata(t.fields))) +end + +function rdata(t::T) where {T<:PossiblyUninitTangent} + R = rdata_type(T) + return is_init(t) ? R(rdata(val(t))) : R() +end + +@generated function rdata(t::Union{Tuple, NamedTuple}) + rdata_type(t) == NoRData && return NoRData() + return :(tuple_map(rdata, t)) +end + +function rdata_backing_type(::Type{P}) where {P} + rdata_field_types = map(n -> rdata_field_type(P, n), 1:fieldcount(P)) + all(==(NoRData), rdata_field_types) && return NoRData + return NamedTuple{fieldnames(P), Tuple{rdata_field_types...}} +end + +""" + zero_rdata(p) + +Given value `p`, return the zero element associated to its reverse data type. +""" +zero_rdata(p) + +zero_rdata(p::IEEEFloat) = zero(p) + +@generated function zero_rdata(p::P) where {P} + + # Get types associated to primal. + T = tangent_type(P) + R = rdata_type(T) + + # If there's no reverse data, return no reverse data, e.g. for mutable types. + R == NoRData && return :(NoRData()) + + # T ought to be a `Tangent`. If it's not, something has gone wrong. + !(T <: Tangent) && Expr(:call, error, "Unhandled type $T") + rdata_field_zeros_exprs = ntuple(fieldcount(P)) do n + R_field = rdata_field_type(P, n) + if R_field <: PossiblyUninitTangent + return :(isdefined(p, $n) ? $R_field(zero_rdata(getfield(p, $n))) : $R_field()) + else + return :(zero_rdata(getfield(p, $n))) + end + end + backing_data_expr = Expr(:call, :tuple, rdata_field_zeros_exprs...) + backing_expr = :($(rdata_backing_type(P))($backing_data_expr)) + return Expr(:call, R, backing_expr) +end + +@generated function zero_rdata(p::Union{Tuple, NamedTuple}) + rdata_type(tangent_type(p)) == NoRData && return NoRData() + return :(tuple_map(zero_rdata, p)) +end + +""" + can_produce_zero_rdata_from_type(::Type{P}) where {P} + +Returns whether or not the zero element of the rdata type for primal type `P` can be +obtained from `P` alone. +""" +@generated function can_produce_zero_rdata_from_type(::Type{P}) where {P} + R = rdata_type(tangent_type(P)) + R == NoRData && return true + isconcretetype(P) || return false + + # For general structs, just look at their fields. + return isstructtype(P) ? all(can_produce_zero_rdata_from_type, fieldtypes(P)) : false +end + +can_produce_zero_rdata_from_type(::Type{<:IEEEFloat}) = true + +""" + CannotProduceZeroRDataFromType() + +Returned by `zero_rdata_from_type` if is not possible to construct the zero rdata element +for a given type. See `zero_rdata_from_type` for more info. +""" +struct CannotProduceZeroRDataFromType end + +""" + zero_rdata_from_type(::Type{P}) where {P} + +Returns the zero element of `rdata_type(tangent_type(P))` if this is possible given only +`P`. If not possible, returns an instance of `CannotProduceZeroRDataFromType`. + +For example, the zero rdata associated to any primal of type `Float64` is `0.0`, so for +`Float64`s this function is simple. Similarly, if the rdata type for `P` is `NoRData`, that +can simply be returned. + +However, it is not possible to return the zero rdata element for abstract types e.g. `Real` +as the type does not uniquely determine the zero element -- the rdata type for `Real` is +`Any`. + +These considerations apply recursively to tuples / namedtuples / structs, etc. + +If you encounter a type which this function returns `CannotProduceZeroRDataFromType`, but +you believe this is done in error, please open an issue. This kind of problem does not +constitute a correctness problem, but can be detrimental to performance, so should be dealt +with. +""" +@generated function zero_rdata_from_type(::Type{P}) where {P} + R = rdata_type(tangent_type(P)) + + # If we know we can't produce a tangent, say so. + can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType() + + # Simple case. + R == NoRData && return NoRData() + + # If `P` is a struct type, attempt to derive the zero rdata for it. We cannot derive + # the zero rdata if it is not possible to derive the zero rdata for any of its fields. + if isstructtype(P) + names = fieldnames(P) + types = fieldtypes(P) + wrapped_field_zeros = tuple_map(ntuple(identity, length(names))) do n + fzero = :(zero_rdata_from_type($(types[n]))) + if tangent_field_type(P, n) <: PossiblyUninitTangent + Q = rdata_type(tangent_type(fieldtype(P, n))) + return :(_wrap_field($Q, $fzero)) + else + return fzero + end + end + wrapped_field_zeros_tuple = Expr(:call, :tuple, wrapped_field_zeros...) + return :($R(NamedTuple{$names}($wrapped_field_zeros_tuple))) + end + + # Fallback -- we've not been able to figure out how to produce an instance of zero rdata + # so report that it cannot be done. + return throw(error("Unhandled type $P")) +end + +@generated function zero_rdata_from_type(::Type{P}) where {P<:Tuple} + can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType() + rdata_type(tangent_type(P)) == NoRData && return NoRData() + return tuple_map(zero_rdata_from_type, fieldtypes(P)) +end + +function zero_rdata_from_type(::Type{P}) where {P<:NamedTuple} + can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType() + rdata_type(tangent_type(P)) == NoRData && return NoRData() + return NamedTuple{fieldnames(P)}(tuple_map(zero_rdata_from_type, fieldtypes(P))) +end + +zero_rdata_from_type(::Type{P}) where {P<:IEEEFloat} = zero(P) + +""" + LazyZeroRData{P, Tdata}() + +This type is a lazy placeholder for `zero_like_rdata_from_type`. This is used to defer +construction of zero data to the reverse pass. Calling `instantiate` on an instance of this +will construct a zero data. + +Users should construct using `LazyZeroRData(p)`, where `p` is an value of type `P`. This +constructor, and `instantiate`, are specialised to minimise the amount of data which must +be stored. For example, `Float64`s do not need any data, so `LazyZeroRData(0.0)` produces +an instance of a singleton type, meaning that various important optimisations can be +performed in AD. +""" +struct LazyZeroRData{P, Tdata} + data::Tdata +end + +# Be lazy if we can compute the zero element given only the type, otherwise just store the +# zero element and use it later. +@inline function LazyZeroRData(p::P) where {P} + if zero_rdata_from_type(P) isa CannotProduceZeroRDataFromType + rdata = zero_rdata(p) + return LazyZeroRData{P, _typeof(rdata)}(rdata) + else + return LazyZeroRData{P, Nothing}(nothing) + end +end + +@inline instantiate(::LazyZeroRData{P, Nothing}) where {P} = zero_rdata_from_type(P) +@inline instantiate(r::LazyZeroRData) = r.data +@inline instantiate(::NoRData) = NoRData() + +""" + tangent_type(F::Type, R::Type)::Type + +Given the type of the fdata and rdata, `F` and `R` resp., for some primal type, compute its +tangent type. This method must be equivalent to `tangent_type(_typeof(primal))`. +""" +tangent_type(::Type{NoFData}, ::Type{NoRData}) = NoTangent +tangent_type(::Type{NoFData}, ::Type{R}) where {R<:IEEEFloat} = R +tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Array} = F + +# Tuples +function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple, R<:Tuple} + return Tuple{tuple_map(tangent_type, Tuple(F.parameters), Tuple(R.parameters))...} +end +function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:Tuple} + F_tuple = Tuple{tuple_fill(NoFData, Val(length(R.parameters)))...} + return tangent_type(F_tuple, R) +end +function tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Tuple} + R_tuple = Tuple{tuple_fill(NoRData, Val(length(F.parameters)))...} + return tangent_type(F, R_tuple) +end + +# NamedTuples +function tangent_type(::Type{F}, ::Type{R}) where {ns, F<:NamedTuple{ns}, R<:NamedTuple{ns}} + return NamedTuple{ns, tangent_type(tuple_type(F), tuple_type(R))} +end +function tangent_type(::Type{NoFData}, ::Type{R}) where {ns, R<:NamedTuple{ns}} + return NamedTuple{ns, tangent_type(NoFData, tuple_type(R))} +end +function tangent_type(::Type{F}, ::Type{NoRData}) where {ns, F<:NamedTuple{ns}} + return NamedTuple{ns, tangent_type(tuple_type(F), NoRData)} +end +tuple_type(::Type{<:NamedTuple{<:Any, T}}) where {T<:Tuple} = T + +# mutable structs +tangent_type(::Type{F}, ::Type{NoRData}) where {F<:MutableTangent} = F + +# structs +function tangent_type(::Type{F}, ::Type{R}) where {F<:FData, R<:RData} + return Tangent{tangent_type(fields_type(F), fields_type(R))} +end +function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:RData} + return Tangent{tangent_type(NoFData, fields_type(R))} +end +function tangent_type(::Type{F}, ::Type{NoRData}) where {F<:FData} + return Tangent{tangent_type(fields_type(F), NoRData)} +end + +function tangent_type( + ::Type{PossiblyUninitTangent{F}}, ::Type{PossiblyUninitTangent{R}} +) where {F, R} + return PossiblyUninitTangent{tangent_type(F, R)} +end + +# Abstract types. +tangent_type(::Type{Any}, ::Type{Any}) = Any + + +""" + tangent(f, r) + +Reconstruct the tangent `t` for which `fdata(t) == f` and `rdata(t) == r`. +""" +tangent(::NoFData, ::NoRData) = NoTangent() +tangent(::NoFData, r::IEEEFloat) = r +tangent(f::Array, ::NoRData) = f + +# Tuples +tangent(f::Tuple, r::Tuple) = tuple_map(tangent, f, r) +tangent(::NoFData, r::Tuple) = tuple_map(_r -> tangent(NoFData(), _r), r) +tangent(f::Tuple, ::NoRData) = tuple_map(_f -> tangent(_f, NoRData()), f) + +# NamedTuples +function tangent(f::NamedTuple{n}, r::NamedTuple{n}) where {n} + return NamedTuple{n}(tangent(Tuple(f), Tuple(r))) +end +function tangent(::NoFData, r::NamedTuple{ns}) where {ns} + return NamedTuple{ns}(tangent(NoFData(), Tuple(r))) +end +function tangent(f::NamedTuple{ns}, ::NoRData) where {ns} + return NamedTuple{ns}(tangent(Tuple(f), NoRData())) +end + +# mutable structs +tangent(f::MutableTangent, r::NoRData) = f + +# structs +function tangent(f::F, r::R) where {F<:FData, R<:RData} + return tangent_type(F, R)(tangent(f.data, r.data)) +end +function tangent(::NoFData, r::R) where {R<:RData} + return tangent_type(NoFData, R)(tangent(NoFData(), r.data)) +end +function tangent(f::F, ::NoRData) where {F<:FData} + return tangent_type(F, NoRData)(tangent(f.data, NoRData())) +end + +function tangent(f::PossiblyUninitTangent{F}, r::PossiblyUninitTangent{R}) where {F, R} + T = PossiblyUninitTangent{tangent_type(F, R)} + is_init(f) && is_init(r) && return T(tangent(val(f), val(r))) + !is_init(f) && !is_init(r) && return T() + throw(ArgumentError("Initialisation mismatch")) +end +function tangent(f::PossiblyUninitTangent{F}, ::PossiblyUninitTangent{NoRData}) where {F} + T = PossiblyUninitTangent{tangent_type(F, NoRData)} + return is_init(f) ? T(tangent(val(f), NoRData())) : T() +end +function tangent(::PossiblyUninitTangent{NoFData}, r::PossiblyUninitTangent{R}) where {R} + T = PossiblyUninitTangent{tangent_type(NoFData, R)} + return is_init(r) ? T(tangent(NoFData(), val(r))) : T() +end +function tangent(::PossiblyUninitTangent{NoFData}, ::PossiblyUninitTangent{NoRData}) + return PossiblyUninitTangent(NoTangent()) +end + +""" + increment_rdata!!(t::T, r)::T where {T} + +Increment the rdata component of tangent `t` by `r`, and return the updated tangent. +Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc. +""" +increment_rdata!!(t::T, r) where {T} = tangent(fdata(t), increment!!(rdata(t), r))::T + +""" + zero_tangent(p, ::NoFData) + + +""" +zero_tangent(p, ::NoFData) = zero_tangent(p) + +function zero_tangent(p::P, f::F) where {P, F} + T = tangent_type(P) + T == F && return f + r = rdata(zero_tangent(p)) + return tangent(f, r) +end + +zero_tangent(p::Tuple, f::Union{Tuple, NamedTuple}) = tuple_map(zero_tangent, p, f) diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl index 78d1d658..bb7402bf 100644 --- a/src/interpreter/bbcode.jl +++ b/src/interpreter/bbcode.jl @@ -1,3 +1,5 @@ +# See the docstring for `BBCode` for some context on this file. + _id_count::Int32 = 0 """ @@ -153,6 +155,20 @@ Base.length(bb::BBlock) = length(bb.inst_ids) Base.copy(bb::BBlock) = BBlock(bb.id, copy(bb.inst_ids), copy(bb.insts)) +""" + phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}} + +Returns all of the `IDPhiNode`s at the start of `bb`, along with their `ID`s. If there are +no `IDPhiNode`s at the start of `bb`, then both vectors will be empty. +""" +function phi_nodes(bb::BBlock) + n_phi_nodes = findlast(x -> x.stmt isa IDPhiNode, bb.insts) + if n_phi_nodes === nothing + n_phi_nodes = 0 + end + return bb.inst_ids[1:n_phi_nodes], bb.insts[1:n_phi_nodes] +end + """ Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing @@ -248,21 +264,28 @@ Base.copy(ir::BBCode) = BBCode(ir, copy(ir.blocks)) Compute a map from the `ID of each `BBlock` in `ir` to its possible successors. """ function compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}} - succs = map(enumerate(ir.blocks)) do (n, blk) - return successors(terminator(blk), n, ir, n == length(ir.blocks)) + return _compute_all_successors(ir.blocks) +end + +# Internal method. Just requires that a Vector of BBlocks are passed. This method is easier +# to construct test cases for because there is no need to construct all the other stuff that +# goes into a `BBCode`. +function _compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} + succs = map(enumerate(blks)) do (n, blk) + return successors(terminator(blk), n, blks, n == length(blks)) end - return Dict{ID, Vector{ID}}(zip(map(b -> b.id, ir.blocks), succs)) + return Dict{ID, Vector{ID}}(zip(map(b -> b.id, blks), succs)) end -function successors(::Nothing, n::Int, ir::BBCode, is_final_block::Bool) - return is_final_block ? ID[] : ID[ir.blocks[n+1].id] +function successors(::Nothing, n::Int, blks::Vector{BBlock}, is_final_block::Bool) + return is_final_block ? ID[] : ID[blks[n+1].id] end -successors(t::IDGotoNode, ::Int, ::BBCode, ::Bool) = [t.label] -function successors(t::IDGotoIfNot, n::Int, ir::BBCode, is_final_block::Bool) - return is_final_block ? ID[t.dest] : ID[t.dest, ir.blocks[n + 1].id] +successors(t::IDGotoNode, ::Int, ::Vector{BBlock}, ::Bool) = [t.label] +function successors(t::IDGotoIfNot, n::Int, blks::Vector{BBlock}, is_final_block::Bool) + return is_final_block ? ID[t.dest] : ID[t.dest, blks[n + 1].id] end -successors(::ReturnNode, ::Int, ::BBCode, ::Bool) = ID[] -successors(t::Switch, ::Int, ::BBCode, ::Bool) = vcat(t.dests, t.fallthrough_dest) +successors(::ReturnNode, ::Int, ::Vector{BBlock}, ::Bool) = ID[] +successors(t::Switch, ::Int, ::Vector{BBlock}, ::Bool) = vcat(t.dests, t.fallthrough_dest) """ compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} @@ -270,8 +293,12 @@ successors(t::Switch, ::Int, ::BBCode, ::Bool) = vcat(t.dests, t.fallthrough_des Compute a map from the `ID of each `BBlock` in `ir` to its possible predecessors. """ function compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} + return _compute_all_predecessors(ir.blocks) +end - successor_map = compute_all_successors(ir) +function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} + + successor_map = _compute_all_successors(blks) # Initialise predecessor map to be empty. ks = collect(keys(successor_map)) @@ -582,3 +609,122 @@ function _sort_blocks!(ir::BBCode) ir.blocks .= ir.blocks[I] return ir end + +#= + characterise_unique_predecessor_blocks(blks::Vector{BBlock}) -> + Tuple{Dict{ID, Bool}, Dict{ID, Bool}} + +We call a block `b` a _unique_ _predecessor_ in the control flow graph associated to `blks` +if it is the only predecessor to all of its successors. Put differently we call `b` a unique +predecessor if, whenever control flow arrives in any of the successors of `b`, we know for +certain that the previous block must have been `b`. + +Returns two `Dict`s. A value in the first `Dict` is `true` if the block associated to its +key is a unique precessor, and is `false` if not. A value in the second `Dict` is `true` if +it has a single predecessor, and that predecessor is a unique predecessor. + +*Context*: + +This information is important for optimising AD because knowing that `b` is a unique +predecessor means that +1. on the forwards-pass, there is no need to push the ID of `b` to the block stack when + passing through it, and +2. on the reverse-pass, there is no need to pop the block stack when passing through one of + the successors to `b`. + +Utilising this reduces the overhead associated to doing AD. It is quite important when +working with cheap loops -- loops where the operations performed at each iteration +are inexpensive -- for which minimising memory pressure is critical to performance. It is +also important for single-block functions, because it can be used to entirely avoid using a +block stack at all. +=# +function characterise_unique_predecessor_blocks( + blks::Vector{BBlock} +)::Tuple{Dict{ID, Bool}, Dict{ID, Bool}} + + # Obtain the block IDs in order -- this ensures that we get the entry block first. + blk_ids = ID[b.id for b in blks] + preds = _compute_all_predecessors(blks) + succs = _compute_all_successors(blks) + + # The bulk of blocks can be hanled by this general loop. + is_unique_pred = Dict{ID, Bool}() + for id in blk_ids + ss = succs[id] + is_unique_pred[id] = !isempty(ss) && all(s -> length(preds[s]) == 1, ss) + end + + # If there is a single reachable return node, then that block is treated as a unique + # pred, since control can only pass "out" of the function via this block. Conversely, + # if there are multiple reachable return nodes, then execution can return to the calling + # function via any of them, so they are not unique predecessors. + # Note that the previous block sets is_unique_pred[id] to false for all nodes which + # end with a reachable return node, so the value only needs changing if there is a + # unique reachable return node. + reachable_return_blocks = filter(blks) do blk + is_reachable_return_node(terminator(blk)) + end + if length(reachable_return_blocks) == 1 + is_unique_pred[only(reachable_return_blocks).id] = true + end + + # pred_is_unique_pred is true if the unique predecessor to a block is a unique pred. + pred_is_unique_pred = Dict{ID, Bool}() + for id in blk_ids + pred_is_unique_pred[id] = length(preds[id]) == 1 && is_unique_pred[only(preds[id])] + end + + # If the entry block has no predecessors, then it can only be entered once, when the + # function is first entered. In this case, we treat it as having a unique predecessor. + entry_id = blk_ids[1] + pred_is_unique_pred[entry_id] = isempty(preds[entry_id]) + + return is_unique_pred, pred_is_unique_pred +end + +""" + characterise_used_ids(blks::Vector{BBlock})::Dict{ID, Bool} + +For each line in `blks`, determine whether it is referenced anywhere else in the code. +Returns a dictionary containing the results. An element is `false` if the corresponding +`ID` is unused, and `true` if is used. +""" +function characterise_used_ids(stmts::Vector{Tuple{ID, NewInstruction}})::Dict{ID, Bool} + ids = first.(stmts) + insts = last.(stmts) + + # Initialise to false. + is_used = Dict{ID, Bool}(zip(ids, fill(false, length(ids)))) + + # Hunt through the instructions, flipping a value in is_used to true whenever an ID + # is encountered which corresponds to an SSA. + for inst in insts + _find_id_uses!(is_used, inst.stmt) + end + return is_used +end + +# Helper function used in characterise_used_ids. +function _find_id_uses!(d::Dict{ID, Bool}, x::Expr) + for arg in x.args + in(arg, keys(d)) && setindex!(d, true, arg) + end +end +function _find_id_uses!(d::Dict{ID, Bool}, x::IDGotoIfNot) + return in(x.cond, keys(d)) && setindex!(d, true, x.cond) +end +_find_id_uses!(::Dict{ID, Bool}, ::IDGotoNode) = nothing +function _find_id_uses!(d::Dict{ID, Bool}, x::PiNode) + return in(x.val, keys(d)) && setindex!(d, true, x.val) +end +function _find_id_uses!(d::Dict{ID, Bool}, x::IDPhiNode) + v = x.values + for n in eachindex(v) + isassigned(v, n) && in(v[n], keys(d)) && setindex!(d, true, v[n]) + end +end +function _find_id_uses!(d::Dict{ID, Bool}, x::ReturnNode) + return isdefined(x, :val) && in(x.val, keys(d)) && setindex!(d, true, x.val) +end +_find_id_uses!(d::Dict{ID, Bool}, x::QuoteNode) = nothing +_find_id_uses!(d::Dict{ID, Bool}, x) = nothing diff --git a/src/interpreter/interpreted_function.jl b/src/interpreter/interpreted_function.jl deleted file mode 100644 index 4a0f39f8..00000000 --- a/src/interpreter/interpreted_function.jl +++ /dev/null @@ -1,588 +0,0 @@ -# Special types to represent data in an IRCode and a InterpretedFunction. - -abstract type AbstractSlot{T} end - -Base.eltype(::AbstractSlot{T}) where {T} = T - -""" - SlotRef{T}() - -Constructs a reference to a slot of type `T` whose value is unassigned. - - SlotRef(x::T) where {T} - -Constructs a reference to a slot of type `T` whose value is `x`. - - SlotRef{T}(x) - -Constructs a reference to a slot of type `T` whose value is `x`. Valid provided that -`typeof(x) <: T`. -""" -mutable struct SlotRef{T} <: AbstractSlot{T} - x::T - SlotRef{T}() where {T} = new{T}() - SlotRef(x::T) where {T} = new{T}(x) - SlotRef{T}(x) where {T} = new{T}(x) -end - -Base.getindex(x::SlotRef) = getfield(x, :x) -Base.setindex!(x::SlotRef, val) = setfield!(x, :x, val) -Base.isassigned(x::SlotRef) = isdefined(x, :x) -Base.copy(x::SlotRef{T}) where {T} = isassigned(x) ? SlotRef{T}(x[]) : SlotRef{T}() - -""" - ConstSlot(x) - -Represents a constant, and it type-stable, and is therefore stored inline. -""" -struct ConstSlot{T} <: AbstractSlot{T} - x::T - ConstSlot(x::T) where {T} = new{T}(x) - ConstSlot(::Type{T}) where {T} = new{Type{T}}(T) - ConstSlot{T}(x) where {T} = new{T}(x) -end - -Base.getindex(x::ConstSlot) = getfield(x, :x) -Base.setindex!(::ConstSlot, val) = nothing -Base.isassigned(::ConstSlot) = true -Base.copy(x::ConstSlot{T}) where {T} = ConstSlot{T}(x[]) - -""" - TypedGlobalRef(x::GlobalRef) - -A (potentially) type-stable getter for a `GlobalRef`. In particular, if the `GlobalRef` is -declared to be a concrete type, this `getindex(::TypedGlobalRef)` will be type-stable. If no -declaration was made, then `getindex(::TypedGlobalRef)` will infer to `Any`. - -If a `GlobalRef` is declared to be constant, prefer to represent it using a `ConstSlot`, -rather than a `TypedGlobalRef`. -""" -struct TypedGlobalRef{T} <: AbstractSlot{T} - mod::Module - name::Symbol - TypedGlobalRef(x::GlobalRef) = new{x.binding.ty}(x.mod, x.name) -end - -TypedGlobalRef(mod::Module, name::Symbol) = TypedGlobalRef(GlobalRef(mod, name)) - -Base.getindex(x::TypedGlobalRef{T}) where {T} = getglobal(x.mod, x.name)::T -Base.setindex!(x::TypedGlobalRef, val) = setglobal!(x.mod, x.name, val) -Base.isassigned(::TypedGlobalRef) = true - -#= -Returns either a `ConstSlot` or a `TypedGlobalRef`, both of which are `AbstractSlot`s. -In particular, a `ConstSlot` is returned only if the `ex` is declared to be constant. -=# -function _globalref_to_slot(ex::GlobalRef) - return isconst(ex) ? ConstSlot(getglobal(ex.mod, ex.name)) : TypedGlobalRef(ex) -end - -# Utility functionality used through instruction construction. - -const Inst = Core.OpaqueClosure{Tuple{Int}, Int} - -# Standard handling for next-block returns for non control flow related instructions. -_standard_next_block(is_blk_end::Bool, current_blk::Int) = is_blk_end ? current_blk + 1 : 0 - - -# IR node handlers -- translates Julia SSAIR nodes into executable `Inst`s (see above). -# Each node may have several methods of `build_inst`. One will always be a method which -# accepts a variety of arguments, including an `InterpretedFunction`, extracts only the data -# that it needs, and calls another method of `build_inst`. This other method of -# `build_inst` will actually build the instruction. This structure is used to make it -# easy to construct unit test cases for the second method (if you inspect the tests for this -# code, you will find that the second method is usually called). - -## ReturnNode -function build_inst(inst::ReturnNode, @nospecialize(in_f), ::Int, ::Int, ::Bool)::Inst - return build_inst(ReturnNode, in_f.return_slot, _get_slot(inst.val, in_f)) -end -function build_inst(::Type{ReturnNode}, ret_slot::SlotRef, val_slot::AbstractSlot) - return @opaque (prev_block::Int) -> (ret_slot[] = val_slot[]; return -1) -end - -## GotoNode -function build_inst(inst::GotoNode, @nospecialize(in_f), ::Int, ::Int, ::Bool)::Inst - return build_inst(GotoNode, inst.label) -end -build_inst(::Type{GotoNode}, label::Int) = @opaque (p::Int) -> label - -## GotoIfNot -function build_inst(x::GotoIfNot, @nospecialize(in_f), ::Int, b::Int, ::Bool)::Inst - return build_inst(GotoIfNot, _get_slot(x.cond, in_f), b + 1, x.dest) -end -function build_inst(::Type{GotoIfNot}, cond::AbstractSlot, next_blk::Int, dest::Int) - if !(Bool <: eltype(cond)) - throw(ArgumentError("cond $cond has eltype $(eltype(cond)), not a supertype of Bool")) - end - return @opaque (p::Int) -> cond[] ? next_blk : dest -end - -## PhiNode - -struct TypedPhiNode{Tr<:AbstractSlot, Tt<:AbstractSlot, Te<:Tuple, Tv<:Tuple} - tmp_slot::Tt - ret_slot::Tr - edges::Te - values::Tv -end - -# Runs a collection of PhiNodes (semantically) simulataneously. Does this by first writing -# the value associated to each PhiNode to its `tmp_slot`. Once all values have been written, -# copies the `tmp_slot` value across to the `ret_slot`. This ensures that if e.g. -# PhiNode B takes the value associated to PhiNode A, it gets the value _before_ this -# collection of PhiNodes started to run, rather than after. See SSAIR docs for more info. -function build_phinode_insts( - ir_insts::Vector{PhiNode}, in_f, n_first::Int, b::Int, is_blk_end::Bool -)::Inst - nodes = build_typed_phi_nodes(ir_insts, in_f, n_first) - next_blk = _standard_next_block(is_blk_end, b) - return build_inst(Vector{PhiNode}, (nodes..., ), next_blk) -end - -struct UndefRef end - -function build_typed_phi_nodes(ir_insts::Vector{PhiNode}, in_f, n_first::Int) - return map(enumerate(ir_insts)) do (j, ir_inst) - ret_slot = in_f.slots[n_first + j - 1] - edges = map(Int, (ir_inst.edges..., )) - vals = ir_inst.values - _init = map(eachindex(vals)) do j - return isassigned(vals, j) ? _get_slot(vals[j], in_f) : UndefRef() - end - T = eltype(ret_slot) - values_vec = map(n -> _init[n] isa UndefRef ? SlotRef{T}() : _init[n], eachindex(_init)) - return TypedPhiNode(SlotRef{T}(), ret_slot, edges, (values_vec..., )) - end -end - -function build_inst(::Type{Vector{PhiNode}}, nodes::Tuple, next_blk::Int)::Inst - return @opaque function (prev_blk::Int) - map(Base.Fix2(store_tmp_value!, prev_blk), nodes) - map(transfer_tmp_value!, nodes) - return next_blk - end -end - -function store_tmp_value!(node::TypedPhiNode, prev_blk::Int) - map(node.edges, node.values) do edge, val - (edge == prev_blk) && isassigned(val) && (node.tmp_slot[] = val[]) - end - return nothing -end - -function transfer_tmp_value!(node::TypedPhiNode) - isassigned(node.tmp_slot) && (node.ret_slot[] = node.tmp_slot[]) - return nothing -end - -## PiNode -function build_inst(x::PiNode, @nospecialize(in_f), n::Int, b::Int, is_blk_end::Bool)::Inst - next_blk = _standard_next_block(is_blk_end, b) - return build_inst(PiNode, _get_slot(x.val, in_f), in_f.slots[n], next_blk) -end -function build_inst(::Type{PiNode}, input::AbstractSlot, out::AbstractSlot, next_blk::Int) - return @opaque (prev_blk::Int) -> (out[] = input[]; return next_blk) -end - -## GlobalRef -function build_inst(x::GlobalRef, @nospecialize(in_f), n::Int, b::Int, is_blk_end::Bool)::Inst - next_blk = _standard_next_block(is_blk_end, b) - return build_inst(GlobalRef, _globalref_to_slot(x), in_f.slots[n], next_blk) -end -function build_inst(::Type{GlobalRef}, x::AbstractSlot, out::AbstractSlot, next_blk::Int) - return @opaque (prev_blk::Int) -> (out[] = x[]; return next_blk) -end - -## QuoteNode and literals -function build_inst(node, @nospecialize(in_f), n::Int, b::Int, is_blk_end::Bool)::Inst - x = ConstSlot(node isa QuoteNode ? node.value : node) - return build_inst(nothing, x, in_f.slots[n], _standard_next_block(is_blk_end, b)) -end -function build_inst(::Nothing, x::ConstSlot, out_slot::AbstractSlot, next_blk::Int) - return @opaque (prev_blk::Int) -> (out_slot[] = x[]; return next_blk) -end - -## Expr - -@inline _eval(f::F, args::Vararg{Any, N}) where {F, N} = f(args...) - -tangent_type(::Type{typeof(_eval)}) = NoTangent - -function build_inst(x::Expr, @nospecialize(in_f), n::Int, b::Int, is_blk_end::Bool)::Inst - next_blk = _standard_next_block(is_blk_end, b) - val_slot = in_f.slots[n] - if Meta.isexpr(x, :boundscheck) - return build_inst(Val(:boundscheck), val_slot, next_blk) - elseif Meta.isexpr(x, :invoke) || Meta.isexpr(x, :call) - is_invoke = Meta.isexpr(x, :invoke) - __args = is_invoke ? x.args[2:end] : x.args - arg_refs = map(arg -> _get_slot(arg, in_f), (__args..., )) - sig = Tuple{map(eltype, arg_refs)...} - evaluator = get_evaluator(in_f.ctx, sig, in_f.interp, is_invoke) - return build_inst(Val(:call), arg_refs, evaluator, val_slot, next_blk) - elseif x.head in [ - :code_coverage_effect, :gc_preserve_begin, :gc_preserve_end, :loopinfo, :leave, - :pop_exception, - ] - return build_inst(Val(:skipped_expression), next_blk) - elseif Meta.isexpr(x, :throw_undef_if_not) - slot_to_check = _get_slot(x.args[2], in_f) - return build_inst(Val(:throw_undef_if_not), slot_to_check, next_blk) - else - throw(error("Unrecognised expression $x")) - end -end - -function get_evaluator(ctx::T, sig, interp, is_invoke::Bool) where {T} - is_primitive(T, sig) && return _eval - is_invoke && return InterpretedFunction(ctx, sig, interp) - return DelayedInterpretedFunction(ctx, Dict(), interp) -end - -function build_inst(::Val{:boundscheck}, val_slot::AbstractSlot, next_blk::Int)::Inst - return @opaque (prev_blk::Int) -> (val_slot[] = true; return next_blk) -end - -function build_inst( - ::Val{:call}, - arg_slots::Targ_slots, - ev::Teval, - val_slot::AbstractSlot, - next_blk::Int, -)::Inst where {Teval, Targ_slots} - return @opaque function (prev_blk::Int) - val_slot[] = ev(tuple_map(getindex, arg_slots)...) - return next_blk - end -end - -build_inst(::Val{:skipped_expression}, next_blk::Int)::Inst = @opaque (prev_blk::Int) -> next_blk - -function build_inst(::Val{:throw_undef_if_not}, slot_to_check::AbstractSlot, next_blk::Int)::Inst - return @opaque function (prev_blk::Int) - !isassigned(slot_to_check) && throw(error("Boooo, not assigned")) - return next_blk - end -end - -# -# Code execution -# - -_get_slot(x::Argument, _, arg_info, _) = arg_info.arg_slots[x.n] -_get_slot(x::GlobalRef, _, _, _) = _globalref_to_slot(x) -_get_slot(x::QuoteNode, _, _, _) = ConstSlot(x.value) -_get_slot(x::SSAValue, slots, _, _) = slots[x.id] -_get_slot(x::AbstractSlot, _, _, _) = throw(error("Already a slot!")) -_get_slot(x, _, _, _) = ConstSlot(x) -function _get_slot(x::Expr, _, _, sptypes) - # There are only a couple of `Expr`s possible as arguments to `Expr`s. - if Meta.isexpr(x, :boundscheck) - return ConstSlot(true) - elseif Meta.isexpr(x, :static_parameter) - return ConstSlot(sptypes[x.args[1]].typ) - else - throw(ArgumentError("Found unexpected expr $x")) - end -end - - - -# -# Loading arguments into slots. -# - -# Data structure to handle arguments to functions. Comprises a collection of slots, and -# knows whether or not it represents the arguments of a varargs function. -struct ArgInfo{Targ_slots<:NTuple{N, Any} where {N}, is_vararg} - arg_slots::Targ_slots -end - -function ArgInfo(::Type{T}, is_vararg::Bool) where {T<:Tuple} - Targ_slots = Tuple{map(t -> SlotRef{t}, T.parameters)...} - return ArgInfo{Targ_slots, is_vararg}((map(t -> SlotRef{t}(), T.parameters)..., )) -end - -function unflatten_vararg(::Val{is_va}, args::Tuple, ::Val{nargs}) where {is_va, nargs} - return is_va ? (args[1:nargs]..., (args[nargs+1:end]..., )) : args -end - -function load_args!(ai::ArgInfo{T, is_vararg}, args::Tuple) where {T, is_vararg} - # There is a difference between the varargs that we recieve, and the varargs of the - # original function. This section sorts that out. - # For example if the original function is `f(x...)`, then the `argtypes` field of its - # `IRCode` when calling e.g. `f(5.0)` will be `Tuple{typeof(f), Tuple{Float64}}`, where - # the second tuple contains the vararg. - # However, the `argtypes` field of the corresponding `InterpretedFunction` will - # be `Tuple{<:InterpretedFunction, Tuple{typeof(f), Float64}}`. - # Therefore, the `args` field of this function will be a `Tuple{typeof(f), Float64}`. - # We must therefore transform it into a `Tuple` of type - # `Tuple{typeof(f), Tuple{Float64}}` before attempting to load it into `ai.arg_slots`. - refined_args = unflatten_vararg(Val(is_vararg), args, Val(length(ai.arg_slots) - 1)) - return __load_args!(ai.arg_slots, refined_args) -end - -@generated function __load_args!(arg_slots::Tuple, args::Tuple) - Ts = args.parameters - ns = filter(n -> !Base.issingletontype(Ts[n]), eachindex(Ts)) - loaders = map(n -> :(arg_slots[$n][] = args[$n]), ns) - return Expr(:block, loaders..., :(return nothing)) -end - -# -# Construct and run an InterpretedFunction. -# - -struct InterpretedFunction{sig<:Tuple, C, Treturn, Targ_info<:ArgInfo} - ctx::C - return_slot::SlotRef{Treturn} - arg_info::Targ_info - slots::Vector{AbstractSlot} - instructions::Vector{Inst} - bb_starts::Vector{Int} - bb_ends::Vector{Int} - ir::IRCode - interp::TapirInterpreter - spnames::Any -end - -# See above for other `_get_slot` methods. -function _get_slot(x, in_f::InterpretedFunction) - return _get_slot(x, in_f.slots, in_f.arg_info, in_f.ir.sptypes) -end - -function is_vararg_sig_and_sparam_names(sig) - world = Base.get_world_counter() - min = Base.RefValue{UInt}(typemin(UInt)) - max = Base.RefValue{UInt}(typemax(UInt)) - ms = Base._methods_by_ftype(sig, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector - m = only(ms).method - return m.isva, sparam_names(m) -end - -function sparam_names(m::Core.Method)::Vector{Symbol} - whereparams = ExprTools.where_parameters(m.sig) - whereparams === nothing && return Symbol[] - return map(whereparams) do name - name isa Symbol && return name - Meta.isexpr(name, :(<:)) && return name.args[1] - Meta.isexpr(name, :(>:)) && return name.args[1] - error("unrecognised type param $name") - end -end - -make_slot(x::Type{T}) where {T} = (@isdefined T) ? SlotRef{T}() : SlotRef{DataType}() -make_slot(x::CC.Const) = ConstSlot{_typeof(x.val)}(x.val) -make_slot(x::CC.PartialStruct) = SlotRef{x.typ}() -make_slot(::CC.PartialTypeVar) = SlotRef{TypeVar}() - -make_dummy_instruction(next_blk::Int) = @opaque (p::Int) -> next_blk - -# Special handling is required for PhiNodes, because their semantics require that when -# more than one PhiNode appears at the start of a basic block, they are run simulataneously -# rather than in sequence. See the SSAIR docs for an explanation of why this is the case. -function make_phi_instructions!(in_f::InterpretedFunction) - ir = in_f.ir - insts = in_f.instructions - for (b, bb) in enumerate(ir.cfg.blocks) - - # Find any phi nodes at the start of the block. - phi_node_inds = Int[] - foreach(n -> (ir.stmts.inst[n] isa PhiNode) && push!(phi_node_inds, n), bb.stmts) - isempty(phi_node_inds) && continue - - # Make a single instruction which runs all of the PhiNodes "simulataneously". - # Specifically, this instruction runs all of the phi nodes, storing the results of - # this into temporary storage, then writing from the temporary slots to the - # final slots. This has the effect of ensuring that phi nodes that depend on other - # phi nodes get the "old" values, not the new updated values. This was a - # surprisingly hard bug to catch and resolve. - nodes = [ir.stmts.inst[n] for n in phi_node_inds] - n_first = first(phi_node_inds) - is_blk_end = length(phi_node_inds) == length(bb.stmts) - insts[phi_node_inds[1]] = build_phinode_insts(nodes, in_f, n_first, b, is_blk_end) - - # Create dummy instructions for the remainder of the nodes. - for n in phi_node_inds[2:end] - insts[n] = make_dummy_instruction(_standard_next_block(is_blk_end, b)) - end - end - return nothing -end - -""" - InterpretedFunction(ctx::C, sig::Type{<:Tuple}, interp) where {C} - -Construct a data structure which can be used to execute the instruction specified by `sig`. -For example, -```julia -in_f = InterpretedFunction(DefaultCtx(), Tuple{typeof(sin), Float64}, Tapir.PInterp()) -in_f(sin, 5.0) -``` -will yield exactly the same result as running `sin(5.0)`. The advantage of this data -structure is that `build_rrule!!` is implemented for it, meaning that it can be -differentiated. - -The performance of `InterpretedFunction` largely depends on what the functions are that it -operates on, but it definitely adds a notable amount of overhead when compared to regular -Julia code. Typically this overhead is on the order 10ns per operation (on a modern CPU). - -For example, running on low-level code involving small scalar operations will -_typically_ take 10-100 times longer than running the original Julia function, but BLAS -calls on moderately large matrices has negligible overhead when compared with the original -function. - -## Caching - -`InterpretedFunction`s are cached by `interp` -- as a consequence, if you call -`InterpretedFunction` twice with the same arguments, the second call will just return a -cached result. - -## Known-Limitations - -While much of the language is supported, there are a few things that `InterpretedFunction` -_cannot_ execute. These include anything to do with threading, and exception handling. -The ability to handling threading may improve in future versions of `InterpretedFunction`, -but exception handling is unlikely to be supported, as it is not at all clear how it would -be handled in reverse-mode AD. - -Note that `InterpretedFunction`s should be fine with the constructs involved in exception -handling _provided_ that no exceptions are actually thrown. - -# Implementation - -An `InterpretedFunction` operates by first looking up the _optimised_ IRCode associated to -`sig` under `interp`. It associates each instruction in the IR with a `Core.OpaqueClosure`, -and each `Argument` / `SSAValue` in the IR with a (heap-allocated) `AbstractSlot` (for the -most part, these slots are `Ref`s). - -While the details of what each kind of `OpaqueClosure` can be found in the corresponding -`Tapir.build_inst` method, they generally have the following structure: -- load data from argument / ssa slots, -- do computation, -- write result to the instruction's ssa slot, -- return an integer indicating which instruction to execute next. - -The returned integer is permitted to take one of the following values: -- `-1`, in which case we should return, -- `0`, in which case the next instruction should be run, -- a positive integer, in which case execution jumps to the start of that block. - -The only argument to each `Core.OpaqueClosure` is an integer corresponding to the index of -the previous block that was run. As a result, each instruction has the _same_ signature, -meaning that while each instruction tends to do quite different things, we do not see an -explosion of types. Moreover, type-stability it maintained. -""" -function InterpretedFunction(ctx::C, sig::Type{<:Tuple}, interp) where {C} - - # Grab code associated to this function. - ir, Treturn = lookup_ir(interp, sig) - - # Slot into which the output of this function will be placed. - return_slot = SlotRef{Treturn}() - - # Construct argument reference references. - arg_types = Tuple{map(_get_type, ir.argtypes)..., } - is_vararg, spnames = is_vararg_sig_and_sparam_names(sig) - arg_info = ArgInfo(arg_types, is_vararg) - ir = normalise!(ir, spnames) - - # Create slots. In most cases, these are instances of `SlotRef`s, which can be read from - # and written to by instructions (they are essentially `Base.RefValue`s with a - # different name. Very occassionally the compiler will deduce that a particular slot has - # a constant value. In these cases, we instead create an instance of `ConstSlot`, which - # cannot be written to. - slots = AbstractSlot[make_slot(T) for T in ir.stmts.type] - - # Allocate memory for instructions and argument loading instructions. - insts = Vector{Inst}(undef, length(slots)) - - # Compute the index of the instruction associated with the start of each basic block - # in `ir`. This is used to know where to jump to when we hit a `Core.GotoNode` or - # `Core.GotoIfNot`. The `ir.cfg` very nearly gives this to us for free. - bb_starts = vcat(1, ir.cfg.index) - bb_ends = vcat(ir.cfg.index .- 1, length(slots)) - - # Extract the starting location of each basic block from the CFG and build IF. - in_f = InterpretedFunction{sig, C, Treturn, _typeof(arg_info)}( - ctx, return_slot, arg_info, slots, insts, bb_starts, bb_ends, ir, interp, spnames, - ) - - # Eagerly create PhiNode instructions, as this requires special handling. - make_phi_instructions!(in_f) - - # Cache this InterpretedFunction so that we don't have tobuild it again. - return in_f -end - -function (in_f::InterpretedFunction)(args::Vararg{Any, N}) where {N} - load_args!(in_f, args) - return __barrier(in_f) -end - -load_args!(in_f::InterpretedFunction, args::Targs) where {Targs} = load_args!(in_f.arg_info, args) - -# Execute an interpreted function, having already loaded the arguments into their slots. -function __barrier(in_f::Tf) where {Tf<:InterpretedFunction} - prev_block = 0 - next_block = 0 - current_block = 1 - n = 1 - instructions = in_f.instructions - while next_block != -1 - if !isassigned(instructions, n) - instructions[n] = build_inst(in_f, n) - end - next_block = instructions[n](prev_block) - if next_block == 0 - n += 1 - elseif next_block > 0 - n = in_f.bb_starts[next_block] - prev_block = current_block - current_block = next_block - next_block = 0 - end - end - return in_f.return_slot[] -end - -# Produce a `Dict` mapping from block numbers to line number of their first statement. -function block_map(cfg::CC.CFG) - line_to_blk_maps = map(((n, blk),) -> tuple.(blk.stmts, n), enumerate(cfg.blocks)) - return Dict(reduce(vcat, line_to_blk_maps)) -end - -function build_inst(in_f::InterpretedFunction{sig}, n::Int) where {sig} - @nospecialize in_f - ir_inst = in_f.ir.stmts.inst[n] - b = block_map(in_f.ir.cfg)[n] - is_blk_end = n in in_f.bb_ends - return build_inst(ir_inst, in_f, n, b, is_blk_end) -end - -# Use to handle dynamic dispatch inside `InterpretedFunction`s. -# -# `InterpretedFunction`s operate recursively -- if the types associated to the `args` field -# of a `:call` expression have not been inferred successfully, then we must wait until -# runtime to determine what code to run. The `DelayedInterpretedFunction` does exactly this. -struct DelayedInterpretedFunction{C, Tlocal_cache, T<:TapirInterpreter} - ctx::C - local_cache::Tlocal_cache - interp::T -end - -function (din_f::DelayedInterpretedFunction{C})(fargs::Vararg{Any, N}) where {C, N} - k = map(_typeof, fargs) - _evaluator = get(din_f.local_cache, k, nothing) - if _evaluator === nothing - sig = _typeof(fargs) - _evaluator = if is_primitive(C, sig) - _eval - else - InterpretedFunction(din_f.ctx, sig, din_f.interp) - end - din_f.local_cache[k] = _evaluator - end - return _evaluator(fargs...) -end diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index d9358a94..466cdedd 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -92,20 +92,12 @@ function infer_ir!(ir::IRCode) return __infer_ir!(ir, CC.NativeInterpreter(), __get_toplevel_mi_from_ir(ir, Tapir)) end -# Sometimes types in `IRCode` have been replaced by constants, or partially-completed -# structs. A particularly common case is that the first element of the `argtypes` field of -# an `IRCode` is a `Core.Const` containing the function to be called. `_get_type` recovers -# the type in such situations. -_get_type(x::Core.PartialStruct) = x.typ -_get_type(x::Core.Const) = _typeof(x.val) -_get_type(T) = T - # Given some IR, generates a MethodInstance suitable for passing to infer_ir!, if you don't # already have one with the right argument types. Credit to @oxinabox: # https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 function __get_toplevel_mi_from_ir(ir, _module::Module) mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ()); - mi.specTypes = Tuple{map(_get_type, ir.argtypes)...} + mi.specTypes = Tuple{map(_type, ir.argtypes)...} mi.def = _module return mi end @@ -255,6 +247,15 @@ purely a function of whether or not its `val` field is defined or not. is_reachable_return_node(x::ReturnNode) = isdefined(x, :val) is_reachable_return_node(x) = false +""" + is_unreachable(x::ReturnNode) + +Determine whehter `x` is a `ReturnNode`, and if it is, if it is also unreachable. This is +purely a function of whether or not its `val` field is defined or not. +""" +is_unreachable_return_node(x::ReturnNode) = !isdefined(x, :val) +is_unreachable_return_node(x) = false + """ globalref_type(x::GlobaRef) diff --git a/src/interpreter/registers.jl b/src/interpreter/registers.jl deleted file mode 100644 index 68e70866..00000000 --- a/src/interpreter/registers.jl +++ /dev/null @@ -1,46 +0,0 @@ -""" - AugmentedRegister(codual::CoDual, tangent_stack) - -A wrapper data structure for bundling together a codual and a tangent stack. These appear -in the code associated to active values in the primal. - -For example, a statment in the primal such as -```julia -%5 = sin(%4)::Float64 -``` -which provably returns a `Float64` in the primal, would return an `register_type(Float64)` -in the forwards-pass, where `register_type` will return an `AugmentedRegister` when the -primal type is `Float64`. -""" -struct AugmentedRegister{T<:CoDual, V} - codual::T - tangent_ref::V -end - -@inline primal(reg::AugmentedRegister) = primal(reg.codual) - -""" - register_type(::Type{P}) where {P} - -If `P` is the type associated to a primal register, the corresponding register in the -forwards-pass must be a `register_type(P)`. -""" -function register_type(::Type{P}) where {P} - P == DataType && return Any - P == UnionAll && return Any - P isa Union && return __union_register_type(P) - if isconcretetype(P) - return AugmentedRegister{codual_type(P), tangent_ref_type_ub(P)} - else - return AugmentedRegister - end -end - -# Specialised method for unions. -function __union_register_type(::Type{P}) where {P} - if P isa Union - CC.tmerge(AugmentedRegister{codual_type(P.a)}, __union_register_type(P.b)) - else - return AugmentedRegister{codual_type(P)} - end -end diff --git a/src/interpreter/reverse_mode_ad.jl b/src/interpreter/reverse_mode_ad.jl deleted file mode 100644 index 1cfc5358..00000000 --- a/src/interpreter/reverse_mode_ad.jl +++ /dev/null @@ -1,589 +0,0 @@ -# FwdsInsts have the same signature as Insts, but have different side-effects. -const FwdsInst = Core.OpaqueClosure{Tuple{Int}, Int} - -# The backwards instructions don't actually need to return an Int, however, there is -# currently a performance bug in OpaqueClosures which means that an allocation is produced -# if a constant is returned. Consequently, we have to return something non-constant. -# See https://github.com/JuliaLang/julia/issues/52620 for info. -# By convention, any backwards instruction which "does nothing" just returns its input, and -# does no other work. -const BwdsInst = Core.OpaqueClosure{Tuple{Int}, Int} - -const RuleSlot{V} = Union{SlotRef{V}, ConstSlot{V}} where {V<:Tuple{CoDual, Ref}} - -__primal_type(::Type{<:Tuple{<:CoDual{P}, <:Any}}) where {P} = @isdefined(P) ? P : Any -function __primal_type(::Type{P}) where {P<:Tuple{<:CoDual, <:Any}} - P isa Union && return Union{__primal_type(P.a), __primal_type(P.b)} - return Any -end - -primal_type(::AbstractSlot{P}) where {P} = __primal_type(P) - -function rule_slot_type(::Type{P}) where {P} - return Tuple{codual_type(P), tangent_ref_type_ub(P)} -end - -function make_rule_slot(::SlotRef{P}, ::Any) where {P} - return SlotRef{rule_slot_type(P)}() -end -function make_rule_slot(x::ConstSlot{P}, ::Any) where {P} - cd = uninit_codual(x[]) - stack = make_tangent_stack(P) - push!(stack, tangent(cd)) - return ConstSlot((cd, top_ref(stack))) -end - -make_tangent_stack(::Type{P}) where {P} = tangent_stack_type(P)() - -make_tangent_ref_stack(::Type{P}) where {P} = Stack{P}() -make_tangent_ref_stack(::Type{NoTangentRef}) = NoTangentRefStack() - -get_codual(x::RuleSlot) = x[][1] -get_tangent_stack(x::RuleSlot) = x[][2] - -increment_ref!(x::Ref, t) = setindex!(x, increment!!(x[], t)) - -## ReturnNode -function build_coinsts(node::ReturnNode, _, _, _rrule!!, ::Int, ::Int, ::Bool) - return build_coinsts( - ReturnNode, _rrule!!.ret, _rrule!!.ret_tangent, _get_slot(node.val, _rrule!!), - ) -end -function build_coinsts( - ::Type{ReturnNode}, ret::SlotRef{<:CoDual}, ret_tangent::SlotRef, val::RuleSlot -) - tangent_ref_stack = Stack{tangent_ref_type_ub(primal_type(val))}() - fwds_inst = @opaque function (p::Int) - push!(tangent_ref_stack, get_tangent_stack(val)) - ret[] = get_codual(val) - return -1 - end - bwds_inst = @opaque function (j::Int) - increment_ref!(pop!(tangent_ref_stack), ret_tangent[]) - return j - end - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -## GotoNode -build_coinsts(x::GotoNode, _, _, _, ::Int, ::Int, ::Bool) = build_coinsts(GotoNode, x.label) -function build_coinsts(::Type{GotoNode}, dest::Int) - return build_inst(GotoNode, dest)::FwdsInst, (@opaque (j::Int) -> j)::BwdsInst -end - -## GotoIfNot -function build_coinsts(x::GotoIfNot, _, _, _rrule!!, ::Int, b::Int, is_blk_end::Bool) - return build_coinsts(GotoIfNot, x.dest, b + 1, _get_slot(x.cond, _rrule!!)) -end -function build_coinsts(::Type{GotoIfNot}, dest::Int, next_blk::Int, cond::RuleSlot) - fwds_inst = @opaque (p::Int) -> primal(get_codual(cond)) ? next_blk : dest - bwds_inst = @opaque (j::Int) -> j - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -## PhiNode -function build_coinsts(ir_insts::Vector{PhiNode}, _, _rrule!!, n_first::Int, b::Int, is_blk_end::Bool) - nodes = (build_typed_phi_nodes(ir_insts, _rrule!!, n_first)..., ) - next_blk = _standard_next_block(is_blk_end, b) - return build_coinsts(Vector{PhiNode}, nodes, next_blk) -end -function build_coinsts( - ::Type{Vector{PhiNode}}, nodes::NTuple{N, TypedPhiNode}, next_blk::Int, -) where {N} - # Check that we're operating on CoDuals. - @assert all(x -> x.ret_slot isa RuleSlot, nodes) - @assert all(x -> all(y -> isa(y, RuleSlot), x.values), nodes) - - # Construct instructions. - fwds_inst = build_inst(Vector{PhiNode}, nodes, next_blk) - bwds_inst = @opaque (j::Int) -> j - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -## PiNode -function build_coinsts(x::PiNode, P, _, _rrule!!, n::Int, b::Int, is_blk_end::Bool) - val = _get_slot(x.val, _rrule!!) - ret = _rrule!!.slots[n] - return build_coinsts(PiNode, P, val, ret, _standard_next_block(is_blk_end, b)) -end -function build_coinsts( - ::Type{PiNode}, - ::Type{P}, - val::RuleSlot, - ret::RuleSlot{<:Tuple{R, <:Any}}, - next_blk::Int, -) where {R, P} - - my_tangent_stack = make_tangent_stack(P) - tangent_stack_stack = make_tangent_ref_stack(tangent_ref_type_ub(primal_type(val))) - - make_fwds(v) = R(primal(v), tangent(v)) - function fwds_run() - v, tangent_stack = val[] - push!(my_tangent_stack, tangent(v)) - push!(tangent_stack_stack, tangent_stack) - ret[] = (make_fwds(v), top_ref(my_tangent_stack)) - return next_blk - end - fwds_inst = @opaque (p::Int) -> fwds_run() - function bwds_run() - increment_ref!(pop!(tangent_stack_stack), pop!(my_tangent_stack)) - end - bwds_inst = @opaque (j::Int) -> (bwds_run(); return j) - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -## GlobalRef -function build_coinsts(x::GlobalRef, P, _, _rrule!!, n::Int, b::Int, is_blk_end::Bool) - next_blk = _standard_next_block(is_blk_end, b) - return build_coinsts(GlobalRef, P, _globalref_to_slot(x), _rrule!!.slots[n], next_blk) -end -function build_coinsts( - ::Type{GlobalRef}, ::Type{P}, global_ref::AbstractSlot, out::RuleSlot, next_blk::Int -) where {P} - my_tangent_stack = make_tangent_stack(P) - fwds_inst = @opaque function (p::Int) - v = uninit_codual(global_ref[]) - push!(my_tangent_stack, tangent(v)) - out[] = (v, top_ref(my_tangent_stack)) - return next_blk - end - bwds_inst = @opaque function (j::Int) - pop!(my_tangent_stack) - return j - end - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -## QuoteNode and literals -function build_coinsts(node, _, _, _rrule!!, n::Int, b::Int, is_blk_end::Bool) - x = ConstSlot(zero_codual(node isa QuoteNode ? node.value : node)) - next_blk = _standard_next_block(is_blk_end, b) - return build_coinsts(nothing, x, _rrule!!.slots[n], next_blk) -end -function build_coinsts(::Nothing, x::ConstSlot, out::RuleSlot, next_blk::Int) - my_tangent_stack = make_tangent_stack(primal_type(out)) - push!(my_tangent_stack, tangent(x[])) - fwds_inst = @opaque function (p::Int) - out[] = (x[], top_ref(my_tangent_stack)) - return next_blk - end - bwds_inst = @opaque (j::Int) -> j - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -## Expr - -get_rrule!!_evaluator(::typeof(_eval)) = rrule!! -get_rrule!!_evaluator(in_f::InterpretedFunction) = build_rrule!!(in_f) -get_rrule!!_evaluator(::DelayedInterpretedFunction) = rrule!! - -# Constructs a Vector which can holds instances of the pullback associated to -# `__rrule!!` when applied to the types in `codual_sig`. If `__rrule!!` infers for these -# types, then we should get a concretely-typed containers. Conversely, if inference fails, -# we fallback to `Any`. -function build_pb_stack(__rrule!!, evaluator, arg_slots) - deval = zero_codual(evaluator) - codual_sig = Tuple{_typeof(deval), map(codual_type ∘ primal_type, arg_slots)...} - possible_output_types = Base.return_types(__rrule!!, codual_sig) - if length(possible_output_types) == 0 - throw(error("No return type inferred for __rrule!! with sig $codual_sig")) - elseif length(possible_output_types) > 1 - @warn "Too many output types inferred" - display(possible_output_types) - println() - throw(error("> 1 return type inferred for __rrule!! with sig $codual_sig ")) - end - T_pb!! = only(possible_output_types) - if T_pb!! <: Tuple && T_pb!! !== Union{} - F = T_pb!!.parameters[2] - return Base.issingletontype(F) ? SingletonStack{F}() : Stack{F}() - else - return Stack{Any}() - end -end - -function build_coinsts(ir_inst::Expr, P, in_f, _rrule!!, n::Int, b::Int, is_blk_end::Bool) - is_invoke = Meta.isexpr(ir_inst, :invoke) - next_blk = _standard_next_block(is_blk_end, b) - val_slot = _rrule!!.slots[n] - if Meta.isexpr(ir_inst, :boundscheck) - return build_coinsts(Val(:boundscheck), val_slot, next_blk) - elseif is_invoke || Meta.isexpr(ir_inst, :call) - - # Extract args refs. - __args = is_invoke ? ir_inst.args[2:end] : ir_inst.args - arg_slots = map(arg -> _get_slot(arg, _rrule!!), (__args..., )) - - # Construct signature, and determine how the rrule is to be computed. - primal_sig = Tuple{map(arg -> eltype(_get_slot(arg, in_f)), (__args..., ))...} - evaluator = get_evaluator(in_f.ctx, primal_sig, in_f.interp, is_invoke) - __rrule!! = get_rrule!!_evaluator(evaluator) - - # Create stack for storing pullbacks. - pb_stack = build_pb_stack(__rrule!!, evaluator, arg_slots) - - return build_coinsts( - Val(:call), P, val_slot, arg_slots, evaluator, __rrule!!, pb_stack, next_blk - ) - elseif ir_inst.head in [ - :code_coverage_effect, :gc_preserve_begin, :gc_preserve_end, :loopinfo, - :leave, :pop_exception, - ] - return build_coinsts(Val(:skipped_expression), next_blk) - elseif Meta.isexpr(ir_inst, :throw_undef_if_not) - slot_to_check = _get_slot(ir_inst.args[2], _rrule!!) - return build_coinsts(Val(:throw_undef_if_not), slot_to_check, next_blk) - else - throw(error("Unrecognised expression $ir_inst")) - end -end - -function build_coinsts(::Val{:boundscheck}, out::RuleSlot, next_blk::Int) - @assert eltype(out) == Tuple{CoDual{Bool, NoTangent}, NoTangentRef} - fwds_inst = @opaque function (p::Int) - out[] = (zero_codual(true), NoTangentRef()) - return next_blk - end - bwds_inst = @opaque (j::Int) -> j - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -function build_coinsts( - ::Val{:call}, - ::Type{P}, - out::RuleSlot, - arg_slots::NTuple{N, RuleSlot} where {N}, - evaluator::Teval, - __rrule!!::Trrule!!, - pb_stack, - next_blk::Int, -) where {P, Teval, Trrule!!} - - my_tangent_stack = make_tangent_stack(P) - - tangent_stack_stacks = map(arg_slots) do arg_slot - make_tangent_ref_stack(tangent_ref_type_ub(primal_type(arg_slot))) - end - - function fwds_pass() - args = tuple_map(get_codual, arg_slots) - map(tangent_stack_stacks, arg_slots) do tangent_stack_stack, arg - push!(tangent_stack_stack, get_tangent_stack(arg)) - end - _out, pb!! = __rrule!!(zero_codual(evaluator), args...) - push!(my_tangent_stack, tangent(_out)) - push!(pb_stack, pb!!) - out[] = (_out, top_ref(my_tangent_stack)) - return nothing - end - fwds_inst = @opaque function (p::Int) - fwds_pass() - return next_blk - end - - function bwds_pass() - pb!! = pop!(pb_stack) - dout = pop!(my_tangent_stack) - tangent_stacks = map(pop!, tangent_stack_stacks) - dargs = tuple_map(set_immutable_to_zero ∘ getindex, tangent_stacks) - new_dargs = pb!!(dout, NoTangent(), dargs...) - map(increment_ref!, tangent_stacks, new_dargs[2:end]) - return nothing - end - bwds_inst = @opaque function (j::Int) - bwds_pass() - return j - end - # display(Base.code_ircode(fwds_pass, Tuple{})) - # display(Base.code_ircode(bwds_pass, Tuple{})) - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -function build_coinsts(::Val{:skipped_expression}, next_blk::Int) - return (@opaque (p::Int) -> next_blk), (@opaque (j::Int) -> j) -end - -function build_coinsts(::Val{:throw_undef_if_not}, slot::AbstractSlot, next_blk::Int) - fwds_inst = @opaque function (prev_blk::Int) - !isassigned(slot) && throw(error("Boooo, not assigned")) - return next_blk - end - bwds_inst = @opaque (j::Int) -> j - return fwds_inst::FwdsInst, bwds_inst::BwdsInst -end - -# -# Code execution -# - -function rrule!!(::CoDual{typeof(_eval)}, fargs::Vararg{CoDual, N}) where {N} - out, pb!! = rrule!!(fargs...) - _eval_pb!!(dout, d_eval, dfargs...) = d_eval, pb!!(dout, dfargs...)... - return out, _eval_pb!! -end - -function rrule!!(_f::CoDual{<:DelayedInterpretedFunction{C, F}}, args::CoDual...) where {C, F} - f = primal(_f) - s = _typeof(map(primal, args)) - if is_primitive(C, s) - return rrule!!(zero_codual(_eval), args...) - else - in_f = InterpretedFunction(f.ctx, s, f.interp) - return build_rrule!!(in_f)(zero_codual(in_f), args...) - end -end - -tangent_type(::Type{<:InterpretedFunction}) = NoTangent -tangent_type(::Type{<:DelayedInterpretedFunction}) = NoTangent - -function make_codual_arginfo(ai::ArgInfo{T, is_vararg}) where {T, is_vararg} - arg_slots = map(Base.Fix2(make_rule_slot, nothing), ai.arg_slots) - return ArgInfo{_typeof(arg_slots), is_vararg}(arg_slots) -end - -function make_arg_tangent_stacks(argtypes::Vector{Any}) - return map(a -> tangent_stack_type(a)(), (map(_get_type, argtypes)...,)) -end - -function load_rrule_args!( - ai::ArgInfo{T, is_vararg}, args::Tuple, arg_tangent_stacks::Tuple -) where {T, is_vararg} - # There is a difference between the varargs that we recieve, and the varargs of the - # original function. This section sorts that out. - # For example if the original function is `f(x...)`, then the `argtypes` field of its - # `IRCode` when calling e.g. `f(5.0)` will be `Tuple{typeof(f), Tuple{Float64}}`, where - # the second tuple contains the vararg. - # However, the `argtypes` field of the corresponding `InterpretedFunction` will - # be `Tuple{<:InterpretedFunction, Tuple{typeof(f), Float64}}`. - # Therefore, the `args` field of this function will be a `Tuple{typeof(f), Float64}`. - # We must therefore transform it into a `Tuple` of type - # `Tuple{typeof(f), Tuple{Float64}}` before attempting to load it into `ai.arg_slots`. - if is_vararg - num_args = length(ai.arg_slots) - 1 # once for first arg, once for vararg - primals = map(primal, args) - tangents = map(tangent, args) - refined_primal_args = (primals[1:num_args]..., (primals[num_args+1:end]..., )) - refined_tangent_args = (tangents[1:num_args]..., (tangents[num_args+1:end]..., )) - refined_args = map(CoDual, refined_primal_args, refined_tangent_args) - else - refined_args = args - end - - # Load the arguments into `ai.arg_slots`. - map(refined_args, arg_tangent_stacks) do arg, arg_tangent_stack - push!(arg_tangent_stack, tangent(arg)) - end - args = map((a, b) -> (a, top_ref(b)), refined_args, arg_tangent_stacks) - return __load_args!(ai.arg_slots, args) -end - -struct InterpretedFunctionRRule{ - sig<:Tuple, Tret, Tret_tangent, Targ_info<:ArgInfo, Targ_tangent_stacks -} - ret::SlotRef{Tret} - ret_tangent::SlotRef{Tret_tangent} - arg_info::Targ_info - arg_tangent_stacks::Targ_tangent_stacks - slots::Vector{RuleSlot} - fwds_instructions::Vector{FwdsInst} - bwds_instructions::Vector{BwdsInst} - n_stack::Stack{Int} - ir::IRCode -end - -function _get_slot(x, in_f::InterpretedFunctionRRule) - return _wrap_rule_slot(_get_slot(x, in_f.slots, in_f.arg_info, in_f.ir)) -end - -_wrap_rule_slot(x::RuleSlot) = x -function _wrap_rule_slot(x::ConstSlot{<:CoDual}) - stack = make_tangent_stack(primal_type(x)) - push!(stack, tangent(x[])) - return ConstSlot((x[], top_ref(stack))) -end -function _wrap_rule_slot(x::ConstSlot{P}) where {P} - T = tangent_type(P) - Tref = tangent_ref_type_ub(P) - stack = make_tangent_stack(P) - push!(stack, zero_tangent(x[])) - return ConstSlot{Tuple{codual_type(P), Tref}}((zero_codual(x[]), top_ref(stack))) -end - -# Special handling is required for PhiNodes, because their semantics require that when -# more than one PhiNode appears at the start of a basic block, they are run simulataneously -# rather than in sequence. See the SSAIR docs for an explanation of why this is the case. -function make_phi_instructions!( - in_f::InterpretedFunction, __rrule!!::InterpretedFunctionRRule -) - ir = in_f.ir - fwds_insts = __rrule!!.fwds_instructions - bwds_insts = __rrule!!.bwds_instructions - - for (b, bb) in enumerate(ir.cfg.blocks) - - # Find any phi nodes at the start of the block. - phi_node_inds = Int[] - foreach(n -> (ir.stmts.inst[n] isa PhiNode) && push!(phi_node_inds, n), bb.stmts) - isempty(phi_node_inds) && continue - - # Make a single instruction which runs all of the PhiNodes "simultaneously". - # Specifically, this instruction runs all of the phi nodes, storing the results of - # this into temporary storage, then writing from the temporary slots to the - # final slots. This has the effect of ensuring that phi nodes that depend on other - # phi nodes get the "old" values, not the new updated values. This was a - # surprisingly hard bug to catch and resolve. - nodes = [ir.stmts.inst[n] for n in phi_node_inds] - n_first = first(phi_node_inds) - is_blk_end = length(phi_node_inds) == length(bb.stmts) - fwds_inst, bwds_inst = build_coinsts(nodes, in_f, __rrule!!, n_first, b, is_blk_end) - fwds_insts[phi_node_inds[1]] = fwds_inst - bwds_insts[phi_node_inds[1]] = bwds_inst - - # Create dummy instructions for the remainder of the nodes. - for n in phi_node_inds[2:end] - fwds_insts[n] = make_dummy_instruction(_standard_next_block(is_blk_end, b)) - bwds_insts[n] = make_dummy_instruction(n) - end - end - return nothing -end - -function build_rrule!!(in_f::InterpretedFunction{sig}) where {sig} - - return_slot = SlotRef{codual_type(eltype(in_f.return_slot))}() - return_tangent_slot = SlotRef{tangent_type(eltype(in_f.return_slot))}() - arg_info = make_codual_arginfo(in_f.arg_info) - arg_tangent_stacks = make_arg_tangent_stacks(in_f.ir.argtypes) - - # Construct rrule!! for in_f. - Tret = eltype(return_slot) - Tret_tangent = eltype(return_tangent_slot) - __rrule!! = InterpretedFunctionRRule{ - sig, Tret, Tret_tangent, _typeof(arg_info), _typeof(arg_tangent_stacks) - }( - return_slot, - return_tangent_slot, - arg_info, - arg_tangent_stacks, - RuleSlot[ - make_rule_slot(primal_slot, inst) for - (primal_slot, inst) in zip(in_f.slots, in_f.ir.stmts.inst) - ], # SlotRefs - Vector{FwdsInst}(undef, length(in_f.instructions)), # fwds_instructions - Vector{BwdsInst}(undef, length(in_f.instructions)), # bwds_instructions - Stack{Int}(), - in_f.ir, - ) - - # Set PhiNodes. - make_phi_instructions!(in_f, __rrule!!) - - return __rrule!! -end - -struct InterpretedFunctionPb{Tret_tangent<:SlotRef, Targ_info, Tbwds_f, V, Q} - j::Int - bwds_instructions::Tbwds_f - ret_tangent::Tret_tangent - n_stack::Stack{Int} - arg_info::Targ_info - arg_tangent_stacks::V - arg_tangent_stack_refs::Q -end - -function (in_f_rrule!!::InterpretedFunctionRRule{sig})( - _in_f::CoDual{<:InterpretedFunction{sig}}, args::Vararg{CoDual, N} -) where {sig, N} - - # Load in variables. - return_slot = in_f_rrule!!.ret - arg_info = in_f_rrule!!.arg_info - arg_tangent_stacks = in_f_rrule!!.arg_tangent_stacks - n_stack = in_f_rrule!!.n_stack - - # Initialise variables. - load_rrule_args!(arg_info, args, arg_tangent_stacks) - in_f = primal(_in_f) - prev_block = 0 - next_block = 0 - current_block = 1 - n = 1 - j = length(n_stack) - - # Get references to top of tangent stacks for use on reverse-pass. - arg_tangent_stack_refs = map(top_ref, arg_tangent_stacks) - - # Run instructions until done. - while next_block != -1 - if !isassigned(in_f_rrule!!.fwds_instructions, n) - fwds, bwds = generate_coinstructions(in_f, in_f_rrule!!, n) - in_f_rrule!!.fwds_instructions[n] = fwds - in_f_rrule!!.bwds_instructions[n] = bwds - end - next_block = in_f_rrule!!.fwds_instructions[n](prev_block) - push!(n_stack, n) - if next_block == 0 - n += 1 - elseif next_block > 0 - n = in_f.bb_starts[next_block] - prev_block = current_block - current_block = next_block - next_block = 0 - end - end - - return_val = return_slot[] - interpreted_function_pb!! = InterpretedFunctionPb( - j, - in_f_rrule!!.bwds_instructions, - in_f_rrule!!.ret_tangent, - n_stack, - arg_info, - arg_tangent_stacks, - arg_tangent_stack_refs, - ) - return return_val, interpreted_function_pb!! -end - -function (if_pb!!::InterpretedFunctionPb)(dout, ::NoTangent, dargs::Vararg{Any, N}) where {N} - - # Update the output cotangent value to whatever is provided. - if_pb!!.ret_tangent[] = dout - tangent_stack_refs = if_pb!!.arg_tangent_stack_refs # this can go when we refactor - set_tangent_stacks!(tangent_stack_refs, dargs, if_pb!!.arg_info) - - # Run the instructions in reverse. Present assumes linear instruction ordering. - n_stack = if_pb!!.n_stack - bwds_instructions = if_pb!!.bwds_instructions - while length(n_stack) > if_pb!!.j - inst = bwds_instructions[pop!(n_stack)] - inst(0) - end - - # Return resulting tangents from slots. - return NoTangent(), assemble_dout(if_pb!!.arg_tangent_stacks, if_pb!!.arg_info)... -end - -function set_tangent_stacks!(tangent_stacks, dargs, ai::ArgInfo{<:Any, is_va}) where {is_va} - refined_dargs = unflatten_vararg(Val(is_va), dargs, Val(length(ai.arg_slots) - 1)) - map(setindex!, tangent_stacks, refined_dargs) -end - -function assemble_dout(tangent_stacks, ::ArgInfo{<:Any, is_va}) where {is_va} - dargs = map(pop!, tangent_stacks) - return is_va ? (dargs[1:end-1]..., dargs[end]...) : dargs -end - -function generate_coinstructions(in_f, in_f_rrule!!, n) - ir_inst = in_f.ir.stmts.inst[n] - ir_type = _get_type(in_f.ir.stmts.type[n]) - b = block_map(in_f.ir.cfg)[n] - is_blk_end = n in in_f.bb_ends - return build_coinsts(ir_inst, ir_type, in_f, in_f_rrule!!, n, b, is_blk_end) -end - -# Slow implementation, but useful for testing correctness. -function rrule!!(f_in::CoDual{<:InterpretedFunction}, args::CoDual...) - return build_rrule!!(primal(f_in))(f_in, args...) -end diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index c4666a3e..560d4365 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -69,6 +69,15 @@ function shared_data_stmts(p::SharedDataPairs)::Vector{Tuple{ID, NewInstruction} end end +#= +The block stack is the stack used to keep track of which basic blocks are visited on the +forwards pass, and therefore which blocks need to be visited on the reverse pass. There is +one block stack per derived rule. +By using Int32, we assume that there aren't more than `typemax(Int32)` unique basic blocks +in a given function, which ought to be reasonable. +=# +const BlockStack = Stack{Int32} + #= ADInfo @@ -82,42 +91,50 @@ codegen which produces the forwards- and reverse-passes. to determine which blocks to visit. - `block_stack`: the block stack. Can always be found at `block_stack_id` in the forwards- and reverse-passes. -- `entry_id`: special ID associated to the block inserted at the start of execution in the - the forwards-pass, and the end of execution in the pullback. +- `entry_id`: ID associated to the block inserted at the start of execution in the the + forwards-pass, and the end of execution in the pullback. - `shared_data_pairs`: the `SharedDataPairs` used to define the captured variables passed - to both the forwards- and reverse-passes.. + to both the forwards- and reverse-passes. - `arg_types`: a map from `Argument` to its static type. -- `ssa_insts`: a map from `ID` associated to lines to the primal `NewInstruction`. -- `arg_tangent_stacks`: a map from primal `Argument`s to their tangent stacks. If the stack - associated to an `Argument` is a bits type then this will just be the tangent stack. - Otherwise, it will be the `ID` associated to the stack, and the stack itself will be put - in the `shared_data_pairs`. -- `tangent_stacks`: a map from `ID` to tangent stacks. If the tangent stack associated to - the `ID` is a bits type, then this will actually be the tangent stack. Otherwise it will - be the `ID` associated to the stack, and the stack itself will be put in the - `shared_data_pairs`. +- `ssa_insts`: a map from `ID` associated to lines to the primal `NewInstruction`. This + contains the line of code, its static / inferred type, and some other detailss. See + `Core.Compiler.NewInstruction` for a full list of fields. +- `arg_rdata_ref_ids`: the dict mapping from arguments to the `ID` which creates and + initialises the `Ref` which contains the reverse data associated to that argument. + Recall that the heap allocations associated to this `Ref` are always optimised away in + the final programme. +- `ssa_rdata_ref_ids`: the same as `arg_rdata_ref_ids`, but for each `ID` associated to an + ssa rather than each argument. +- `safety_on`: if `true`, run in "safe mode" -- wraps all rule calls in `SafeRRule`. This is + applied recursively, so that safe mode is also switched on in derived rules. +- `is_used_dict`: for each `ID` associated to a line of code, is `false` if line is not used + anywhere in any other line of code. =# struct ADInfo interp::PInterp block_stack_id::ID - block_stack::Stack{Int32} + block_stack::BlockStack entry_id::ID shared_data_pairs::SharedDataPairs arg_types::Dict{Argument, Any} ssa_insts::Dict{ID, NewInstruction} - arg_tangent_stacks::Dict{Argument, Any} - tangent_stacks::Dict{ID, Any} + arg_rdata_ref_ids::Dict{Argument, ID} + ssa_rdata_ref_ids::Dict{ID, ID} + safety_on::Bool + is_used_dict::Dict{ID, Bool} end -# The constructor that you should use for ADInfo. +# The constructor that you should use for ADInfo if you don't have a BBCode lying around. +# See the definition of the ADInfo struct for info on the arguments. function ADInfo( interp::PInterp, arg_types::Dict{Argument, Any}, ssa_insts::Dict{ID, NewInstruction}, - arg_tangent_stacks, + is_used_dict::Dict{ID, Bool}, + safety_on::Bool, ) shared_data_pairs = SharedDataPairs() - block_stack = Stack{Int32}() + block_stack = BlockStack() return ADInfo( interp, add_data!(shared_data_pairs, block_stack), @@ -126,48 +143,97 @@ function ADInfo( shared_data_pairs, arg_types, ssa_insts, - make_arg_tangent_stacks!(shared_data_pairs, arg_tangent_stacks), - make_tangent_stacks!(shared_data_pairs, ssa_insts), + Dict((k, ID()) for k in keys(arg_types)), + Dict((k, ID()) for k in keys(ssa_insts)), + safety_on, + is_used_dict, ) end -function __log_data(p::Union{ADInfo, SharedDataPairs}, x) - return Base.issingletontype(_typeof(x)) ? x : add_data!(p, x) +# The constructor you should use for ADInfo if you _do_ have a BBCode lying around. See the +# ADInfo struct for information regarding `interp` and `safety_on`. +function ADInfo(interp::PInterp, ir::BBCode, safety_on::Bool) + arg_types = Dict{Argument, Any}( + map(((n, t),) -> (Argument(n) => _type(t)), enumerate(ir.argtypes)) + ) + stmts = collect_stmts(ir) + ssa_insts = Dict{ID, NewInstruction}(stmts) + is_used_dict = characterise_used_ids(stmts) + return ADInfo(interp, arg_types, ssa_insts, is_used_dict, safety_on) end -# Construct a map from primal `Argument`s to the location of its tangent stack in the -# forwards-pass and pullback. If tangent stack is a singleton, just yields the tangent -# stack itself. -function make_arg_tangent_stacks!(p::SharedDataPairs, arg_tangent_stacks) - arguments = Argument.(eachindex(arg_tangent_stacks)) - stack_ids = map(Base.Fix1(__log_data, p), arg_tangent_stacks) - return Dict{Argument, Any}(zip(arguments, stack_ids)) -end +# Shortcut for `add_data!(info.shared_data_pairs, data)`. +add_data!(info::ADInfo, data)::ID = add_data!(info.shared_data_pairs, data) -# Construct a map from primal `ID`s corresponding to lines in the IR, to the location of -# their tangent stacks in the forwards-pass and pullback. If tangent stacks is a singleton, -# just yields the tangent stack itself. -function make_tangent_stacks!(p::SharedDataPairs, ssa_insts::Dict{ID, NewInstruction}) - tangent_stacks = Dict{ID, Any}() - for (k, inst) in ssa_insts - Meta.isexpr(inst.stmt, :call) || Meta.isexpr(inst.stmt, :invoke) || continue - tangent_stacks[k] = __log_data(p, make_tangent_stack(_get_type(inst.type))) - end - return tangent_stacks +# Returns `x` if it is a singleton, or the `ID` of the ssa which will contain it on the +# forwards- and reverse-passes. The reason for this is that if something is a singleton, it +# can be placed directly in the IR. +function add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x) + return Base.issingletontype(_typeof(x)) ? x : add_data!(p, x) end -# Shortcut for `add_data!(info.shared_data_pairs, data)`. -add_data!(info::ADInfo, data) = add_data!(info.shared_data_pairs, data) +# Returns `true` if `id` is used by any of the lines in the ir, false otherwise. +is_used(info::ADInfo, id::ID)::Bool = info.is_used_dict[id] # Returns the static / inferred type associated to `x`. get_primal_type(info::ADInfo, x::Argument) = info.arg_types[x] -get_primal_type(info::ADInfo, x::ID) = _get_type(info.ssa_insts[x].type) +get_primal_type(info::ADInfo, x::ID) = _type(info.ssa_insts[x].type) get_primal_type(::ADInfo, x::QuoteNode) = _typeof(x.value) get_primal_type(::ADInfo, x) = _typeof(x) function get_primal_type(::ADInfo, x::GlobalRef) return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty end +# Returns the `ID` associated to the line in the reverse pass which will contain the +# reverse data for `x`. If `x` is not an `Argument` or `ID`, then `nothing` is returned. +get_rev_data_id(info::ADInfo, x::Argument) = info.arg_rdata_ref_ids[x] +get_rev_data_id(info::ADInfo, x::ID) = info.ssa_rdata_ref_ids[x] +get_rev_data_id(::ADInfo, ::Any) = nothing + +# Create the statements which initialise the reverse-data `Ref`s. +function reverse_data_ref_stmts(info::ADInfo) + arg_stmts = [(id, __ref(_type(info.arg_types[k]))) for (k, id) in info.arg_rdata_ref_ids] + ssa_stmts = [(id, __ref(_type(info.ssa_insts[k].type))) for (k, id) in info.ssa_rdata_ref_ids] + return vcat(arg_stmts, ssa_stmts) +end + +# Helper for reverse_data_ref_stmts. +__ref(P) = new_inst(Expr(:call, __make_ref, P)) + +# Helper for reverse_data_ref_stmts. +@inline @generated function __make_ref(::Type{P}) where {P} + R = zero_like_rdata_type(P) + return :(Ref{$R}(Tapir.zero_like_rdata_from_type(P))) +end + +@inline __make_ref(::Type{Union{}}) = nothing + +# Returns the number of arguments that the primal function has. +num_args(info::ADInfo) = length(info.arg_types) + +# This struct is used to ensure that `ZeroRData`s, which are used as placeholder zero +# elements whenever an actual instance of a zero rdata for a particular primal type cannot +# be constructed without also having an instance of said type, never reach rules. +# On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures +# that if it is a `ZeroRData`, we instead get an actual zero of the correct type. If it is +# not a zero rdata, the computation _should_ be elided via inlining + constant prop. +struct RRuleZeroWrapper{Trule} + rule::Trule +end + +struct RRuleWrapperPb{Tpb!!, Tl} + pb!!::Tpb!! + l::Tl +end + +(rule::RRuleWrapperPb)(dy) = rule.pb!!(increment!!(dy, instantiate(rule.l))) + +@inline function (rule::RRuleZeroWrapper{R})(f::F, args::Vararg{CoDual, N}) where {R, F, N} + y, pb!! = rule.rule(f, args...) + l = LazyZeroRData(primal(y)) + return y::CoDual, (pb!! isa NoPullback ? pb!! : RRuleWrapperPb(pb!!, l)) +end + #= ADStmtInfo @@ -214,20 +280,25 @@ make_ad_stmts!(::Nothing, line::ID, ::ADInfo) = ad_stmt_info(line, nothing, noth # `ReturnNode`s have a single field, `val`, for which there are three cases to consider: # -# 1. `val isa Union{Argument, ID}`: this is an active bit of data. Consequently, we know -# that it will be an `AugmentedRegister` already, and can just return it. Therefore `stmt` -# is returned as the forwards-pass (with any `Argument`s incremented), and nothing happens -# in the pullback. -# 2. `val` is undefined: this `ReturnNode` is unreachable. Consequently, we'll never hit the +# 1. `val` is undefined: this `ReturnNode` is unreachable. Consequently, we'll never hit the # associated statements on the forwards-pass of pullback. We just return the original # statement on the forwards-pass, and `nothing` on the reverse-pass. +# 2. `val isa Union{Argument, ID}`: this is an active piece of data. Consequently, we know +# that it will be an `CoDual` already, and can just return it. Therefore `stmt` +# is returned as the forwards-pass (with any `Argument`s incremented). On the reverse-pass +# the associated rdata ref should be incremented with the rdata passed to the pullback, +# which lives in argument 2. # 3. `val` is defined, but not a `Union{Argument, ID}`: in this case we're returning a -# constant -- build a constant register and return that. +# constant -- build a constant CoDual and return that. There is nothing to do on the +# reverse pass. function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) - if !isdefined(stmt, :val) || is_active(stmt.val) - return ad_stmt_info(line, inc_args(stmt), nothing) + is_reachable_return_node(stmt) || return ad_stmt_info(line, inc_args(stmt), nothing) + if is_active(stmt.val) + rdata_id = get_rev_data_id(info, stmt.val) + rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2))) + return ad_stmt_info(line, inc_args(stmt), rvs) else - return ad_stmt_info(line, ReturnNode(const_register(stmt.val, info)), nothing) + return ad_stmt_info(line, ReturnNode(const_codual(stmt.val, info)), nothing) end end @@ -239,18 +310,17 @@ end # Identity forwards-pass, no-op reverse. No shared data. function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo) stmt = inc_args(stmt) - if stmt.cond isa Union{Argument, ID} - # If cond refers to a register, then the primal must be extracted. - cond_id = ID() - fwds = [ - (cond_id, new_inst(Expr(:call, primal, stmt.cond))), - (line, new_inst(IDGotoIfNot(cond_id, stmt.dest), Any)), - ] - return ad_stmt_info(line, fwds, nothing) - else - # If something other than a register, then there is nothing to do. - return ad_stmt_info(line, stmt, nothing) - end + + # If cond is not going to be wrapped in a `CoDual`, so just return the stmt. + is_active(stmt.cond) || return ad_stmt_info(line, stmt, nothing) + + # stmt.cond is active, so primal must be extracted from `CoDual`. + cond_id = ID() + fwds = [ + (cond_id, new_inst(Expr(:call, primal, stmt.cond))), + (line, new_inst(IDGotoIfNot(cond_id, stmt.dest), Any)), + ] + return ad_stmt_info(line, fwds, nothing) end # Identity forwards-pass, no-op reverse. No shared data. @@ -259,116 +329,85 @@ function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo) new_vals = Vector{Any}(undef, length(vals)) for n in eachindex(vals) isassigned(vals, n) || continue - new_vals[n] = is_active(vals[n]) ? __inc(vals[n]) : const_register(vals[n], info) + new_vals[n] = is_active(vals[n]) ? __inc(vals[n]) : const_codual(vals[n], info) end # It turns out to be really very important to do type inference correctly for PhiNodes. # For some reason, type inference really doesn't like it when you encounter mutually- # dependent PhiNodes whose types are unknown and for which you set the flag to - # CC.IR_FLAG_REFINED. - new_type = register_type(get_primal_type(info, line)) + # CC.IR_FLAG_REFINED. To avoid this we directly tell the compiler what the type is. + new_type = fcodual_type(get_primal_type(info, line)) _inst = new_inst(IDPhiNode(stmt.edges, new_vals), new_type, info.ssa_insts[line].flag) return ad_stmt_info(line, _inst, nothing) end function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) - isa(stmt.val, Union{Argument, ID}) || unhandled_feature("PiNode: $stmt") - - # Create line which sharpens the register type as much as possible. - sharp_primal_type = _get_type(stmt.typ) - sharpened_register_type = AugmentedRegister{codual_type(_get_type(sharp_primal_type))} - new_pi_line = ID() - new_pi = PiNode(__inc(stmt.val), sharpened_register_type) - - # Create a statement which moves data from the loosely-typed register to a more - # strictly typed one, which is possible because of the `PiNode`. - tangent_stack = make_tangent_stack(sharp_primal_type) - tangent_stack_id = add_data!(info, tangent_stack) - val_type = get_primal_type(info, stmt.val) - tangent_ref_stack = make_tangent_ref_stack(tangent_ref_type_ub(val_type)) - tangent_ref_stack_id = add_data!(info, tangent_ref_stack) - new_line = Expr(:call, __pi_fwds!, tangent_stack_id, tangent_ref_stack_id, new_pi_line) + + # Assume that the PiNode contains active data -- it's hard to see why a PiNode would be + # created for e.g. a constant. Error if code is encountered where this doesn't hold. + is_active(stmt.val) || unhandled_feature("PiNode: $stmt") + + # Get the primal type of this line, and the rdata refs for the `val` of this `PiNode` + # and this line itself. + P = get_primal_type(info, line) + val_rdata_ref_id = get_rev_data_id(info, stmt.val) + output_rdata_ref_id = get_rev_data_id(info, line) # Assemble the above lines and construct reverse-pass. return ad_stmt_info( line, - [(new_pi_line, new_inst(new_pi)), (line, new_inst(new_line))], - Expr(:call, __pi_rvs!, tangent_stack_id, tangent_ref_stack_id), + PiNode(stmt.val, fcodual_type(_type(stmt.typ))), + Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id), ) end -@inline function __pi_fwds!(tangent_stack, tangent_ref_stack, reg::AugmentedRegister) - push!(tangent_ref_stack, reg.tangent_ref) - push!(tangent_stack, tangent(reg.codual)) - return AugmentedRegister(reg.codual, top_ref(tangent_stack)) -end - -@inline function __pi_rvs!(tangent_stack, tangent_ref_stack) - increment_ref!(pop!(tangent_ref_stack), pop!(tangent_stack)) +@inline function __pi_rvs!(::Type{P}, val_rdata_ref::Ref, output_rdata_ref::Ref) where {P} + increment_ref!(val_rdata_ref, __deref_and_zero(P, output_rdata_ref)) return nothing end -# Constant GlobalRefs are handled. See const_register. Non-constant -# GlobalRefs are handled by assuming that they are constant, and creating a register with -# the value. We then check at run-time that the value has not changed. +# Constant GlobalRefs are handled. See const_codual. Non-constant GlobalRefs are handled by +# assuming that they are constant, and creating a CoDual with the value. We then check at +# run-time that the value has not changed. function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo) - if isconst(stmt) - return const_ad_stmt(stmt, line, info) - else - reg = const_register(getglobal(stmt.mod, stmt.name), info) - gref_id = ID() - fwds = [ - (gref_id, new_inst(stmt)), - (line, new_inst(Expr(:call, __verify_const, gref_id, reg))), - ] - return ad_stmt_info(line, fwds, nothing) - end + isconst(stmt) && return const_ad_stmt(stmt, line, info) + + x = const_codual(getglobal(stmt.mod, stmt.name), info) + globalref_id = ID() + fwds = [ + (globalref_id, new_inst(stmt)), + (line, new_inst(Expr(:call, __verify_const, globalref_id, x))), + ] + return ad_stmt_info(line, fwds, nothing) end -# Helper used by `make_ad_stmts! ` for `GlobalRef`. +# Helper used by `make_ad_stmts! ` for `GlobalRef`. Noinline to avoid IR bloat. @noinline function __verify_const(global_ref, stored_value) @assert global_ref == primal(stored_value) - return stored_value + return uninit_fcodual(global_ref) end -# QuoteNodes are constant. See make_const_register for details. +# QuoteNodes are constant. make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) -# Literal constant. See const_register for details. +# Literal constant. make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) # `make_ad_stmts!` for constants. function const_ad_stmt(stmt, line::ID, info::ADInfo) - reg = const_register(stmt, info) - return ad_stmt_info(line, reg isa ID ? Expr(:call, identity, reg) : reg, nothing) + x = const_codual(stmt, info) + return ad_stmt_info(line, x isa ID ? Expr(:call, identity, x) : x, nothing) end -# Build an `AugmentedRegister` from `stmt`, which will be checked to ensure that its value -# is constant. If the resulting register is a bits type, then it is returned. If it is not, -# then the register is put into shared data, and the ID associated to it in the forwards- -# and reverse-passes returned. -function const_register(stmt, info::ADInfo) - reg = build_const_reg(stmt) - return isbitstype(_typeof(reg)) ? reg : add_data!(info, reg) +# Build a `CoDual` from `stmt`, with zero / uninitialised fdata. If the resulting CoDual is +# a bits type, then it is returned. If it is not, then the CoDual is put into shared data, +# and the ID associated to it in the forwards- and reverse-passes returned. +function const_codual(stmt, info::ADInfo) + x = uninit_fcodual(get_const_primal_value(stmt)) + return isbitstype(_typeof(x)) ? x : add_data!(info, x) end -# Create a constant augmented register which lives in the shared data. Returns the `ID` -# which will be associated to this data in the forwards- and reverse-passes. -shared_data_const_reg(stmt, info::ADInfo) = add_data!(info, build_const_reg(stmt)) - -# Create an `AugmentedRegister` containing the values associated to `stmt`, a zero tangent. -# Pushes a single element onto the stack, and puts a reference to that stack in the -# register. -function build_const_reg(stmt) - primal_value = get_const_primal_value(stmt) - tangent_stack = make_tangent_stack(_typeof(primal_value)) - tangent = uninit_tangent(primal_value) - push!(tangent_stack, tangent) - return AugmentedRegister(CoDual(primal_value, tangent), top_ref(tangent_stack)) -end - -# Get the value associated to `x`. For `GlobalRef`s, verify that `x` is indeed a constant, -# and error if it is not. +# Get the value associated to `x`. For `GlobalRef`s, verify that `x` is indeed a constant. function get_const_primal_value(x::GlobalRef) isconst(x) || unhandled_feature("Non-constant GlobalRef not supported: $x") return getglobal(x.mod, x.name) @@ -396,96 +435,129 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) args = ((is_invoke ? stmt.args[2:end] : stmt.args)..., ) arg_types = map(arg -> get_primal_type(info, arg), args) + # Special case: if the result of a call to getfield is un-used, then leave the + # primal statment alone (just increment arguments as usual). This was causing + # performance problems in a couple of situations where the field being requested is + # not known at compile time. `getfield` cannot be dead-code eliminated, because it + # can throw an error if the requested field does not exist. Everything _other_ than + # the boundscheck is eliminated in LLVM codegen, so it's important that AD doesn't + # get in the way of this. + # + # This might need to be generalised to more things than just `getfield`, but at the + # time of writing this comment, it's unclear whether or not this is the case. + if !is_used(info, line) && get_const_primal_value(args[1]) == getfield + fwds = new_inst(Expr(:call, __fwds_pass_no_ad!, map(__inc, args)...)) + return ad_stmt_info(line, fwds, nothing) + end + # Construct signature, and determine how the rrule is to be computed. sig = Tuple{arg_types...} - rule = if is_primitive(context_type(info.interp), sig) + raw_rule = if is_primitive(context_type(info.interp), sig) rrule!! # intrinsic / builtin / thing we provably have rule for elseif is_invoke - LazyDerivedRule(info.interp, sig) # Static dispatch + LazyDerivedRule(info.interp, sig, info.safety_on) # Static dispatch else - DynamicDerivedRule(info.interp) # Dynamic dispatch + DynamicDerivedRule(info.interp, info.safety_on) # Dynamic dispatch end + # Wrap the raw rule in a struct which ensures that any `ZeroRData`s are stripped + # away before the raw_rule is called. + zero_safe_rule = RRuleZeroWrapper(raw_rule) + + # If safe mode has been requested, use a safe rule. + rule = info.safety_on ? SafeRRule(zero_safe_rule) : zero_safe_rule + # If the rule is `rrule!!` (i.e. `sig` is primitive), then don't bother putting # the rule into shared data, because it's safe to put it directly into the code. - rule_ref = __log_data(info, rule) - - # Tangent stacks are allocated in build_rrule, and stored in the `info`. Just - # retrieve the stack associated to the tangent returned from this line. - ret_tangent_stack_id = info.tangent_stacks[line] + rule_ref = add_data_if_not_singleton!(info, rule) # If the type of the pullback is a singleton type, then there is no need to store it # in the shared data, it can be interpolated directly into the generated IR. - pb_stack = build_pb_stack(_typeof(rule), arg_types) - pb_stack_id = __log_data(info, pb_stack) - - # if the pullback is a `NoPullback`, then there is no need to log the references to - # the tangent stacks associated to the inputs to this call, because there will never - # need to be any incrementing done. There are functions called within - # `__fwds_pass!` and `__rvs_pass!` that specialise on the type of the pullback to - # avoid ever using the arg tangent ref stacks, so we just need to create a default - # value here (`nothing`), as it will never be used. - arg_tangent_ref_stacks_id = ID() - if pb_stack isa SingletonStack{NoPullback} - arg_tangent_ref_stacks = nothing - else - ref_stacks = map(arg_types, args) do arg_type, arg - stack = __make_arg_tangent_ref_stack(arg_type, arg) - if Base.issingletontype(_typeof(stack)) - return stack - elseif haskey(info.tangent_stacks, arg) - return info.tangent_stacks[arg] - elseif arg isa Argument - return info.arg_tangent_stacks[arg] - else - return add_data!(info, stack) - end - end - arg_tangent_ref_stacks = Expr(:call, __tangent_ref_stacks, ref_stacks...) - end - - # Create calls to `__fwds_pass!` and `__rvs_pass!`, which run the forwards pass and - # pullback associated to a call / invoke. - fwds_pass_call = Expr( - :call, - __fwds_pass!, - arg_tangent_ref_stacks_id, - rule_ref, - ret_tangent_stack_id, - pb_stack_id, - register_type(get_primal_type(info, line)), - map(__inc, args)..., - ) + T_pb!! = pullback_type(_typeof(rule), arg_types) + pb_stack_id = add_data_if_not_singleton!(info, build_pb_stack(T_pb!!)) + + # + # Write forwards-pass. These statements are written out manually, as writing them + # out in a function would prevent inlining in some (all?) type-unstable situations. + # + + # Make arguments to rrule call. Things which are not already CoDual must be made so. + codual_arg_ids = map(_ -> ID(), args) + __codual_args = map(arg -> Expr(:call, __make_codual, __inc(arg)), args) + codual_args = Tuple{ID, NewInstruction}[ + (id, new_inst(arg)) for (id, arg) in zip(codual_arg_ids, __codual_args) + ] - rvs_pass_call = Expr( - :call, __rvs_pass!, arg_tangent_ref_stacks_id, ret_tangent_stack_id, pb_stack_id + # Make call to rule. + rule_call_id = ID() + rule_call = Expr(:call, rule_ref, codual_arg_ids...) + + # Extract the output-codual from the returned tuple. + raw_output_id = ID() + raw_output = Expr(:call, getfield, rule_call_id, 1) + + # Extract the pullback from the returned tuple. + pb_id = ID() + pb = Expr(:call, getfield, rule_call_id, 2) + + # Push the pullback stack. + push_pb_stack_id = ID() + push_pb_stack = Expr(:call, __push_pb_stack!, pb_stack_id, pb_id) + + # Provide a type assertion to help the compiler out. Doing it this way, rather than + # directly changing the inferred type of the instruction associated to raw_output, + # has the advantage of not introducing the possibility of segfaults. It will still + # be optimised away in situations where the compiler is able to successfully infer + # the type, so performance in performance-critical situations is unaffected. + output_id = line + F = fcodual_type(get_primal_type(info, line)) + output = Expr(:call, Core.typeassert, raw_output_id, F) + + # Create statements associated to forwards-pass. + fwds = vcat( + codual_args, + Tuple{ID, NewInstruction}[ + (rule_call_id, new_inst(rule_call)), + (raw_output_id, new_inst(raw_output)), + (pb_id, new_inst(pb)), + (push_pb_stack_id, new_inst(push_pb_stack)), + (output_id, new_inst(output)), + ], ) - fwds = [ - (arg_tangent_ref_stacks_id, new_inst(arg_tangent_ref_stacks)), - (line, new_inst(fwds_pass_call)), - ] - rvs = Tuple{ID, NewInstruction}[ - (arg_tangent_ref_stacks_id, new_inst(arg_tangent_ref_stacks)), - (line, new_inst(rvs_pass_call)), - ] - return ad_stmt_info(line, fwds, rvs) + # Make statement associated to reverse-pass. If the reverse-pass is provably a + # NoPullback, then don't bother doing anything at all. + rvs_pass = if T_pb!! <: NoPullback + nothing + else + Expr( + :call, + __rvs_pass!, + get_primal_type(info, line), + pb_stack_id, + get_rev_data_id(info, line), + map(Base.Fix1(get_rev_data_id, info), args)..., + ) + end + return ad_stmt_info(line, fwds, new_inst(rvs_pass)) elseif Meta.isexpr(stmt, :boundscheck) # For some reason the compiler cannot handle boundscheck statements when we run it # again. Consequently, emit `true` to be safe. Ideally we would handle this in a # more natural way, but I'm not sure how to do that. - tmp = AugmentedRegister(zero_codual(true), NoTangentStack()) - return ad_stmt_info(line, tmp, nothing) + return ad_stmt_info(line, zero_fcodual(true), nothing) elseif Meta.isexpr(stmt, :code_coverage_effect) - # Code coverage irrelevant for derived code. + # Code coverage irrelevant for derived code, and really inflates it in some + # situations. Since code coverage is usually only requrested during CI, including + # these effects also creates differences between the code generated when developing + # and the code generated in CI, which occassionally creates hard-to-debug issues. return ad_stmt_info(line, nothing, nothing) elseif Meta.isexpr(stmt, :copyast) # Get constant out and shove it in shared storage. - reg = const_register(stmt.args[1], info) - return ad_stmt_info(line, Expr(:call, identity, reg), nothing) + x = const_codual(stmt.args[1], info) + return ad_stmt_info(line, Expr(:call, identity, x), nothing) elseif Meta.isexpr(stmt, :loopinfo) # Cannot pass loopinfo back through the optimiser for some reason. @@ -508,162 +580,115 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) end end -# Used in `make_ad_stmts!` for call and invoke exprs. If an argument to the stmt is active, -# then we grab its tangnet ref stack. If it's inactive (a constant of some kind -- really -# anything that's not an `Argument` or an `ID`), then we create a dummy stack that will get -# optimised away. -function __make_arg_tangent_ref_stack(arg_type, arg) - is_active(arg) || return InactiveStack(InactiveRef(__zero_tangent(arg))) - return make_tangent_ref_stack(tangent_ref_type_ub(arg_type)) -end - -@inline function __tangent_ref_stacks(args::Vararg{Any, N}) where {N} - return tuple_map(___tangent_ref_stacks_helper, args) -end - -# Distinguish between tangent stacks and tangent ref stacks based on their type. If we see -# a type which looks like a tangent ref stack, just return it. If we see any other type, -# assume it is a tangent stack, meaning that the tangent stack is fixed. -# This is bit of a hack -- ideally we would get the code construction in `make_ad_stmts!` to -# determine this, as doing this based on type is potentially flakey. It will have to do for -# now though. -@inline @generated function ___tangent_ref_stacks_helper(arg::P) where {P} - if P <: Union{InactiveStack, Stack{<:Ref}, NoTangentRefStack} && !(P <: Stack{<:Ptr}) - return :(arg) - else - return :(FixedStackTangentRefStack(arg)) - end -end - is_active(::Union{Argument, ID}) = true is_active(::Any) = false -__zero_tangent(arg) = zero_tangent(arg) -__zero_tangent(arg::GlobalRef) = zero_tangent(getglobal(arg.mod, arg.name)) -__zero_tangent(arg::QuoteNode) = zero_tangent(arg.value) +# Get a bound on the pullback type, given a rule and associated primal types. +function pullback_type(Trule, arg_types) + T = Core.Compiler.return_type(Tuple{Trule, map(fcodual_type, arg_types)...}) + return (T <: Tuple && T !== Union{} && !(T isa Union)) ? T.parameters[2] : Any +end # Build a stack to contain the pullback. Specialises on whether the pullback is a singleton, # and whether we get to know the concrete type of the pullback or not. -function build_pb_stack(Trule, arg_types) - T_pb!! = Core.Compiler.return_type(Tuple{Trule, map(codual_type, arg_types)...}) - if T_pb!! <: Tuple && T_pb!! !== Union{} && !(T_pb!! isa Union) - F = T_pb!!.parameters[2] - return Base.issingletontype(F) ? SingletonStack{F}() : Stack{F}() - else - return Stack{Any}() - end -end +build_pb_stack(Tpb) = Base.issingletontype(Tpb) ? SingletonStack{Tpb}() : Stack{Tpb}() -# Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`. -@inline function __fwds_pass!( - arg_tangent_ref_stacks, - rule, - ret_tangent_stack, - pb_stack, - ::Type{R}, - f::F, - raw_args::Vararg{Any, N}, -) where {R, F, N} - - raw_args = (f, raw_args...) - __log_tangent_refs!(pb_stack, raw_args, arg_tangent_ref_stacks) - - # Run the rule. - args = tuple_map(x -> isa(x, AugmentedRegister) ? x.codual : uninit_codual(x), raw_args) - out, pb!! = rule(args...) - - # Log the results and return. - __push_tangent_stack!(ret_tangent_stack, tangent(out)) - __push_pb_stack!(pb_stack, pb!!) - return AugmentedRegister(out, top_ref(ret_tangent_stack))::R -end - -@inline function __log_tangent_refs!(::Any, raw_args, arg_tangent_ref_stacks) - tangent_refs = map(x -> isa(x, AugmentedRegister) ? x.tangent_ref : nothing, raw_args) - tuple_map(__push_ref_stack, arg_tangent_ref_stacks, tangent_refs) +# Used by the getfield special-case in call / invoke statments. +@inline function __fwds_pass_no_ad!(f::F, raw_args::Vararg{Any, N}) where {F, N} + return tuple_splat(__get_primal(f), tuple_map(__get_primal, raw_args)) end -@inline __log_tangent_refs!(::SingletonStack{NoPullback}, ::Any, ::Any) = nothing +__get_primal(x::CoDual) = primal(x) +__get_primal(x) = x -@inline __push_ref_stack(tangent_ref_stack, ref) = push!(tangent_ref_stack, ref) -@inline __push_ref_stack(::InactiveStack, ref) = nothing -@inline __push_ref_stack(::NoTangentRefStack, ref) = nothing +# Helper used in make_ad_stmts! for call / invoke. +__make_codual(x::P) where {P} = (P <: CoDual ? x : uninit_fcodual(x))::CoDual -@inline __push_tangent_stack!(stack, t) = push!(stack, t) +# Useful to have this function call for debugging when looking at the generated IRCode. @inline __push_pb_stack!(stack, pb!!) = push!(stack, pb!!) # Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`. -@inline function __rvs_pass!(arg_tangent_ref_stacks, ret_tangent_stack, pb_stack)::Nothing - pb = __pop_pb_stack!(pb_stack) - tngt = __pop_tangent_stack!(ret_tangent_stack) - __execute_reverse_pass!(pb, tngt, arg_tangent_ref_stacks) +@inline function __rvs_pass!(P, pb_stack, ret_rev_data_ref, arg_rev_data_refs...)::Nothing + __run_rvs_pass!(P, __pop_pb_stack!(pb_stack), ret_rev_data_ref, arg_rev_data_refs...) end -@inline __execute_reverse_pass!(::NoPullback, ::Any, ::Any) = nothing +# If `NoPullback` is the pullback, then there is nothing to do. Moreover, since the +# reverse-data accumulated in the `ret_rev_data_ref` is never used, we don't even need to +# bother reseting it's value to zero. +@inline __run_rvs_pass!(::Any, ::NoPullback, ::Ref, arg_rev_data_refs...) = nothing -@inline function __execute_reverse_pass!(pb!!, dout, arg_tangent_ref_stacks) - # Get the tangent w.r.t. each argument of the primal. - tangent_refs = tuple_map(pop!, arg_tangent_ref_stacks) - - # Run the pullback and increment the argument tangents. - dargs = tuple_map(set_immutable_to_zero ∘ getindex, tangent_refs) - new_dargs = pb!!(dout, dargs...) - tuple_map(increment_ref!, tangent_refs, new_dargs) +@inline function __run_rvs_pass!(P, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...) + tuple_map(increment_if_ref!, arg_rev_data_refs, pb!!(ret_rev_data_ref[])) + set_ret_ref_to_zero!!(P, ret_rev_data_ref) return nothing end +@inline increment_if_ref!(ref::Ref, rvs_data) = increment_ref!(ref, rvs_data) +@inline increment_if_ref!(::Nothing, ::Any) = nothing + +@inline increment_ref!(x::Ref, t) = setindex!(x, increment!!(x[], t)) +@inline increment_ref!(::Base.RefValue{NoRData}, t) = nothing + +# Useful to have this function call for debugging when looking at the generated IRCode. @inline __pop_pb_stack!(stack) = pop!(stack) -@inline __pop_tangent_stack!(tangent_stack) = pop!(tangent_stack) + +@inline function set_ret_ref_to_zero!!(::Type{P}, r::Ref{R}) where {P, R} + r[] = zero_like_rdata_from_type(P) +end +@inline set_ret_ref_to_zero!!(::Type{P}, r::Base.RefValue{NoRData}) where {P} = nothing # -# Runners for generated code. +# Runners for generated code. The main job of these functions is to handle the translation +# between differing varargs conventions. # -struct Pullback{Tpb, Tret_ref, Targ_tangent_stacks, Tisva, Tnargs} - pb_oc::Tpb - ret_ref::Tret_ref - arg_tangent_stacks::Targ_tangent_stacks +struct Pullback{Tpb_oc, Tisva<:Val, Tnvargs<:Val} + pb_oc::Tpb_oc isva::Tisva - nargs::Tnargs + nvargs::Tnvargs end -@inline function (pb::Pullback{P, Q})(dy, dargs::Vararg{Any, N}) where {P, Q, N} - unflattened_dargs = __unflatten_varargs(pb.isva, dargs, pb.nargs) - map(setindex!, map(top_ref, pb.arg_tangent_stacks), unflattened_dargs) - increment_ref!(pb.ret_ref, dy) - pb.pb_oc(dy, unflattened_dargs...) - out = __flatten_varargs(pb.isva, map(pop!, pb.arg_tangent_stacks), nvargs(length(dargs), pb.nargs)) - return out::_typeof(dargs) -end +@inline (pb::Pullback)(dy) = __flatten_varargs(pb.isva, pb.pb_oc(dy), pb.nvargs) -@inline nvargs(n_flat, ::Val{nargs}) where {nargs} = Val(n_flat - nargs + 1) - -struct DerivedRule{Tfwds_oc, Targ_tangent_stacks, Tpb_oc, Tisva<:Val, Tnargs<:Val} +struct DerivedRule{Tfwds_oc, Tpb_oc, Tisva<:Val, Tnargs<:Val} fwds_oc::Tfwds_oc pb_oc::Tpb_oc - arg_tangent_stacks::Targ_tangent_stacks - block_stack::Stack{Int32} isva::Tisva nargs::Tnargs end @inline function (fwds::DerivedRule{P, Q, S})(args::Vararg{CoDual, N}) where {P, Q, S, N} + uf_args = __unflatten_codual_varargs(fwds.isva, args, fwds.nargs) + pb!! = Pullback(fwds.pb_oc, fwds.isva, nvargs(length(args), fwds.nargs)) + return fwds.fwds_oc(uf_args...)::CoDual, pb!! +end - # Load arguments in to stacks, and create tuples. - args = __unflatten_codual_varargs(fwds.isva, args, fwds.nargs) - args_with_tangent_stacks = map(args, fwds.arg_tangent_stacks) do arg, arg_tangent_stack - push!(arg_tangent_stack, tangent(arg)) - return AugmentedRegister(arg, top_ref(arg_tangent_stack)) - end +@inline nvargs(n_flat, ::Val{nargs}) where {nargs} = Val(n_flat - nargs + 1) - # Run forwards-pass. - reg = fwds.fwds_oc(args_with_tangent_stacks...)::AugmentedRegister +# If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0). +function __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs} + isva || return args + last_el = isa(args[end], NoRData) ? ntuple(n -> NoRData(), nvargs) : args[end] + return (args[1:end-1]..., last_el...) +end - # Extract result and assemble pullback. - pb!! = Pullback(fwds.pb_oc, reg.tangent_ref, fwds.arg_tangent_stacks, fwds.isva, fwds.nargs) - return reg.codual, pb!! +# If isva and nargs=2, then inputs `(CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))` +# are transformed into `(CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))`. +function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs} + isva || return args + group_primal = map(primal, args[nargs:end]) + if fdata_type(tangent_type(_typeof(group_primal))) == NoFData + grouped_args = zero_fcodual(group_primal) + else + grouped_args = CoDual(group_primal, map(tangent, args[nargs:end])) + end + return (args[1:nargs-1]..., grouped_args) end +# +# Rule derivation. +# + # Compute the concrete type of the rule that will be returned from `build_rrule`. This is # important for performance in dynamic dispatch, and to ensure that recursion works # properly. @@ -674,84 +699,51 @@ function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig} Treturn = Base.Experimental.compute_ir_rettype(ir) isva, _ = is_vararg_sig_and_sparam_names(sig) - arg_types = map(_get_type, ir.argtypes) - arg_tangent_types = map(tangent_type, arg_types) - Targ_registers = Tuple{map(tangent_stack_type ∘ _get_type, ir.argtypes)...} - Treturn_register = register_type(Treturn) - if isconcretetype(Treturn_register) + arg_types = map(_type, ir.argtypes) + arg_fwds_types = Tuple{map(fcodual_type, arg_types)...} + arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...} + fwds_return_codual = fcodual_type(Treturn) + rvs_return_type = rdata_type(tangent_type(Treturn)) + if isconcretetype(fwds_return_codual) return DerivedRule{ - Core.OpaqueClosure{Tuple{map(register_type, arg_types)...}, Treturn_register}, - Targ_registers, - Core.OpaqueClosure{Tuple{tangent_type(Treturn), arg_tangent_types...}, Nothing}, + Core.OpaqueClosure{arg_fwds_types, fwds_return_codual}, + Core.OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}, Val{isva}, Val{length(ir.argtypes)}, } else return DerivedRule{ - Core.OpaqueClosure{Tuple{map(register_type, arg_types)...}, T} where {T<:Treturn_register}, - Targ_registers, - Core.OpaqueClosure{Tuple{tangent_type(Treturn), arg_tangent_types...}, Nothing}, + Core.OpaqueClosure{arg_fwds_types, P} where {P<:fwds_return_codual}, + Core.OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}, Val{isva}, Val{length(ir.argtypes)}, } end end -# if isva and nargs=2, then inputs (5.0, 4.0, 3.0) are transformed into (5.0, (4.0, 3.0)). -function __unflatten_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs} - isva || return args - if all(t -> t isa NoTangent, args[nargs:end]) - return (args[1:nargs-1]..., NoTangent()) - else - return (args[1:nargs-1]..., args[nargs:end]) - end -end - -# If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0). -function __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs} - isva || return args - if args[end] isa NoTangent - return (args[1:end-1]..., ntuple(n -> NoTangent(), nvargs)...) - else - return (args[1:end-1]..., args[end]...) - end -end - -# If isva and nargs=2, then inputs `(CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))` -# are transformed into `(CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))`. -function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs} - isva || return args - group_primal = map(primal, args[nargs:end]) - if tangent_type(_typeof(group_primal)) == NoTangent - grouped_args = zero_codual(group_primal) - else - grouped_args = CoDual(group_primal, map(tangent, args[nargs:end])) - end - return (args[1:nargs-1]..., grouped_args) -end - """ build_rrule(args...) Helper method. Only uses static information from `args`. """ -function build_rrule(args...) - return build_rrule(TapirInterpreter(), _typeof(TestUtils.__get_primals(args))) -end +build_rrule(args...) = build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args))) """ - build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}) where {C} + build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C} Returns a `DerivedRule` which is an `rrule!!` for `sig` in context `C`. See the docstring for `rrule!!` for more info. + +If `safety_on` is `true`, then all calls to rules are replaced with calls to `SafeRRule`s. """ -function build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}) where {C} +function build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C} - # Reset id count. This ensures that everything in this function is deterministic. + # Reset id count. This ensures that the IDs generated are the same each time this + # function runs. seed_id!() # If we have a hand-coded rule, just use that. - is_primitive(C, sig) && return rrule!! + is_primitive(C, sig) && return (safety_on ? SafeRRule(rrule!!) : rrule!!) # Grab code associated to the primal. ir, _ = lookup_ir(interp, sig) @@ -763,15 +755,10 @@ function build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}) where {C} primal_ir = BBCode(ir) # Compute global info. - arg_types = Dict{Argument, Any}( - map(((n, t),) -> (Argument(n) => _get_type(t)), enumerate(ir.argtypes)) - ) - insts = new_inst_vec(ir.stmts) - ssa_types = Dict{ID, NewInstruction}(zip(concatenate_ids(primal_ir), insts)) - arg_tangent_stacks = (map(make_tangent_stack ∘ _get_type, primal_ir.argtypes)..., ) - info = ADInfo(interp, arg_types, ssa_types, arg_tangent_stacks) + info = ADInfo(interp, primal_ir, safety_on) - # For each block in the fwds and pullback BBCode, translate all statements. + # For each block in the fwds and pullback BBCode, translate all statements. Running this + # will, in general, push items to `info.shared_data_pairs`. ad_stmts_blocks = map(primal_ir.blocks) do primal_blk ids = primal_blk.inst_ids primal_stmts = map(x -> x.stmt, primal_blk.insts) @@ -780,40 +767,34 @@ function build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}) where {C} # Make shared data, and construct BBCode for forwards-pass and pullback. shared_data = shared_data_tuple(info.shared_data_pairs) - # display(sig) - # @show length(shared_data) - # @show length(ir.stmts.inst) - # display(collect(_typeof(shared_data).parameters)) # If we've already derived the OpaqueClosures and info, do not re-derive, just create a # copy and pass in new shared data. - if !haskey(interp.oc_cache, sig) + if haskey(interp.oc_cache, (sig, safety_on)) + existing_fwds_oc, existing_pb_oc = interp.oc_cache[(sig, safety_on)] + fwds_oc = replace_captures(existing_fwds_oc, shared_data) + pb_oc = replace_captures(existing_pb_oc, shared_data) + else fwds_ir = forwards_pass_ir(primal_ir, ad_stmts_blocks, info, _typeof(shared_data)) pb_ir = pullback_ir(primal_ir, Treturn, ad_stmts_blocks, info, _typeof(shared_data)) + # @show sig, safety_on + # display(ir) + # display(IRCode(fwds_ir)) + # display(IRCode(pb_ir)) optimised_fwds_ir = optimise_ir!(IRCode(fwds_ir); do_inline=true) optimised_pb_ir = optimise_ir!(IRCode(pb_ir); do_inline=true) + # @show length(ir.stmts.inst) # @show length(optimised_fwds_ir.stmts.inst) # @show length(optimised_pb_ir.stmts.inst) - # display(ir) # display(optimised_fwds_ir) # display(optimised_pb_ir) fwds_oc = OpaqueClosure(optimised_fwds_ir, shared_data...; do_compile=true) pb_oc = OpaqueClosure(optimised_pb_ir, shared_data...; do_compile=true) - interp.oc_cache[sig] = (fwds_oc, pb_oc) - else - existing_fwds_oc, existing_pb_oc = interp.oc_cache[sig] - fwds_oc = replace_captures(existing_fwds_oc, shared_data) - pb_oc = replace_captures(existing_pb_oc, shared_data) + interp.oc_cache[(sig, safety_on)] = (fwds_oc, pb_oc) end - return rule_type(interp, sig)( - fwds_oc, - pb_oc, - arg_tangent_stacks, - info.block_stack, - Val(isva), - Val(length(ir.argtypes)), - ) + raw_rule = rule_type(interp, sig)(fwds_oc, pb_oc, Val(isva), Val(num_args(info))) + return safety_on ? SafeRRule(raw_rule) : raw_rule end # Given an `OpaqueClosure` `oc`, create a new `OpaqueClosure` of the same type, but with new @@ -836,35 +817,46 @@ Produce the IR associated to the `OpaqueClosure` which runs most of the forwards =# function forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) + is_unique_pred, pred_is_unique_pred = characterise_unique_predecessor_blocks(ir.blocks) + # Insert a block at the start which extracts all items from the captures field of the # `OpaqueClosure`, which contains all of the data shared between the forwards- and # reverse-passes. These are assigned to the `ID`s given by the `SharedDataPairs`. - # Additionally, push the entry id onto the block stack. - push_block_stack_stmt = Expr(:call, push!, info.block_stack_id, info.entry_id.id) - entry_stmts = vcat( - shared_data_stmts(info.shared_data_pairs), - (ID(), new_inst(push_block_stack_stmt)), - ) + # Additionally, push the entry id onto the block stack if needed. + sds = shared_data_stmts(info.shared_data_pairs) + if pred_is_unique_pred[ir.blocks[1].id] + entry_stmts = sds + else + push_block_stack_stmt = Expr( + :call, __push_blk_stack!, info.block_stack_id, info.entry_id.id + ) + entry_stmts = vcat(sds, (ID(), new_inst(push_block_stack_stmt))) + end entry_block = BBlock(info.entry_id, entry_stmts) # Construct augmented version of each basic block from the primal. For each block: - # 1. pull the translated basic block statements from ad_stmts_blocks. - # 2. insert a statement which logs the ID of the current block to the block stack. + # 1. pull the translated basic block statements from ad_stmts_blocks, + # 2. insert a statement which logs the ID of the current block if necessary, and # 3. construct and return a BBlock. blocks = map(ad_stmts_blocks) do (block_id, ad_stmts) fwds_stmts = reduce(vcat, map(x -> x.fwds, ad_stmts)) - ins_loc = length(fwds_stmts) + (isa(fwds_stmts[end][2].stmt, Terminator) ? 0 : 1) - ins_stmt = Expr(:call, __push_blk_stack!, info.block_stack_id, block_id.id) - ins_inst = (ID(), new_inst(ins_stmt)) - return BBlock(block_id, insert!(fwds_stmts, ins_loc, ins_inst)) + if !is_unique_pred[block_id] + ins_loc = length(fwds_stmts) + (isa(fwds_stmts[end][2].stmt, Terminator) ? 0 : 1) + ins_stmt = Expr(:call, __push_blk_stack!, info.block_stack_id, block_id.id) + ins_inst = (ID(), new_inst(ins_stmt)) + insert!(fwds_stmts, ins_loc, ins_inst) + end + return BBlock(block_id, fwds_stmts) end # Create and return the `BBCode` for the forwards-pass. - arg_types = vcat(Tshared_data, map(register_type ∘ _get_type, ir.argtypes)) + arg_types = vcat(Tshared_data, map(fcodual_type ∘ _type, ir.argtypes)) return BBCode(vcat(entry_block, blocks), arg_types, ir.sptypes, ir.linetable, ir.meta) end -@noinline __push_blk_stack!(block_stack::Stack{Int32}, id::Int32) = push!(block_stack, id) +# Going via this function, rather than just calling push!, makes it very straightforward to +# figure out much time is spent pushing to the block stack when profiling AD. +@inline __push_blk_stack!(block_stack::BlockStack, id::Int32) = push!(block_stack, id) #= pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) @@ -874,8 +866,7 @@ Produce the IR associated to the `OpaqueClosure` which runs most of the pullback function pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) # Compute the argument types associated to the reverse-pass. - darg_types = map(tangent_type ∘ _get_type, ir.argtypes) - arg_types = vcat(Tshared_data, tangent_type(Tret), darg_types) + arg_types = vcat(Tshared_data, rdata_type(tangent_type(Tret))) # Compute the blocks which return in the primal. primal_exit_blocks_inds = findall(is_reachable_return_node ∘ terminator, ir.blocks) @@ -893,33 +884,57 @@ function pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, T end # - # Standard path pullback generation -- applied to 99% of primals: + # Standard path pullback generation -- applies to 99% of primals: # - # Create entry block, which pops the block_stack, and switches to whichever block we - # were in at the end of the forwards-pass. - exit_blocks_ids = map(n -> ir.blocks[n].id, primal_exit_blocks_inds) + # Create entry block which: + # 1. extracts items from shared data to the correct IDs, + # 2. creates `Ref`s (which will be optimised away later) to hold rdata for all ssas, + # 3. create switch statement to block which terminated the forwards pass. If there is + # only a single block in the primal containing a reachable ReturnNode, then there is + # no need to pop the block stack. data_stmts = shared_data_stmts(info.shared_data_pairs) - switch_stmts = make_switch_stmts(exit_blocks_ids, info) - entry_block = BBlock(ID(), vcat(data_stmts, switch_stmts)) + rev_data_ref_stmts = reverse_data_ref_stmts(info) + exit_blocks_ids = map(n -> ir.blocks[n].id, primal_exit_blocks_inds) + switch_stmts = make_switch_stmts(exit_blocks_ids, length(exit_blocks_ids) == 1, info) + entry_block = BBlock(ID(), vcat(data_stmts, rev_data_ref_stmts, switch_stmts)) # For each basic block in the primal: - # 1. pull the translated basic block statements from ad_stmts_blocks - # 2. reverse the statements - # 3. pop block stack to get the predecessor block - # 4. insert a switch statement to determine which block to jump to. Restrict blocks - # considered to only those which are predecessors of this one. If in the first block, - # check whether or not the block stack is empty. If empty, jump to the exit block. + # 1. if the block is reachable on the reverse-pass, the bulk of its statements are the + # translated basic block statements, in reverse. + # 2. if, on the other hand, the block is provably not reachable on the reverse-pass, + # return a block with nothing in it. At present we only assert that a block is not + # reachable if it ends with an unreachable return node. + # 3. if we need to pop the predecessor stack, pop it. We don't need to pop it if there + # is only a single predecessor to this block, and said predecessor is a _unique_ + # _predecessor_ (see characterise_unique_predecessor_blocks for more info), as its + # ID is uniquely determined, and nothing will have been put on to the block stack + # during the forwards-pass (see how the output of + # characterise_unique_predecessor_blocks is used in forwards_pass_ir). + # 4. if the block began with one or more PhiNodes, then handle their tangents. + # 5. jump to the predecessor block ps = compute_all_predecessors(ir) + _, pred_is_unique_pred = characterise_unique_predecessor_blocks(ir.blocks) main_blocks = map(ad_stmts_blocks, enumerate(ir.blocks)) do (blk_id, ad_stmts), (n, blk) - rvs_stmts = reduce(vcat, [x.rvs for x in reverse(ad_stmts)]) + if is_unreachable_return_node(terminator(blk)) + rvs_stmts = [(ID(), new_inst(nothing))] + else + rvs_stmts = reduce(vcat, [x.rvs for x in reverse(ad_stmts)]) + end pred_ids = vcat(ps[blk.id], n == 1 ? [info.entry_id] : ID[]) - switch_stmts = make_switch_stmts(pred_ids, info) - return BBlock(blk_id, vcat(rvs_stmts, switch_stmts)) + tmp = pred_is_unique_pred[blk_id] + additional_stmts, new_blocks = conclude_rvs_block(blk, pred_ids, tmp, info) + rvs_block = BBlock(blk_id, vcat(rvs_stmts, additional_stmts)) + return vcat(rvs_block, new_blocks) end + main_blocks = vcat(main_blocks...) - # Create an exit block. Simply returns nothing. - exit_block = BBlock(info.entry_id, [(ID(), new_inst(ReturnNode(nothing)))]) + # Create an exit block. Dereferences reverse-data for arguments and returns it. + arg_rdata_ref_ids = map(n -> info.arg_rdata_ref_ids[Argument(n)], 1:num_args(info)) + deref_id = ID() + deref = new_inst(Expr(:call, __deref_arg_rev_data_refs, arg_rdata_ref_ids...)) + ret = new_inst(ReturnNode(deref_id)) + exit_block = BBlock(info.entry_id, [(deref_id, deref), (ID(), ret)]) # Create and return `BBCode` for the pullback. blks = vcat(entry_block, main_blocks, exit_block) @@ -927,7 +942,99 @@ function pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, T end #= - make_switch_stmts(pred_ids::Vector{ID}, info::ADInfo) + conclude_rvs_block( + blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo + ) + +Generates code which is inserted at the end of each counterpart block in the reverse-pass. +Handles phi nodes, and choosing the correct next block to switch to. +=# +function conclude_rvs_block( + blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo +) + # Get the PhiNodes and their IDs. + phi_ids, phis = phi_nodes(blk) + + # If there are no PhiNodes in this block, switch directly to the predecessor. + if length(phi_ids) == 0 + return make_switch_stmts(pred_ids, pred_is_unique_pred, info), BBlock[] + end + + # Create statements which extract + zero the rdata refs associated to them. + rdata_ids = map(_ -> ID(), phi_ids) + deref_stmts = map(phi_ids, rdata_ids) do phi_id, deref_id + P = get_primal_type(info, phi_id) + r = get_rev_data_id(info, phi_id) + return (deref_id, new_inst(Expr(:call, __deref_and_zero, P, r))) + end + + # For each predecessor, create a `BBlock` which processes its corresponding edge in + # each of the `PhiNode`s. + new_blocks = map(pred_ids) do pred_id + values = Any[__get_value(pred_id, p.stmt) for p in phis] + return rvs_phi_block(pred_id, rdata_ids, values, info) + end + new_pred_ids = map(blk -> blk.id, new_blocks) + switch = make_switch_stmts(pred_ids, new_pred_ids, pred_is_unique_pred, info) + return vcat(deref_stmts, switch), new_blocks +end + +# Helper functionality for conclude_rvs_block. +function __get_value(edge::ID, x::IDPhiNode) + edge in x.edges || return nothing + n = only(findall(==(edge), x.edges)) + return isassigned(x.values, n) ? x.values[n] : nothing +end + +# Helper, used in conclude_rvs_block. +@inline function __deref_and_zero(::Type{P}, x::Ref) where {P} + t = x[] + x[] = Tapir.zero_like_rdata_from_type(P) + return t +end + +#= + rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo) + +Produces a `BBlock` which runs the reverse-pass for the edge associated to `pred_id` in a +collection of `IDPhiNode`s, and then goes to the block associated to `pred_id`. + +For example, suppose that we encounter the following collection of `PhiNode`s at the start +of some block: +```julia +%6 = φ (#2 => _1, #3 => %5) +%7 = φ (#2 => 5., #3 => _2) +``` +Let the tangent refs associated to `%6`, `%7`, and `_1`` be denoted `t%6`, `t%7`, and `t_1` +resp., and let `pred_id` be `#2`, then this function will produce a basic block of the form +```julia +increment_ref!(t_1, t%6) +nothing +goto #2 +``` +The call to `increment_ref!` appears because `_1` is the value associated to`%6` when the +primal code comes from `#2`. Similarly, the `goto #2` statement appears because we came from +`#2` on the forwards-pass. There is no `increment_ref!` associated to `%7` because `5.` is a +constant. We emit a `nothing` statement, which the compiler will happily optimise away later +on. + +The same ideas apply if `pred_id` were `#3`. The block would end with `#3`, and there would +be two `increment_ref!` calls because both `%5` and `_2` are not constants. +=# +function rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo) + @assert length(rdata_ids) == length(values) + inc_stmts = map(rdata_ids, values) do id, val + stmt = Expr(:call, increment_if_ref!, get_rev_data_id(info, val), id) + return (ID(), new_inst(stmt)) + end + goto_stmt = (ID(), new_inst(IDGotoNode(pred_id))) + return BBlock(ID(), vcat(inc_stmts, goto_stmt)) +end + +#= + make_switch_stmts( + pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo + ) `preds_ids` comprises the `ID`s associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is `[ID(1), ID(2), ID(3)]`, then @@ -947,9 +1054,10 @@ switch( In words: `make_switch_stmts` emits code which jumps to whichever block preceded the current block during the forwards-pass. =# -function make_switch_stmts(pred_ids::Vector{ID}, info::ADInfo) - - # If there are no predecessors, then we can't possible have hit this block. This can +function make_switch_stmts( + pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo +) + # If there are no predecessors, then we can't possibly have hit this block. This can # happen when all of the statements in a block have been eliminated, but the Julia # optimiser has not removed the block entirely from the `IRCode`. This often presents as # a block containing only a single `nothing` statement. @@ -959,7 +1067,11 @@ function make_switch_stmts(pred_ids::Vector{ID}, info::ADInfo) # Get the predecessor that we actually had in the primal. prev_blk_id = ID() - prev_blk = new_inst(Expr(:call, __pop_blk_stack!, info.block_stack_id)) + if pred_is_unique_pred + prev_blk = new_inst(QuoteNode(only(pred_ids))) + else + prev_blk = new_inst(Expr(:call, __pop_blk_stack!, info.block_stack_id)) + end # Compare predecessor from primal with all possible predecessors. conds = map(pred_ids[1:end-1]) do id @@ -967,17 +1079,25 @@ function make_switch_stmts(pred_ids::Vector{ID}, info::ADInfo) end # Switch statement to change to the predecessor. - switch_stmt = Switch(Any[c[1] for c in conds], pred_ids[1:end-1], pred_ids[end]) + switch_stmt = Switch(Any[c[1] for c in conds], target_ids[1:end-1], target_ids[end]) switch = (ID(), new_inst(switch_stmt)) return vcat((prev_blk_id, prev_blk), conds, switch) end -@noinline __pop_blk_stack!(block_stack::Stack{Int32}) = pop!(block_stack) +function make_switch_stmts(pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo) + return make_switch_stmts(pred_ids, pred_ids, pred_is_unique_pred, info) +end + +# Going via this function, rather than just calling pop! directly, makes it easy to figure +# out how much time is spent popping the block stack when profiling performance. +@inline __pop_blk_stack!(block_stack::BlockStack) = pop!(block_stack) # Helper function emitted by `make_switch_stmts`. __switch_case(id::Int32, predecessor_id::Int32) = !(id === predecessor_id) +# Helper function used by `pullback_ir`. +@inline __deref_arg_rev_data_refs(arg_rev_data_refs...) = map(getindex, arg_rev_data_refs) #= DynamicDerivedRule(interp::TapirInterpreter) @@ -992,42 +1112,50 @@ This is used to implement dynamic dispatch. struct DynamicDerivedRule{T, V} interp::T cache::V + safety_on::Bool end -DynamicDerivedRule(interp::TapirInterpreter) = DynamicDerivedRule(interp, Dict{Any, Any}()) +function DynamicDerivedRule(interp::TapirInterpreter, safety_on::Bool) + return DynamicDerivedRule(interp, Dict{Any, Any}(), safety_on) +end function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N} - sig = Tuple{map(_typeof, map(primal, args))...} + sig = Tuple{tuple_map(_typeof, tuple_map(primal, args))...} is_primitive(context_type(dynamic_rule.interp), sig) && return rrule!!(args...) rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing - rule = build_rrule(dynamic_rule.interp, sig) + rule = build_rrule(dynamic_rule.interp, sig; safety_on=dynamic_rule.safety_on) dynamic_rule.cache[sig] = rule end return rule(args...) end #= - LazyDerivedRule(interp, sig) + LazyDerivedRule(interp, sig, safety_on::Bool) For internal use only. A type-stable wrapper around a `DerivedRule`, which only instantiates the `DerivedRule` when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived. + +If `safety_on` is `true`, then the rule constructed will be a `SafeRRule`. This is useful +when debugging, but should usually be switched off for production code as it (in general) +incurs some runtime overhead. =# -mutable struct LazyDerivedRule{Trule, T, V} - interp::T - sig::V +mutable struct LazyDerivedRule{sig, Tinterp<:TapirInterpreter, Trule} + interp::Tinterp + safety_on::Bool rule::Trule - function LazyDerivedRule(interp::T, sig::V) where {T<:PInterp, V<:Type{<:Tuple}} - return new{rule_type(interp, sig), T, V}(interp, sig) + function LazyDerivedRule(interp::A, ::Type{sig}, safety_on::Bool) where {A, sig} + rt = safety_on ? SafeRRule{rule_type(interp, sig)} : rule_type(interp, sig) + return new{sig, A, rt}(interp, safety_on) end end -function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N} +function (rule::LazyDerivedRule{sig})(args::Vararg{Any, N}) where {N, sig} if !isdefined(rule, :rule) - rule.rule = build_rrule(rule.interp, rule.sig) + rule.rule = build_rrule(rule.interp, sig; safety_on=rule.safety_on) end return rule.rule(args...) end diff --git a/src/interpreter/zero_like_rdata.jl b/src/interpreter/zero_like_rdata.jl new file mode 100644 index 00000000..38a990cf --- /dev/null +++ b/src/interpreter/zero_like_rdata.jl @@ -0,0 +1,37 @@ +""" + ZeroRData() + +Singleton type indicating zero-valued rdata. This should only ever appear as an +intermediate quantity in the reverse-pass of AD when the type of the primal is not fully +inferable, or a field of a type is abstractly typed. + +If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an +error -- please open an issue in such a situation. +""" +struct ZeroRData end + +@inline increment!!(::ZeroRData, r::R) where {R} = r + +""" + zero_like_rdata_type(::Type{P}) where {P} + +Indicates the type which will be returned by `zero_like_rdata_from_type`. Will be the rdata +type for `P` if we can produce the zero rdata element given only `P`, and will be the union +of `R` and `ZeroRData` if an instance of `P` is needed. +""" +function zero_like_rdata_type(::Type{P}) where {P} + R = rdata_type(tangent_type(P)) + return can_produce_zero_rdata_from_type(P) ? R : Union{R, ZeroRData} +end + +""" + zero_like_rdata_from_type(::Type{P}) where {P} + +This is an internal implementation detail -- you should generally not use this function. + +Returns _either_ the zero element of type `rdata_type(tangent_type(P))`, or a `ZeroRData`. +It is always valid to return a `ZeroRData`, +""" +function zero_like_rdata_from_type(::Type{P}) where {P} + return can_produce_zero_rdata_from_type(P) ? zero_rdata_from_type(P) : ZeroRData() +end diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index 3f45bb0c..ef88d188 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -2,14 +2,14 @@ # because we drop the gradient, because the tangent type of integers is NoTangent. # https://github.com/JuliaLang/julia/blob/9f9e989f241fad1ae03c3920c20a93d8017a5b8f/base/pointer.jl#L282 @is_primitive MinimalCtx Tuple{typeof(Base.:(+)), Ptr, Integer} -function rrule!!(::CoDual{typeof(Base.:(+))}, x::CoDual{<:Ptr}, y::CoDual{<:Integer}) - return CoDual(primal(x) + primal(y), tangent(x) + primal(y)), NoPullback() +function rrule!!(f::CoDual{typeof(Base.:(+))}, x::CoDual{<:Ptr}, y::CoDual{<:Integer}) + return CoDual(primal(x) + primal(y), tangent(x) + primal(y)), NoPullback(f, x, y) end @is_primitive MinimalCtx Tuple{typeof(randn), Xoshiro, Vararg} -function rrule!!(::CoDual{typeof(randn)}, rng::CoDual{Xoshiro}, args::CoDual...) +function rrule!!(f::CoDual{typeof(randn)}, rng::CoDual{Xoshiro}, args::CoDual...) x = randn(primal(rng), map(primal, args)...) - return CoDual(x, zero(x)), NoPullback() + return zero_fcodual(x), NoPullback(f, rng, args...) end function generate_hand_written_rrule!!_test_cases( @@ -20,9 +20,7 @@ function generate_hand_written_rrule!!_test_cases( test_cases = Any[ # Rules to avoid pointer type conversions. ( - true, - :stability, - nothing, + true, :stability_and_allocs, nothing, +, CoDual( bitcast(Ptr{Float64}, pointer_from_objref(_x)), @@ -32,7 +30,7 @@ function generate_hand_written_rrule!!_test_cases( ), # Rule to avoid llvmcall - (true, :stability, nothing, randn, Xoshiro(1)), + (true, :stability_and_allocs, nothing, randn, Xoshiro(1)), (true, :stability, nothing, randn, Xoshiro(1), 2), (true, :stability, nothing, randn, Xoshiro(1), 3, 2), ] diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index c785a6f1..dd7a46b3 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -25,8 +25,8 @@ end # LEVEL 1 # -for (fname, elty) in ((:cblas_ddot,:Float64), (:cblas_sdot,:Float32)) - @eval function rrule!!( +for (fname, elty) in ((:cblas_ddot, :Float64), (:cblas_sdot, :Float32)) + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type @@ -38,8 +38,8 @@ for (fname, elty) in ((:cblas_ddot,:Float64), (:cblas_sdot,:Float32)) _incx::CoDual{BLAS.BlasInt}, _DY::CoDual{Ptr{$elty}}, _incy::CoDual{BLAS.BlasInt}, - args..., - ) + args::Vararg{Any, N}, + ) where {N} # Load in values from pointers. n, incx, incy = map(primal, (_n, _incx, _incy)) xinds = 1:incx:incx * n @@ -47,21 +47,22 @@ for (fname, elty) in ((:cblas_ddot,:Float64), (:cblas_sdot,:Float32)) DX = view(unsafe_wrap(Vector{$elty}, primal(_DX), n * incx), xinds) DY = view(unsafe_wrap(Vector{$elty}, primal(_DY), n * incy), yinds) - function ddot_pb!!(dv, d1, d2, d3, d4, d5, d6, dn, dDX, dincx, dDY, dincy, dargs...) - _dDX = view(unsafe_wrap(Vector{$elty}, dDX, n * incx), xinds) - _dDY = view(unsafe_wrap(Vector{$elty}, dDY, n * incy), yinds) + _dDX = view(unsafe_wrap(Vector{$elty}, tangent(_DX), n * incx), xinds) + _dDY = view(unsafe_wrap(Vector{$elty}, tangent(_DY), n * incy), yinds) + + function ddot_pb!!(dv) _dDX .+= DY .* dv _dDY .+= DX .* dv - return d1, d2, d3, d4, d5, d6, dn, dDX, dincx, dDY, dincy, dargs... + return tuple_fill(NoRData(), Val(N + 11)) end # Run primal computation. - return CoDual(dot(DX, DY), zero($elty)), ddot_pb!! + return zero_fcodual(dot(DX, DY)), ddot_pb!! end end for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) - @eval function Tapir.rrule!!( + @eval @inline function Tapir.rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type @@ -72,8 +73,8 @@ for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) DA::CoDual{Ptr{$elty}}, DX::CoDual{Ptr{$elty}}, incx::CoDual{Ptr{BLAS.BlasInt}}, - args..., - ) + args::Vararg{Any, N}, + ) where {N} # Load in values from pointers, and turn pointers to memory buffers into Vectors. _n = unsafe_load(primal(n)) _incx = unsafe_load(primal(incx)) @@ -85,7 +86,9 @@ for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) DX_copy = _DX[inds] BLAS.scal!(_n, _DA, _DX, _incx) - function dscal_pullback!!(_, a, b, c, d, e, f, dn, dDA, dDX, dincx, dargs...) + dDA = tangent(DA) + dDX = tangent(DX) + function dscal_pullback!!(::NoRData) # Set primal to previous state. _DX[inds] .= DX_copy @@ -96,9 +99,9 @@ for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) # Compute cotangent w.r.t. DX. BLAS.scal!(_n, _DA, _DX_s, _incx) - return a, b, c, d, e, f, dn, dDA, dDX, dincx, dargs... + return tuple_fill(NoRData(), Val(10 + N)) end - return zero_codual(Cvoid()), dscal_pullback!! + return zero_fcodual(Cvoid()), dscal_pullback!! end end @@ -109,7 +112,7 @@ end # for (gemv, elty) in ((:dgemv_, :Float64), (:sgemm_, :Float32)) - @eval function rrule!!( + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(gemv))}}, ::CoDual, @@ -127,8 +130,9 @@ for (gemv, elty) in ((:dgemv_, :Float64), (:sgemm_, :Float32)) _beta::CoDual{Ptr{$elty}}, _y::CoDual{Ptr{$elty}}, _incy::CoDual{Ptr{BLAS.BlasInt}}, - args... - ) + args::Vararg{Any, Nargs} + ) where {Nargs} + # Load in data. tA = Char(unsafe_load(primal(_tA))) M, N, lda, incx, incy = map(unsafe_load ∘ primal, (_M, _N, _lda, _incx, _incy)) @@ -145,10 +149,13 @@ for (gemv, elty) in ((:dgemv_, :Float64), (:sgemm_, :Float32)) BLAS.gemv!(tA, alpha, A, x, beta, y) - function gemv_pb!!( - _, d1, d2, d3, d4, d5, d6, - dt, dM, dN, dalpha, _dA, dlda, _dx, dincx, dbeta, _dy, dincy, dargs... - ) + dalpha = tangent(_alpha) + dbeta = tangent(_beta) + _dA = tangent(_A) + _dx = tangent(_x) + _dy = tangent(_y) + function gemv_pb!!(::NoRData) + # Load up the tangents. dA = wrap_ptr_as_view(_dA, lda, M, N) dx = view(unsafe_wrap(Vector{$elty}, _dx, incx * Nx), 1:incx:incx * Nx) @@ -164,15 +171,14 @@ for (gemv, elty) in ((:dgemv_, :Float64), (:sgemm_, :Float32)) # Restore the original value of `y`. y .= y_copy - return d1, d2, d3, d4, d5, d6, dt, dM, dN, dalpha, _dA, dlda, _dx, dincx, dbeta, - _dy, dincy, dargs... + return tuple_fill(NoRData(), Val(17 + Nargs)) end - return zero_codual(Cvoid()), gemv_pb!! + return zero_fcodual(Cvoid()), gemv_pb!! end end for (trmv, elty) in ((:dtrmv_, :Float64), (:strmv_, :Float32)) - @eval function rrule!!( + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(trmv))}}, ::CoDual, @@ -187,8 +193,8 @@ for (trmv, elty) in ((:dtrmv_, :Float64), (:strmv_, :Float32)) _lda::CoDual{Ptr{BLAS.BlasInt}}, _x::CoDual{Ptr{$elty}}, _incx::CoDual{Ptr{BLAS.BlasInt}}, - args... - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} # Load in data. uplo, trans, diag = map(Char ∘ unsafe_load ∘ primal, (_uplo, _trans, _diag)) N, lda, incx = map(unsafe_load ∘ primal, (_N, _lda, _incx)) @@ -199,9 +205,10 @@ for (trmv, elty) in ((:dtrmv_, :Float64), (:strmv_, :Float32)) # Run primal computation. BLAS.trmv!(uplo, trans, diag, A, x) - function trmv_pb!!( - _, d1, d2, d3, d4, d5, d6, du, dt, ddiag, dN, _dA, dlda, _dx, dincx, dargs... - ) + _dA = tangent(_A) + _dx = tangent(_x) + function trmv_pb!!(::NoRData) + # Load up the tangents. dA = wrap_ptr_as_view(_dA, lda, N, N) dx = wrap_ptr_as_view(_dx, N, incx) @@ -213,10 +220,9 @@ for (trmv, elty) in ((:dtrmv_, :Float64), (:strmv_, :Float32)) dA .+= tri!(trans == 'N' ? dx * x' : x * dx', uplo, diag) BLAS.trmv!(uplo, trans == 'N' ? 'T' : 'N', diag, A, dx) - return d1, d2, d3, d4, d5, d6, du, dt, ddiag, dN, _dA, dlda, _dx, dincx, - dargs... + return tuple_fill(NoRData(), Val(14 + Nargs)) end - return zero_codual(Cvoid()), trmv_pb!! + return zero_fcodual(Cvoid()), trmv_pb!! end end @@ -230,8 +236,8 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) @eval function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(gemm))}}, - RT::CoDual{Val{Cvoid}}, - AT::CoDual, # arg types + ::CoDual{Val{Cvoid}}, + ::CoDual, # arg types ::CoDual, # nreq ::CoDual, # calling convention tA::CoDual{Ptr{UInt8}}, @@ -247,8 +253,8 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) beta::CoDual{Ptr{$elty}}, C::CoDual{Ptr{$elty}}, LDC::CoDual{Ptr{BLAS.BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} _tA = Char(unsafe_load(primal(tA))) _tB = Char(unsafe_load(primal(tB))) _m = unsafe_load(primal(m)) @@ -270,10 +276,13 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) BLAS.gemm!(_tA, _tB, _alpha, A_mat, B_mat, _beta, C_mat) - function gemm!_pullback!!( - _, df, dname, dRT, dAT, dnreq, dconvention, - dtA, dtB, dm, dn, dka, dalpha, dA, dLDA, dB, dLDB, dbeta, dC, dLDC, dargs..., - ) + dalpha = tangent(alpha) + dA = tangent(A) + dB = tangent(B) + dbeta = tangent(beta) + dC = tangent(C) + function gemm!_pullback!!(::NoRData) + # Restore previous state. C_mat .= C_copy @@ -290,10 +299,9 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) dB_mat .+= _alpha * transpose(_trans(_tB, transpose(dC_mat) * _trans(_tA, A_mat))) dC_mat .*= _beta - return df, dname, dRT, dAT, dnreq, dconvention, - dtA, dtB, dm, dn, dka, dalpha, dA, dLDA, dB, dLDB, dbeta, dC, dLDC, dargs... + return tuple_fill(NoRData(), Val(19 + Nargs)) end - return zero_codual(Cvoid()), gemm!_pullback!! + return zero_fcodual(Cvoid()), gemm!_pullback!! end end @@ -302,8 +310,8 @@ for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) @eval function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(syrk))}}, - RT::CoDual{Val{Cvoid}}, - AT::CoDual, # arg types + ::CoDual{Val{Cvoid}}, + ::CoDual, # arg types ::CoDual, # nreq ::CoDual, # calling convention uplo::CoDual{Ptr{UInt8}}, @@ -316,8 +324,8 @@ for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) beta::CoDual{Ptr{$elty}}, C::CoDual{Ptr{$elty}}, LDC::CoDual{Ptr{BLAS.BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} _uplo = Char(unsafe_load(primal(uplo))) _t = Char(unsafe_load(primal(trans))) _n = unsafe_load(primal(n)) @@ -335,10 +343,12 @@ for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) BLAS.syrk!(_uplo, _t, _alpha, A_mat, _beta, C_mat) - function syrk!_pullback!!( - _, df, dname, dRT, dAT, dnreq, dconvention, - duplo, dtrans, dn, dk, dalpha, dA, dLDA, dbeta, dC, dLDC, dargs..., - ) + dalpha = tangent(alpha) + dA = tangent(A) + dbeta = tangent(beta) + dC = tangent(C) + function syrk!_pullback!!(::NoRData) + # Restore previous state. C_mat .= C_copy @@ -354,10 +364,9 @@ for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) dA_mat .+= _alpha * (_t == 'N' ? (B + B') * A_mat : A_mat * (B + B')) dC_mat .= (_uplo == 'U' ? tril!(dC_mat, -1) : triu!(dC_mat, 1)) .+ _beta .* B - return df, dname, dRT, dAT, dnreq, dconvention, - duplo, dtrans, dn, dk, dalpha, dA, dLDA, dbeta, dC, dLDC, dargs... + return tuple_fill(NoRData(), Val(16 + Nargs)) end - return zero_codual(Cvoid()), syrk!_pullback!! + return zero_fcodual(Cvoid()), syrk!_pullback!! end end @@ -380,8 +389,8 @@ for (trmm, elty) in ((:dtrmm_, :Float64), (:strmm_, :Float32)) _lda::CoDual{Ptr{BLAS.BlasInt}}, _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BLAS.BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} # Load in data and store B for the reverse-pass. side, ul, tA, diag = map(Char ∘ unsafe_load ∘ primal, (_side, _uplo, _trans, _diag)) M, N, lda, ldb = map(unsafe_load ∘ primal, (_M, _N, _lda, _ldb)) @@ -394,10 +403,11 @@ for (trmm, elty) in ((:dtrmm_, :Float64), (:strmm_, :Float32)) # Run primal. BLAS.trmm!(side, ul, tA, diag, alpha, A, B) - function trmm!_pullback!!( - _, d1, d2, d3, d4, d5, d6, - dside, duplo, dtrans, ddiag, dM, dN, dalpha, _dA, dlda, _dB, dlbd, dargs..., - ) + dalpha = tangent(_alpha) + _dA = tangent(_A) + _dB = tangent(_B) + function trmm!_pullback!!(::NoRData) + # Convert pointers to views. dA = wrap_ptr_as_view(_dA, lda, R, R) dB = wrap_ptr_as_view(_dB, ldb, M, N) @@ -418,11 +428,10 @@ for (trmm, elty) in ((:dtrmm_, :Float64), (:strmm_, :Float32)) # Compute dB tangent. BLAS.trmm!(side, ul, tA == 'N' ? 'T' : 'N', diag, alpha, A, dB) - return d1, d2, d3, d4, d5, d6, - dside, duplo, dtrans, ddiag, dM, dN, dalpha, _dA, dlda, _dB, dlbd, dargs... + return tuple_fill(NoRData(), Val(17 + Nargs)) end - return zero_codual(Cvoid()), trmm!_pullback!! + return zero_fcodual(Cvoid()), trmm!_pullback!! end end @@ -445,8 +454,8 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32)) _lda::CoDual{Ptr{BLAS.BlasInt}}, _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BLAS.BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} side = Char(unsafe_load(primal(_side))) uplo = Char(unsafe_load(primal(_uplo))) trans = Char(unsafe_load(primal(_trans))) @@ -463,10 +472,11 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32)) trsm!(side, uplo, trans, diag, alpha, A, B) - function trsm_pb!!( - _, d1, d2, d3, d4, d5, d6, - dside, duplo, dtrans, ddiag, dM, dN, dalpha, _dA, dlda, _dB, dlbd, dargs..., - ) + dalpha = tangent(_alpha) + _dA = tangent(_A) + _dB = tangent(_B) + function trsm_pb!!(::NoRData) + # Convert pointers to views. dA = wrap_ptr_as_view(_dA, lda, R, R) dB = wrap_ptr_as_view(_dB, ldb, M, N) @@ -499,10 +509,9 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32)) # Compute dB tangent. BLAS.trsm!(side, uplo, trans == 'N' ? 'T' : 'N', diag, alpha, A, dB) - return d1, d2, d3, d4, d5, d6, - dside, duplo, dtrans, ddiag, dM, dN, dalpha, _dA, dlda, _dB, dlbd, dargs... + return tuple_fill(NoRData(), Val(17 + Nargs)) end - return zero_codual(Cvoid()), trsm_pb!! + return zero_fcodual(Cvoid()), trsm_pb!! end end @@ -518,12 +527,12 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) # BLAS LEVEL 1 # - [ - Any[false, :none, nothing, BLAS.dot, 3, randn(5), 1, randn(4), 1], - Any[false, :none, nothing, BLAS.dot, 3, randn(6), 2, randn(4), 1], - Any[false, :none, nothing, BLAS.dot, 3, randn(6), 1, randn(9), 3], - Any[false, :none, nothing, BLAS.dot, 3, randn(12), 3, randn(9), 2], - Any[false, :none, nothing, BLAS.scal!, 10, 2.4, randn(30), 2], + Any[ + (false, :none, nothing, BLAS.dot, 3, randn(5), 1, randn(4), 1), + (false, :none, nothing, BLAS.dot, 3, randn(6), 2, randn(4), 1), + (false, :none, nothing, BLAS.dot, 3, randn(6), 1, randn(9), 3), + (false, :none, nothing, BLAS.dot, 3, randn(12), 3, randn(9), 2), + (false, :none, nothing, BLAS.scal!, 10, 2.4, randn(30), 2), ], # @@ -542,7 +551,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) xs = [randn(N), view(randn(15), 3:N+2), view(randn(30), 1:2:2N)] ys = [randn(M), view(randn(15), 2:M+1), view(randn(30), 2:2:2M)] return map(Iterators.product(As, xs, ys)) do (A, x, y) - Any[false, :none, nothing, BLAS.gemv!, tA, randn(), A, x, randn(), y] + (false, :none, nothing, BLAS.gemv!, tA, randn(), A, x, randn(), y) end end, )), @@ -554,7 +563,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) As = [randn(N, N), view(randn(15, 15), 3:N+2, 4:N+3)] bs = [randn(N), view(randn(14), 4:N+3)] return map(product(As, bs)) do (A, b) - Any[false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b] + (false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b) end end, )), @@ -567,13 +576,14 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) vec(map(product(t_flags, t_flags)) do (tA, tB) A = tA == 'N' ? randn(3, 4) : randn(4, 3) B = tB == 'N' ? randn(4, 5) : randn(5, 4) - Any[false, :none, nothing, BLAS.gemm!, tA, tB, randn(), A, B, randn(), randn(3, 5)] + (false, :none, nothing, BLAS.gemm!, tA, tB, randn(), A, B, randn(), randn(3, 5)) end), + # aliased gemm! vec(map(product(t_flags, t_flags)) do (tA, tB) A = randn(5, 5) B = randn(5, 5) - Any[false, :none, nothing, aliased_gemm!, tA, tB, randn(), randn(), A, B] + (false, :none, nothing, aliased_gemm!, tA, tB, randn(), randn(), A, B) end), # syrk! diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index e8fea287..230c33f6 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -14,11 +14,13 @@ module IntrinsicsWrappers +using Base: IEEEFloat using Core: Intrinsics using Tapir import ..Tapir: rrule!!, CoDual, primal, tangent, zero_tangent, NoPullback, - tangent_type, increment!!, @is_primitive, MinimalCtx, is_primitive + tangent_type, increment!!, @is_primitive, MinimalCtx, is_primitive, NoFData, + zero_rdata, NoRData, tuple_map, fdata, NoRData, rdata, increment_rdata!!, zero_fcodual # Note: performance is not considered _at_ _all_ in this implementation. function rrule!!(f::CoDual{<:Core.IntrinsicFunction}, args...) @@ -39,9 +41,8 @@ macro inactive_intrinsic(name) $name(x...) = Intrinsics.$name(x...) (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name), Vararg}}) = true translate(::Val{Intrinsics.$name}) = $name - function rrule!!(::CoDual{typeof($name)}, args...) - y = $name(map(primal, args)...) - return CoDual(y, zero_tangent(y)), NoPullback() + function rrule!!(f::CoDual{typeof($name)}, args::Vararg{Any, N}) where {N} + return zero_fcodual($name(map(primal, args)...)), NoPullback(f, args...) end end return esc(expr) @@ -49,23 +50,23 @@ end @intrinsic abs_float function rrule!!(::CoDual{typeof(abs_float)}, x) - abs_float_pullback!!(dy, df, dx) = df, dx + sign(primal(x)) * dy + abs_float_pullback!!(dy) = NoRData(), sign(primal(x)) * dy y = abs_float(primal(x)) - return CoDual(y, zero_tangent(y)), abs_float_pullback!! + return CoDual(y, NoFData()), abs_float_pullback!! end @intrinsic add_float function rrule!!(::CoDual{typeof(add_float)}, a, b) - add_float_pb!!(c̄, f̄, ā, b̄) = f̄, c̄ + ā, c̄ + b̄ + add_float_pb!!(c̄) = NoRData(), c̄, c̄ c = add_float(primal(a), primal(b)) - return CoDual(c, zero_tangent(c)), add_float_pb!! + return CoDual(c, NoFData()), add_float_pb!! end @intrinsic add_float_fast function rrule!!(::CoDual{typeof(add_float_fast)}, a, b) - add_float_fast_pb!!(c̄, f̄, ā, b̄) = f̄, add_float_fast(c̄, ā), add_float_fast(c̄, b̄) + add_float_fast_pb!!(c̄) = NoRData(), c̄, c̄ c = add_float_fast(primal(a), primal(b)) - return CoDual(c, zero_tangent(c)), add_float_fast_pb!! + return CoDual(c, NoFData()), add_float_fast_pb!! end @inactive_intrinsic add_int @@ -87,15 +88,15 @@ end # atomic_pointerswap @intrinsic bitcast -function rrule!!(::CoDual{typeof(bitcast)}, ::CoDual{Type{T}}, x) where {T} +function rrule!!(f::CoDual{typeof(bitcast)}, t::CoDual{Type{T}}, x) where {T} _x = primal(x) v = bitcast(T, _x) if T <: Ptr && _x isa Ptr dv = bitcast(Ptr{tangent_type(eltype(T))}, tangent(x)) else - dv = zero_tangent(v) + dv = NoFData() end - return CoDual(v, dv), NoPullback() + return CoDual(v, dv), NoPullback(f, t, x) end @inactive_intrinsic bswap_int @@ -119,8 +120,8 @@ __cglobal(::Val{s}, x::Vararg{Any, N}) where {s, N} = cglobal(s, x...) translate(::Val{Intrinsics.cglobal}) = __cglobal Tapir.is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(__cglobal), Vararg}}) = true -function rrule!!(::CoDual{typeof(__cglobal)}, args...) - return Tapir.uninit_codual(__cglobal(map(primal, args)...)), NoPullback() +function rrule!!(f::CoDual{typeof(__cglobal)}, args...) + return Tapir.uninit_fcodual(__cglobal(map(primal, args)...)), NoPullback(f, args...) end @inactive_intrinsic checked_sadd_int @@ -138,9 +139,9 @@ end function rrule!!(::CoDual{typeof(copysign_float)}, x, y) _x = primal(x) _y = primal(y) - copysign_float_pullback!!(dz, df, dx, dy) = df, dx + dz * sign(_y), dy + copysign_float_pullback!!(dz) = NoRData(), dz * sign(_y), zero_rdata(_y) z = copysign_float(_x, _y) - return CoDual(z, zero_tangent(z)), copysign_float_pullback!! + return CoDual(z, NoFData()), copysign_float_pullback!! end @inactive_intrinsic ctlz_int @@ -152,12 +153,8 @@ function rrule!!(::CoDual{typeof(div_float)}, a, b) _a = primal(a) _b = primal(b) _y = div_float(_a, _b) - function div_float_pullback!!(dy, df, da, db) - da += div_float(dy, _b) - db -= dy * _a / _b^2 - return df, da, db - end - return CoDual(_y, zero_tangent(_y)), div_float_pullback!! + div_float_pullback!!(dy) = NoRData(), div_float(dy, _b), -dy * _a / _b^2 + return CoDual(_y, NoFData()), div_float_pullback!! end @intrinsic div_float_fast @@ -165,12 +162,10 @@ function rrule!!(::CoDual{typeof(div_float_fast)}, a, b) _a = primal(a) _b = primal(b) _y = div_float_fast(_a, _b) - function div_float_pullback!!(dy, df, da, db) - da += div_float_fast(dy, _b) - db -= dy * div_float_fast(_a, _b^2) - return df, da, db + function div_float_pullback!!(dy) + return NoRData(), div_float_fast(dy, _b), -dy * div_float_fast(_a, _b^2) end - return CoDual(_y, zero_tangent(_y)), div_float_pullback!! + return CoDual(_y, NoFData()), div_float_pullback!! end @inactive_intrinsic eq_float @@ -183,12 +178,8 @@ end function rrule!!(::CoDual{typeof(fma_float)}, x, y, z) _x = primal(x) _y = primal(y) - _z = primal(z) - function fma_float_pullback!!(da, df, dx, dy, dz) - return df, fma_float(da, _y, dx), fma_float(da, _x, dy), dz + da - end - a = fma_float(_x, _y, _z) - return CoDual(a, zero_tangent(a)), fma_float_pullback!! + fma_float_pullback!!(da) = NoRData(), da * _y, da * _x, da + return CoDual(fma_float(_x, _y, primal(z)), NoFData()), fma_float_pullback!! end # fpext -- maybe interesting @@ -213,18 +204,16 @@ end function rrule!!(::CoDual{typeof(mul_float)}, a, b) _a = primal(a) _b = primal(b) - mul_float_pb!!(dc, df, da, db) = df, fma_float(dc, _b, da), fma_float(_a, dc, db) - c = mul_float(_a, _b) - return CoDual(c, zero_tangent(c)), mul_float_pb!! + mul_float_pb!!(dc) = NoRData(), dc * _b, _a * dc + return CoDual(mul_float(_a, _b), NoFData()), mul_float_pb!! end @intrinsic mul_float_fast function rrule!!(::CoDual{typeof(mul_float_fast)}, a, b) _a = primal(a) _b = primal(b) - mul_float_pb!!(dc, df, da, db) = df, fma_float(dc, _b, da), fma_float(_a, dc, db) - c = mul_float_fast(_a, _b) - return CoDual(c, zero_tangent(c)), mul_float_pb!! + mul_float_fast_pb!!(dc) = NoRData(), dc * _b, _a * dc + return CoDual(mul_float_fast(_a, _b), NoFData()), mul_float_fast_pb!! end @inactive_intrinsic mul_int @@ -234,9 +223,8 @@ function rrule!!(::CoDual{typeof(muladd_float)}, x, y, z) _x = primal(x) _y = primal(y) _z = primal(z) - muladd_float_pullback!!(da, df, dx, dy, dz) = df, dx + da * _y, dy + da * _x, dz + da - a = muladd_float(_x, _y, _z) - return CoDual(a, zero_tangent(a)), muladd_float_pullback!! + muladd_float_pullback!!(da) = NoRData(), da * _y, da * _x, da + return CoDual(muladd_float(_x, _y, _z), NoFData()), muladd_float_pullback!! end @inactive_intrinsic ne_float @@ -246,17 +234,15 @@ end @intrinsic neg_float function rrule!!(::CoDual{typeof(neg_float)}, x) _x = primal(x) - neg_float_pullback!!(dy, df, dx) = df, sub_float(dx, dy) - y = neg_float(_x) - return CoDual(y, zero_tangent(y)), neg_float_pullback!! + neg_float_pullback!!(dy) = NoRData(), -dy + return CoDual(neg_float(_x), NoFData()), neg_float_pullback!! end @intrinsic neg_float_fast function rrule!!(::CoDual{typeof(neg_float_fast)}, x) _x = primal(x) - neg_float_fast_pullback!!(dy, df, dx) = df, sub_float_fast(dx, dy) - y = neg_float_fast(_x) - return CoDual(y, zero_tangent(y)), neg_float_fast_pullback!! + neg_float_fast_pullback!!(dy) = NoRData(), -dy + return CoDual(neg_float_fast(_x), NoFData()), neg_float_fast_pullback!! end @inactive_intrinsic neg_int @@ -268,15 +254,17 @@ function rrule!!(::CoDual{typeof(pointerref)}, x, y, z) _x = primal(x) _y = primal(y) _z = primal(z) - x_s = tangent(x) - a = CoDual(pointerref(_x, _y, _z), pointerref(x_s, _y, _z)) - function pointerref_pullback!!(da, df, dx, dy, dz) - dx_v = pointerref(dx, _y, _z) - new_dx_v = increment!!(dx_v, da) - pointerset(dx, new_dx_v, _y, _z) - return df, dx, dy, dz + dx = tangent(x) + a = CoDual(pointerref(_x, _y, _z), fdata(pointerref(dx, _y, _z))) + if Tapir.rdata_type(tangent_type(Tapir._typeof(primal(a)))) == NoRData + return a, NoPullback((NoRData(), NoRData(), NoRData(), NoRData())) + else + function pointerref_pullback!!(da) + pointerset(dx, increment_rdata!!(pointerref(dx, _y, _z), da), _y, _z) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return a, pointerref_pullback!! end - return a, pointerref_pullback!! end @intrinsic pointerset @@ -286,14 +274,15 @@ function rrule!!(::CoDual{typeof(pointerset)}, p, x, idx, z) _z = primal(z) old_value = pointerref(_p, _idx, _z) old_tangent = pointerref(tangent(p), _idx, _z) - function pointerset_pullback!!(_, df, dp, dx, didx, dz) - dx_new = increment!!(dx, pointerref(dp, _idx, _z)) + dp = tangent(p) + function pointerset_pullback!!(::NoRData) + dx_r = pointerref(dp, _idx, _z) pointerset(_p, old_value, _idx, _z) pointerset(dp, old_tangent, _idx, _z) - return df, dp, dx_new, didx, dz + return NoRData(), NoRData(), rdata(dx_r), NoRData(), NoRData() end pointerset(_p, primal(x), _idx, _z) - pointerset(tangent(p), tangent(x), _idx, _z) + pointerset(dp, zero_tangent(primal(x)), _idx, _z) return p, pointerset_pullback!! end @@ -310,15 +299,15 @@ end @intrinsic sqrt_llvm function rrule!!(::CoDual{typeof(sqrt_llvm)}, x) _x = primal(x) - llvm_sqrt_pullback!!(dy, df, dx) = df, dx + dy * inv(2 * sqrt(_x)) - return CoDual(sqrt_llvm(_x), zero(_x)), llvm_sqrt_pullback!! + llvm_sqrt_pullback!!(dy) = NoRData(), dy * inv(2 * sqrt(_x)) + return CoDual(sqrt_llvm(_x), NoFData()), llvm_sqrt_pullback!! end @intrinsic sqrt_llvm_fast function rrule!!(::CoDual{typeof(sqrt_llvm_fast)}, x) _x = primal(x) - llvm_sqrt_pullback!!(dy, df, dx) = df, dx + dy * inv(2 * sqrt(_x)) - return CoDual(sqrt_llvm_fast(_x), zero(_x)), llvm_sqrt_pullback!! + llvm_sqrt_fast_pullback!!(dy) = NoRData(), dy * inv(2 * sqrt(_x)) + return CoDual(sqrt_llvm_fast(_x), NoFData()), llvm_sqrt_fast_pullback!! end @inactive_intrinsic srem_int @@ -327,20 +316,16 @@ end function rrule!!(::CoDual{typeof(sub_float)}, a, b) _a = primal(a) _b = primal(b) - sub_float_pullback!!(dc, df, da, db) = df, add_float(da, dc), sub_float(db, dc) - c = sub_float(_a, _b) - return CoDual(c, zero_tangent(c)), sub_float_pullback!! + sub_float_pullback!!(dc) = NoRData(), dc, -dc + return CoDual(sub_float(_a, _b), NoFData()), sub_float_pullback!! end @intrinsic sub_float_fast function rrule!!(::CoDual{typeof(sub_float_fast)}, a, b) _a = primal(a) _b = primal(b) - function sub_float_fast_pullback!!(dc, df, da, db) - return df, add_float_fast(da, dc), sub_float_fast(db, dc) - end - c = sub_float_fast(_a, _b) - return CoDual(c, zero_tangent(c)), sub_float_fast_pullback!! + sub_float_fast_pullback!!(dc) = NoRData(), dc, -dc + return CoDual(sub_float_fast(_a, _b), NoFData()), sub_float_fast_pullback!! end @inactive_intrinsic sub_int @@ -362,12 +347,12 @@ end end # IntrinsicsWrappers -function rrule!!(::CoDual{typeof(<:)}, T1, T2) - return CoDual(<:(primal(T1), primal(T2)), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(<:)}, T1, T2) + return zero_fcodual(<:(primal(T1), primal(T2))), NoPullback(f, T1, T2) end -function rrule!!(::CoDual{typeof(===)}, args...) - return CoDual(===(map(primal, args)...), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(===)}, x, y) + return zero_fcodual(primal(x) === primal(y)), NoPullback(f, x, y) end # Core._abstracttype @@ -386,32 +371,35 @@ end # Core._typebody! # Core._typevar -function rrule!!(::CoDual{typeof(Core._typevar)}, args...) - y = Core._typevar(map(primal, args)...) - return CoDual(y, zero_tangent(y)), NoPullback() +function rrule!!(f::CoDual{typeof(Core._typevar)}, args...) + return zero_fcodual(Core._typevar(map(primal, args)...)), NoPullback(f, args...) end -function rrule!!(::CoDual{typeof(Core.apply_type)}, args...) - arg_primals = map(primal, args) - T = Core.apply_type(arg_primals...) - return CoDual(T, zero_tangent(T)), NoPullback() +function rrule!!(f::CoDual{typeof(Core.apply_type)}, args...) + T = Core.apply_type(tuple_map(primal, args)...) + return CoDual{_typeof(T), NoFData}(T, NoFData()), NoPullback(f, args...) end -function rrule!!( +Base.@propagate_inbounds function rrule!!( ::CoDual{typeof(Core.arrayref)}, - inbounds::CoDual{Bool}, + checkbounds::CoDual{Bool}, x::CoDual{<:Array}, - inds::CoDual{Int}..., -) - _inbounds = primal(inbounds) - _inds = map(primal, inds) - function arrayref_pullback!!(dy, df, dinbounds, dx, dinds...) - current_val = arrayref(_inbounds, dx, _inds...) - arrayset(_inbounds, dx, increment!!(current_val, dy), _inds...) - return df, dinbounds, dx, dinds... + inds::Vararg{CoDual{Int}, N}, +) where {N} + + # Convert to linear indices to reduce amount of data required on the reverse-pass, to + # avoid converting from cartesian to linear indices multiple times, and to perform a + # bounds check if required by the calling context. + lin_inds = LinearIndices(size(primal(x)))[tuple_map(primal, inds)...] + + dx = tangent(x) + function arrayref_pullback!!(dy) + new_tangent = increment_rdata!!(arrayref(false, dx, lin_inds), dy) + arrayset(false, dx, new_tangent, lin_inds) + return NoRData(), NoRData(), NoRData(), ntuple(_ -> NoRData(), N)... end - _y = arrayref(_inbounds, primal(x), _inds...) - dy = arrayref(_inbounds, tangent(x), _inds...) + _y = arrayref(false, primal(x), lin_inds) + dy = fdata(arrayref(false, tangent(x), lin_inds)) return CoDual(_y, dy), arrayref_pullback!! end @@ -439,38 +427,44 @@ function rrule!!( end arrayset(_inbounds, primal(A), primal(v), _inds...) - arrayset(_inbounds, tangent(A), tangent(v), _inds...) - function arrayset_pullback!!(dA::TdA, df, dinbounds, dA2::TdA, dv, dinds::NoTangent...) - dv_new = increment!!(dv, arrayref(_inbounds, dA, _inds...)) + dA = tangent(A) + arrayset(_inbounds, dA, tangent(tangent(v), zero_rdata(primal(v))), _inds...) + function arrayset_pullback!!(::NoRData) + dv = rdata(arrayref(_inbounds, dA, _inds...)) if to_save arrayset(_inbounds, primal(A), old_A[][1], _inds...) arrayset(_inbounds, dA, old_A[][2], _inds...) end - return df, dinbounds, dA, dv_new, dinds... + return NoRData(), NoRData(), NoRData(), dv, tuple_map(_ -> NoRData(), _inds)... end return A, arrayset_pullback!! end function isbits_arrayset_rrule( - _inbounds, _inds, A::CoDual{<:Array{P}, TdA}, v + boundscheck, _inds, A::CoDual{<:Array{P}, TdA}, v::CoDual{P} ) where {P, V, TdA <: Array{V}} - old_A = ( - arrayref(_inbounds, primal(A), _inds...), - arrayref(_inbounds, tangent(A), _inds...), - ) - arrayset(_inbounds, primal(A), primal(v), _inds...) - arrayset(_inbounds, tangent(A), tangent(v), _inds...) - function isbits_arrayset_pullback!!(dA::TdA, df, dinbounds, dA2::TdA, dv, dinds::NoTangent...) - dv_new = increment!!(dv, arrayref(_inbounds, dA, _inds...)) - arrayset(_inbounds, primal(A), old_A[1], _inds...) - arrayset(_inbounds, dA, old_A[2], _inds...) - return df, dinbounds, dA, dv_new, dinds... + + # Convert to linear indices + lin_inds = LinearIndices(size(primal(A)))[_inds...] + + old_A = (arrayref(false, primal(A), lin_inds), arrayref(false, tangent(A), lin_inds)) + arrayset(false, primal(A), primal(v), lin_inds) + + _A = primal(A) + dA = tangent(A) + arrayset(false, dA, zero_tangent(primal(v)), lin_inds) + ninds = Val(length(_inds)) + function isbits_arrayset_pullback!!(::NoRData) + dv = rdata(arrayref(false, dA, lin_inds)) + arrayset(false, _A, old_A[1], lin_inds) + arrayset(false, dA, old_A[2], lin_inds) + return NoRData(), NoRData(), NoRData(), dv, tuple_fill(NoRData(), ninds)... end return A, isbits_arrayset_pullback!! end -function rrule!!(::CoDual{typeof(Core.arraysize)}, X, dim) - return CoDual(Core.arraysize(primal(X), primal(dim)), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(Core.arraysize)}, X, dim) + return zero_fcodual(Core.arraysize(primal(X), primal(dim))), NoPullback(f, X, dim) end # Core.compilerbarrier @@ -479,169 +473,164 @@ end # Core.finalizer # Core.get_binding_type -function rrule!!(::CoDual{typeof(Core.ifelse)}, cond, a, b) +function rrule!!(f::CoDual{typeof(Core.ifelse)}, cond, a::A, b::B) where {A, B} _cond = primal(cond) - function ifelse_pullback!!(dc, df, ::NoTangent, da, db) - da = _cond ? increment!!(da, dc) : da - db = _cond ? db : increment!!(db, dc) - return df, NoTangent(), da, db + p_a = primal(a) + p_b = primal(b) + pb!! = if rdata_type(tangent_type(A)) == NoRData && rdata_type(tangent_type(B)) == NoRData + NoPullback(f, cond, a, b) + else + lazy_da = LazyZeroRData(p_a) + lazy_db = LazyZeroRData(p_b) + function ifelse_pullback!!(dc) + da = ifelse(_cond, dc, instantiate(lazy_da)) + db = ifelse(_cond, instantiate(lazy_db), dc) + return NoRData(), NoRData(), da, db + end end - return ifelse(_cond, a, b), ifelse_pullback!! -end -function rrule!!( - ::CoDual{typeof(Core.ifelse)}, - cond, - a::CoDual{<:Any, NoTangent}, - b::CoDual{<:Any, NoTangent}, -) - return ifelse(primal(cond), a, b), NoPullback() + # It's a good idea to split up applying ifelse to the primal and tangent. This is + # because if you push a `CoDual` through ifelse, it _forces_ the construction of the + # CoDual. Conversely, if you pass through the primal and tangents separately, the + # compiler will often be able to avoid constructing the CoDual at all by inlining lots + # of stuff away. + return CoDual(ifelse(_cond, p_a, p_b), ifelse(_cond, tangent(a), tangent(b))), pb!! end # Core.set_binding_type! -function rrule!!(::CoDual{typeof(Core.sizeof)}, x) - return CoDual(Core.sizeof(primal(x)), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(Core.sizeof)}, x) + return zero_fcodual(Core.sizeof(primal(x))), NoPullback(f, x) end # Core.svec -function rrule!!(::CoDual{typeof(applicable)}, f, args...) - return CoDual(applicable(primal(f), map(primal, args)...), NoTangent()), NoPullback() +function rrule!!(_f::CoDual{typeof(applicable)}, f, args...) + pb!! = NoPullback(_f, f, args...) + return zero_fcodual(applicable(primal(f), map(primal, args)...)), pb!! end -function rrule!!(::CoDual{typeof(Core.fieldtype)}, args...) - arg_primals = map(primal, args) - return CoDual(Core.fieldtype(arg_primals...), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(Core.fieldtype)}, args::Vararg{Any, N}) where {N} + arg_primals = tuple_map(primal, args) + return CoDual(Core.fieldtype(arg_primals...), NoFData()), NoPullback(f, args...) end -function rrule!!(::CoDual{typeof(getfield)}, value::CoDual, name::CoDual) - _name = primal(name) - function getfield_pullback(dy, ::NoTangent, dvalue, ::NoTangent) - new_dvalue = _increment_field!!(dvalue, dy, _name) - return NoTangent(), new_dvalue, NoTangent() +function rrule!!(f::CoDual{typeof(getfield)}, x::CoDual{P}, name::CoDual) where {P} + if tangent_type(P) == NoTangent + y = uninit_fcodual(getfield(primal(x), primal(name))) + return y, NoPullback(f, x, name) + elseif is_homogeneous_and_immutable(primal(x)) + dx_r = LazyZeroRData(primal(x)) + _name = primal(name) + function immutable_lgetfield_pb!!(dy) + return NoRData(), increment_field!!(instantiate(dx_r), dy, _name), NoRData() + end + yp = getfield(primal(x), _name) + y = CoDual(yp, _get_fdata_field(primal(x), tangent(x), _name)) + return y, immutable_lgetfield_pb!! + else + return rrule!!(uninit_fcodual(lgetfield), x, uninit_fcodual(Val(primal(name)))) end - y = CoDual( - getfield(primal(value), _name), - _get_tangent_field(primal(value), tangent(value), _name), - ) - return y, getfield_pullback -end - -@inline function rrule!!(::CoDual{typeof(getfield)}, value::CoDual{<:Any, NoTangent}, name::CoDual) - return uninit_codual(getfield(primal(value), primal(name))), NoPullback() end -function rrule!!(::CoDual{typeof(getfield)}, value::CoDual, name::CoDual, order::CoDual) - _name = primal(name) - _order = primal(order) - function getfield_pullback(dy, df, dvalue, dname, dorder) - new_dvalue = _increment_field!!(dvalue, dy, _name) - return df, new_dvalue, dname, dorder +function rrule!!(f::CoDual{typeof(getfield)}, x::CoDual{P}, name::CoDual, order::CoDual) where {P} + if tangent_type(P) == NoTangent + y = uninit_fcodual(getfield(primal(x), primal(name))) + return y, NoPullback(f, x, name, order) + elseif is_homogeneous_and_immutable(primal(x)) + dx_r = LazyZeroRData(primal(x)) + _name = primal(name) + function immutable_lgetfield_pb!!(dy) + tmp = increment_field!!(instantiate(dx_r), dy, _name) + return NoRData(), tmp, NoRData(), NoRData() + end + yp = getfield(primal(x), _name, primal(order)) + y = CoDual(yp, _get_fdata_field(primal(x), tangent(x), _name)) + return y, immutable_lgetfield_pb!! + else + literal_name = uninit_fcodual(Val(primal(name))) + literal_order = uninit_fcodual(Val(primal(order))) + return rrule!!(uninit_fcodual(lgetfield), x, literal_name, literal_order) end - _order = _order isa Expr ? true : _order - y = CoDual( - getfield(primal(value), _name, _order), - _get_tangent_field(primal(value), tangent(value), _name, _order), - ) - return y, getfield_pullback end -@inline function rrule!!( - ::CoDual{typeof(getfield)}, value::CoDual{<:Any, NoTangent}, name::CoDual, order::CoDual -) - return uninit_codual(getfield(primal(value), primal(name), primal(order))), NoPullback() -end +@generated is_homogeneous_and_immutable(::P) where {P<:Tuple} = allequal(P.parameters) +@inline is_homogeneous_and_immutable(p::NamedTuple) = is_homogeneous_and_immutable(Tuple(p)) +is_homogeneous_and_immutable(::Any) = false -_get_tangent_field(_, tangent, f...) = getfield(tangent, f...) -function _get_tangent_field(_, tangent::Union{Tangent, MutableTangent}, f...) - return _value(getfield(tangent.fields, f...)) -end -_get_tangent_field(primal, ::NoTangent, f...) = uninit_tangent(getfield(primal, f...)) - -_increment_field!!(x, y, f) = increment_field!!(x, y, f) -_increment_field!!(x::NoTangent, y, f) = x +# # Highly specialised rrule to handle tuples of DataTypes. +# function rrule!!(::CoDual{typeof(getfield)}, value::CoDual{P}, name::CoDual) where {P<:NTuple{<:Any, DataType}} +# pb!! = NoPullback((NoRData(), NoRData(), NoRData(), NoRData())) +# y = CoDual{DataType, NoFData}(getfield(primal(value), primal(name)), NoFData()) +# return y, pb!! +# end +# function rrule!!(::CoDual{typeof(getfield)}, value::CoDual{P}, name::CoDual, order::CoDual) where {P<:NTuple{<:Any, DataType}} +# pb!! = NoPullback((NoRData(), NoRData(), NoRData(), NoRData())) +# y = CoDual{DataType, NoFData}(getfield(primal(value), primal(name), primal(order)), NoFData()) +# return y, pb!! +# end -function rrule!!(::CoDual{typeof(getglobal)}, a, b) - v = getglobal(primal(a), primal(b)) - return CoDual(v, zero_tangent(v)), NoPullback() +function rrule!!(f::CoDual{typeof(getglobal)}, a, b) + return zero_fcodual(getglobal(primal(a), primal(b))), NoPullback(f, a, b) end # invoke -function rrule!!(::CoDual{typeof(isa)}, x, T) - return CoDual(isa(primal(x), primal(T)), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(isa)}, x, T) + return zero_fcodual(isa(primal(x), primal(T))), NoPullback(f, x, T) end -function rrule!!(::CoDual{typeof(isdefined)}, args...) - return CoDual(isdefined(map(primal, args)...), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(isdefined)}, args...) + return zero_fcodual(isdefined(map(primal, args)...)), NoPullback(f, args...) end # modifyfield! -function rrule!!(::CoDual{typeof(nfields)}, x) - return CoDual(nfields(primal(x)), NoTangent()), NoPullback() -end +rrule!!(f::CoDual{typeof(nfields)}, x) = zero_fcodual(nfields(primal(x))), NoPullback(f, x) # replacefield! function rrule!!(::CoDual{typeof(setfield!)}, value, name, x) - _name = primal(name) - save = isdefined(primal(value), _name) - old_x = save ? getfield(primal(value), _name) : nothing - old_dx = save ? val(getfield(tangent(value).fields, _name)) : nothing - function setfield!_pullback(dy, df, dvalue, ::NoTangent, dx) - new_dx = increment!!(dx, val(getfield(dvalue.fields, _name))) - new_dx = increment!!(new_dx, dy) - old_x !== nothing && setfield!(primal(value), _name, old_x) - old_x !== nothing && set_tangent_field!(tangent(value), _name, old_dx) - return df, dvalue, NoTangent(), new_dx - end - y = CoDual( - setfield!(primal(value), _name, primal(x)), - set_tangent_field!(tangent(value), _name, tangent(x)), - ) - return y, setfield!_pullback -end - -function rrule!!( - ::CoDual{typeof(setfield!)}, value::CoDual{<:Any, NoTangent}, name, x -) - _name = primal(name) - save = isdefined(primal(value), _name) - old_x = save ? getfield(primal(value), _name) : nothing - function setfield!_pullback(dy, df, dvalue, ::NoTangent, dx) - old_x !== nothing && setfield!(primal(value), _name, old_x) - return df, dvalue, NoTangent(), dx - end - y = CoDual(setfield!(primal(value), _name, primal(x)), NoTangent()) - return y, setfield!_pullback + literal_name = uninit_fcodual(Val(primal(name))) + return rrule!!(uninit_fcodual(lsetfield!), value, literal_name, x) end # swapfield! # throw -@inline function tuple_pullback(dy, ::NoTangent, dargs...) - return NoTangent(), tuple_map(increment!!, dargs, dy)... +struct TuplePullback{N} end + +@inline (::TuplePullback{N})(dy::Tuple) where {N} = NoRData(), dy... + +@inline function (::TuplePullback{N})(::NoRData) where {N} + return NoRData(), ntuple(_ -> NoRData(), N)... end -function rrule!!(::CoDual{typeof(tuple)}, args::Vararg{Any, N}) where {N} +@inline tuple_pullback(dy) = NoRData(), dy... + +@inline tuple_pullback(dy::NoRData) = NoRData() + +function rrule!!(f::CoDual{typeof(tuple)}, args::Vararg{Any, N}) where {N} primal_output = tuple(map(primal, args)...) if tangent_type(_typeof(primal_output)) == NoTangent - return zero_codual(primal_output), NoPullback() + return zero_fcodual(primal_output), NoPullback(f, args...) else - return CoDual(primal_output, tuple(map(tangent, args)...)), tuple_pullback + if fdata_type(tangent_type(_typeof(primal_output))) == NoFData + return zero_fcodual(primal_output), TuplePullback{N}() + else + return CoDual(primal_output, tuple(map(tangent, args)...)), TuplePullback{N}() + end end end -function rrule!!(::CoDual{typeof(typeassert)}, x, type) - function typeassert_pullback(dy, ::NoTangent, dx, ::NoTangent) - return NoTangent(), increment!!(dx, dy), NoTangent() - end +function rrule!!(::CoDual{typeof(typeassert)}, x::CoDual, type::CoDual) + typeassert_pullback(dy) = NoRData(), dy, NoRData() return CoDual(typeassert(primal(x), primal(type)), tangent(x)), typeassert_pullback end -rrule!!(::CoDual{typeof(typeof)}, x) = CoDual(typeof(primal(x)), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(typeof)}, x::CoDual) + return zero_fcodual(typeof(primal(x))), NoPullback(f, x) +end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) @@ -650,115 +639,127 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) _a = Vector{Vector{Float64}}(undef, 3) _a[1] = [5.4, 4.23, -0.1, 2.1] + x = randn(5) + p = pointer(x) + dx = randn(5) + dp = pointer(dx) + + y = [1, 2, 3] + q = pointer(y) + dy = zero_tangent(y) + dq = pointer(dy) + # Slightly wider range for builtins whose performance is known not to be great. _range = (lb=1e-3, ub=200.0) test_cases = Any[ # Core.Intrinsics: - [false, :stability, nothing, IntrinsicsWrappers.abs_float, 5.0], - [false, :stability, nothing, IntrinsicsWrappers.add_float, 4.0, 5.0], - [false, :stability, nothing, IntrinsicsWrappers.add_float_fast, 4.0, 5.0], - [false, :stability, nothing, IntrinsicsWrappers.add_int, 1, 2], - [false, :stability, nothing, IntrinsicsWrappers.and_int, 2, 3], - [false, :stability, nothing, IntrinsicsWrappers.arraylen, randn(10)], - [false, :stability, nothing, IntrinsicsWrappers.arraylen, randn(10, 7)], - [false, :stability, nothing, IntrinsicsWrappers.ashr_int, 123456, 0x0000000000000020], + (false, :stability, nothing, IntrinsicsWrappers.abs_float, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.add_float, 4.0, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.add_float_fast, 4.0, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.add_int, 1, 2), + (false, :stability, nothing, IntrinsicsWrappers.and_int, 2, 3), + (false, :stability, nothing, IntrinsicsWrappers.arraylen, randn(10)), + (false, :stability, nothing, IntrinsicsWrappers.arraylen, randn(10, 7)), + (false, :stability, nothing, IntrinsicsWrappers.ashr_int, 123456, 0x0000000000000020), # atomic_fence -- NEEDS IMPLEMENTING AND TESTING # atomic_pointermodify -- NEEDS IMPLEMENTING AND TESTING # atomic_pointerref -- NEEDS IMPLEMENTING AND TESTING # atomic_pointerreplace -- NEEDS IMPLEMENTING AND TESTING # atomic_pointerset -- NEEDS IMPLEMENTING AND TESTING # atomic_pointerswap -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, IntrinsicsWrappers.bitcast, Float64, 5], - [false, :stability, nothing, IntrinsicsWrappers.bitcast, Int64, 5.0], - [false, :stability, nothing, IntrinsicsWrappers.bswap_int, 5], - [false, :stability, nothing, IntrinsicsWrappers.ceil_llvm, 4.1], - [ + (false, :stability, nothing, IntrinsicsWrappers.bitcast, Float64, 5), + (false, :stability, nothing, IntrinsicsWrappers.bitcast, Int64, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.bswap_int, 5), + (false, :stability, nothing, IntrinsicsWrappers.ceil_llvm, 4.1), + ( true, :stability, nothing, IntrinsicsWrappers.__cglobal, Val{:jl_uv_stdout}(), Ptr{Cvoid}, - ], - [false, :stability, nothing, IntrinsicsWrappers.checked_sadd_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_sdiv_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_smul_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_srem_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_ssub_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_uadd_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_udiv_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_umul_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_urem_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.checked_usub_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.copysign_float, 5.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.copysign_float, 5.0, -3.0], + ), + (false, :stability, nothing, IntrinsicsWrappers.checked_sadd_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_sdiv_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_smul_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_srem_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_ssub_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_uadd_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_udiv_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_umul_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_urem_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.checked_usub_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.copysign_float, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.copysign_float, 5.0, -3.0), [false, :stability, nothing, IntrinsicsWrappers.ctlz_int, 5], - [false, :stability, nothing, IntrinsicsWrappers.ctpop_int, 5], - [false, :stability, nothing, IntrinsicsWrappers.cttz_int, 5], - [false, :stability, nothing, IntrinsicsWrappers.div_float, 5.0, 3.0], - [false, :stability, nothing, IntrinsicsWrappers.div_float_fast, 5.0, 3.0], - [false, :stability, nothing, IntrinsicsWrappers.eq_float, 5.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.eq_float, 4.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.eq_float_fast, 5.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.eq_float_fast, 4.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.eq_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.eq_int, 4, 4], - [false, :stability, nothing, IntrinsicsWrappers.flipsign_int, 4, -3], - [false, :stability, nothing, IntrinsicsWrappers.floor_llvm, 4.1], - [false, :stability, nothing, IntrinsicsWrappers.fma_float, 5.0, 4.0, 3.0], + (false, :stability, nothing, IntrinsicsWrappers.ctpop_int, 5), + (false, :stability, nothing, IntrinsicsWrappers.cttz_int, 5), + (false, :stability, nothing, IntrinsicsWrappers.div_float, 5.0, 3.0), + (false, :stability, nothing, IntrinsicsWrappers.div_float_fast, 5.0, 3.0), + (false, :stability, nothing, IntrinsicsWrappers.eq_float, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.eq_float, 4.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.eq_float_fast, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.eq_float_fast, 4.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.eq_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.eq_int, 4, 4), + (false, :stability, nothing, IntrinsicsWrappers.flipsign_int, 4, -3), + (false, :stability, nothing, IntrinsicsWrappers.floor_llvm, 4.1), + (false, :stability, nothing, IntrinsicsWrappers.fma_float, 5.0, 4.0, 3.0), # fpext -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, IntrinsicsWrappers.fpiseq, 4.1, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.fptosi, UInt32, 4.1], - [false, :stability, nothing, IntrinsicsWrappers.fptoui, Int32, 4.1], + (false, :stability, nothing, IntrinsicsWrappers.fpiseq, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.fptosi, UInt32, 4.1), + (false, :stability, nothing, IntrinsicsWrappers.fptoui, Int32, 4.1), # fptrunc -- maybe interesting - [true, :stability, nothing, IntrinsicsWrappers.have_fma, Float64], - [false, :stability, nothing, IntrinsicsWrappers.le_float, 4.1, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.le_float_fast, 4.1, 4.0], + (true, :stability, nothing, IntrinsicsWrappers.have_fma, Float64), + (false, :stability, nothing, IntrinsicsWrappers.le_float, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.le_float_fast, 4.1, 4.0), # llvm_call -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, IntrinsicsWrappers.lshr_int, 1308622848, 0x0000000000000018], - [false, :stability, nothing, IntrinsicsWrappers.lt_float, 4.1, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.lt_float_fast, 4.1, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.mul_float, 5.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.mul_float_fast, 5.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.mul_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.muladd_float, 5.0, 4.0, 3.0], - [false, :stability, nothing, IntrinsicsWrappers.ne_float, 5.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.ne_float_fast, 5.0, 4.0], - [false, :stability, nothing, IntrinsicsWrappers.ne_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.ne_int, 5, 5], - [false, :stability, nothing, IntrinsicsWrappers.neg_float, 5.0], - [false, :stability, nothing, IntrinsicsWrappers.neg_float_fast, 5.0], - [false, :stability, nothing, IntrinsicsWrappers.neg_int, 5], - [false, :stability, nothing, IntrinsicsWrappers.not_int, 5], - [false, :stability, nothing, IntrinsicsWrappers.or_int, 5, 5], - # pointerref -- integration tested because pointers are awkward. See below. - # pointerset -- integration tested because pointers are awkward. See below. + (false, :stability, nothing, IntrinsicsWrappers.lshr_int, 1308622848, 0x0000000000000018), + (false, :stability, nothing, IntrinsicsWrappers.lt_float, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.lt_float_fast, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.mul_float, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.mul_float_fast, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.mul_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.muladd_float, 5.0, 4.0, 3.0), + (false, :stability, nothing, IntrinsicsWrappers.ne_float, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.ne_float_fast, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.ne_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.ne_int, 5, 5), + (false, :stability, nothing, IntrinsicsWrappers.neg_float, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.neg_float_fast, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.neg_int, 5), + (false, :stability, nothing, IntrinsicsWrappers.not_int, 5), + (false, :stability, nothing, IntrinsicsWrappers.or_int, 5, 5), + (true, :stability, nothing, IntrinsicsWrappers.pointerref, CoDual(p, dp), 2, 1), + (true, :stability, nothing, IntrinsicsWrappers.pointerref, CoDual(q, dq), 2, 1), + (true, :stability, nothing, IntrinsicsWrappers.pointerset, CoDual(p, dp), 5.0, 2, 1), + (true, :stability, nothing, IntrinsicsWrappers.pointerset, CoDual(q, dq), 1, 2, 1), # rem_float -- untested and unimplemented because seemingly unused on master # rem_float_fast -- untested and unimplemented because seemingly unused on master - [false, :stability, nothing, IntrinsicsWrappers.rint_llvm, 5], - [false, :stability, nothing, IntrinsicsWrappers.sdiv_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.sext_int, Int64, Int32(1308622848)], - [false, :stability, nothing, IntrinsicsWrappers.shl_int, 1308622848, 0xffffffffffffffe8], - [false, :stability, nothing, IntrinsicsWrappers.sitofp, Float64, 0], - [false, :stability, nothing, IntrinsicsWrappers.sle_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.slt_int, 4, 5], - [false, :stability, nothing, IntrinsicsWrappers.sqrt_llvm, 5.0], - [false, :stability, nothing, IntrinsicsWrappers.sqrt_llvm_fast, 5.0], - [false, :stability, nothing, IntrinsicsWrappers.srem_int, 4, 1], - [false, :stability, nothing, IntrinsicsWrappers.sub_float, 4.0, 1.0], - [false, :stability, nothing, IntrinsicsWrappers.sub_float_fast, 4.0, 1.0], - [false, :stability, nothing, IntrinsicsWrappers.sub_int, 4, 1], - [false, :stability, nothing, IntrinsicsWrappers.trunc_int, UInt8, 78], - [false, :stability, nothing, IntrinsicsWrappers.trunc_llvm, 5.1], - [false, :stability, nothing, IntrinsicsWrappers.udiv_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.uitofp, Float16, 4], - [false, :stability, nothing, IntrinsicsWrappers.ule_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.ult_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.urem_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.xor_int, 5, 4], - [false, :stability, nothing, IntrinsicsWrappers.zext_int, Int64, 0xffffffff], + (false, :stability, nothing, IntrinsicsWrappers.rint_llvm, 5), + (false, :stability, nothing, IntrinsicsWrappers.sdiv_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.sext_int, Int64, Int32(1308622848)), + (false, :stability, nothing, IntrinsicsWrappers.shl_int, 1308622848, 0xffffffffffffffe8), + (false, :stability, nothing, IntrinsicsWrappers.sitofp, Float64, 0), + (false, :stability, nothing, IntrinsicsWrappers.sle_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.slt_int, 4, 5), + (false, :stability, nothing, IntrinsicsWrappers.sqrt_llvm, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.sqrt_llvm_fast, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.srem_int, 4, 1), + (false, :stability, nothing, IntrinsicsWrappers.sub_float, 4.0, 1.0), + (false, :stability, nothing, IntrinsicsWrappers.sub_float_fast, 4.0, 1.0), + (false, :stability, nothing, IntrinsicsWrappers.sub_int, 4, 1), + (false, :stability, nothing, IntrinsicsWrappers.trunc_int, UInt8, 78), + (false, :stability, nothing, IntrinsicsWrappers.trunc_llvm, 5.1), + (false, :stability, nothing, IntrinsicsWrappers.udiv_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.uitofp, Float16, 4), + (false, :stability, nothing, IntrinsicsWrappers.ule_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.ult_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.urem_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.xor_int, 5, 4), + (false, :stability, nothing, IntrinsicsWrappers.zext_int, Int64, 0xffffffff), # Non-intrinsic built-ins: # Core._abstracttype -- NEEDS IMPLEMENTING AND TESTING @@ -775,46 +776,32 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) # Core._structtype -- NEEDS IMPLEMENTING AND TESTING # Core._svec_ref -- NEEDS IMPLEMENTING AND TESTING # Core._typebody! -- NEEDS IMPLEMENTING AND TESTING - [true, :stability, nothing, Core._typevar, :T, Union{}, Any], - [false, :stability, nothing, <:, Float64, Int], - [false, :stability, nothing, <:, Any, Float64], - [false, :stability, nothing, <:, Float64, Any], - [false, :stability, nothing, ===, 5.0, 4.0], - [false, :stability, nothing, ===, 5.0, randn(5)], - [false, :stability, nothing, ===, randn(5), randn(3)], - [false, :stability, nothing, ===, 5.0, 5.0], - [false, :none, (lb=1e-3, ub=100.0), Core.apply_type, Vector, Float64], - [false, :none, (lb=1e-3, ub=100.0), Core.apply_type, Array, Float64, 2], - [false, :stability, nothing, Core.arraysize, randn(5, 4, 3), 2], - [false, :stability, nothing, Core.arraysize, randn(5, 4, 3, 2, 1), 100], - # Core.compilerbarrier -- NEEDS IMPLEMENTING AND TESTING - # Core.const_arrayref -- NEEDS IMPLEMENTING AND TESTING - # Core.donotdelete -- NEEDS IMPLEMENTING AND TESTING - # Core.finalizer -- NEEDS IMPLEMENTING AND TESTING - # Core.get_binding_type -- NEEDS IMPLEMENTING AND TESTING - [false, :none, nothing, Core.ifelse, true, randn(5), 1], - [false, :none, nothing, Core.ifelse, false, randn(5), 2], - (false, :stability, nothing, Core.ifelse, true, 5, 4), - (false, :stability, nothing, Core.ifelse, false, true, false), - [false, :stability, nothing, Core.ifelse, false, 1.0, 2.0], - [false, :stability, nothing, Core.ifelse, true, 1.0, 2.0], - [false, :stability, nothing, Core.ifelse, false, randn(5), randn(3)], - [false, :stability, nothing, Core.ifelse, true, randn(5), randn(3)], - # Core.set_binding_type! -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, Core.sizeof, Float64], - [false, :stability, nothing, Core.sizeof, randn(5)], - # Core.svec -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, Base.arrayref, true, randn(5), 1], - [false, :stability, nothing, Base.arrayref, false, randn(4), 1], - [false, :stability, nothing, Base.arrayref, true, randn(5, 4), 1, 1], - [false, :stability, nothing, Base.arrayref, false, randn(5, 4), 5, 4], - [false, :stability, nothing, Base.arrayset, false, randn(5), 4.0, 3], - [false, :stability, nothing, Base.arrayset, false, randn(5, 4), 3.0, 1, 3], - [false, :stability, nothing, Base.arrayset, true, randn(5), 4.0, 3], - [false, :stability, nothing, Base.arrayset, true, randn(5, 4), 3.0, 1, 3], - [false, :stability, nothing, Base.arrayset, false, [randn(3) for _ in 1:5], randn(4), 1], - # [false, :stability, Base.arrayset, false, _a, randn(4), 1], # _a is not fully initialised - [ + (false, :stability, nothing, <:, Float64, Int), + (false, :stability, nothing, <:, Any, Float64), + (false, :stability, nothing, <:, Float64, Any), + (false, :stability, nothing, ===, 5.0, 4.0), + (false, :stability, nothing, ===, 5.0, randn(5)), + (false, :stability, nothing, ===, randn(5), randn(3)), + (false, :stability, nothing, ===, 5.0, 5.0), + (true, :stability, nothing, Core._typevar, :T, Union{}, Any), + (false, :none, (lb=1e-3, ub=100.0), Core.apply_type, Vector, Float64), + (false, :none, (lb=1e-3, ub=100.0), Core.apply_type, Array, Float64, 2), + (false, :stability, nothing, Base.arrayref, true, randn(5), 1), + (false, :stability, nothing, Base.arrayref, false, randn(4), 1), + (false, :stability, nothing, Base.arrayref, true, randn(5, 4), 1, 1), + (false, :stability, nothing, Base.arrayref, false, randn(5, 4), 5, 4), + (false, :stability, nothing, Base.arrayref, true, randn(5, 4), 1), + (false, :stability, nothing, Base.arrayref, false, randn(5, 4), 5), + (false, :stability, nothing, Base.arrayref, false, [1, 2, 3], 1), + (false, :stability, nothing, Base.arrayset, false, [1, 2, 3], 4, 2), + (false, :stability, nothing, Base.arrayset, false, randn(5), 4.0, 3), + (false, :stability, nothing, Base.arrayset, false, randn(5, 4), 3.0, 1, 3), + (false, :stability, nothing, Base.arrayset, true, randn(5), 4.0, 3), + (false, :stability, nothing, Base.arrayset, true, randn(5, 4), 3.0, 1, 3), + (false, :stability, nothing, Base.arrayset, false, [randn(3) for _ in 1:5], randn(4), 1), + (false, :stability, nothing, Base.arrayset, false, _a, randn(4), 1), + (false, :stability, nothing, Base.arrayset, true, [(5.0, rand(1))], (4.0, rand(1)), 1), + ( false, :stability, nothing, @@ -823,8 +810,8 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) setindex!(Vector{Vector{Float64}}(undef, 3), randn(3), 1), randn(4), 1, - ], - [ + ), + ( false, :stability, nothing, @@ -833,53 +820,77 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) setindex!(Vector{Vector{Float64}}(undef, 3), randn(3), 2), randn(4), 1, - ], - [false, :stability, nothing, applicable, sin, Float64], - [false, :stability, nothing, applicable, sin, Type], - [false, :stability, nothing, applicable, +, Type, Float64], - [false, :stability, nothing, applicable, +, Float64, Float64], - [false, :stability, (lb=1e-3, ub=20.0), fieldtype, StructFoo, :a], - [false, :stability, (lb=1e-3, ub=20.0), fieldtype, StructFoo, :b], - [false, :stability, (lb=1e-3, ub=20.0), fieldtype, MutableFoo, :a], - [false, :stability, (lb=1e-3, ub=20.0), fieldtype, MutableFoo, :b], - [true, :none, _range, getfield, StructFoo(5.0), :a], - [false, :none, _range, getfield, StructFoo(5.0, randn(5)), :a], - [false, :none, _range, getfield, StructFoo(5.0, randn(5)), :b], - [true, :none, _range, getfield, StructFoo(5.0), 1], - [false, :none, _range, getfield, StructFoo(5.0, randn(5)), 1], - [false, :none, _range, getfield, StructFoo(5.0, randn(5)), 2], - [true, :none, _range, getfield, MutableFoo(5.0), :a], - [false, :none, _range, getfield, MutableFoo(5.0, randn(5)), :b], - [false, :none, _range, getfield, UnitRange{Int}(5:9), :start], - [false, :none, _range, getfield, UnitRange{Int}(5:9), :stop], - [false, :none, _range, getfield, (5.0, ), 1, false], - (false, :none, _range, getfield, (1, ), 1, false), - (false, :none, _range, getfield, (1, 2), 1), - (false, :none, _range, getfield, (a=5, b=4), 1), - (false, :none, _range, getfield, (a=5, b=4), 2), + ), + (false, :stability, nothing, Core.arraysize, randn(5, 4, 3), 2), + (false, :stability, nothing, Core.arraysize, randn(5, 4, 3, 2, 1), 100), + # Core.compilerbarrier -- NEEDS IMPLEMENTING AND TESTING + # Core.const_arrayref -- NEEDS IMPLEMENTING AND TESTING + # Core.donotdelete -- NEEDS IMPLEMENTING AND TESTING + # Core.finalizer -- NEEDS IMPLEMENTING AND TESTING + # Core.get_binding_type -- NEEDS IMPLEMENTING AND TESTING + (false, :none, nothing, Core.ifelse, true, randn(5), 1), + (false, :none, nothing, Core.ifelse, false, randn(5), 2), + (false, :stability, nothing, Core.ifelse, true, 5, 4), + (false, :stability, nothing, Core.ifelse, false, true, false), + (false, :stability, nothing, Core.ifelse, false, 1.0, 2.0), + (false, :stability, nothing, Core.ifelse, true, 1.0, 2.0), + (false, :stability, nothing, Core.ifelse, false, randn(5), randn(3)), + (false, :stability, nothing, Core.ifelse, true, randn(5), randn(3)), + # Core.set_binding_type! -- NEEDS IMPLEMENTING AND TESTING + (false, :stability, nothing, Core.sizeof, Float64), + (false, :stability, nothing, Core.sizeof, randn(5)), + # Core.svec -- NEEDS IMPLEMENTING AND TESTING + (false, :stability, nothing, applicable, sin, Float64), + (false, :stability, nothing, applicable, sin, Type), + (false, :stability, nothing, applicable, +, Type, Float64), + (false, :stability, nothing, applicable, +, Float64, Float64), + (false, :stability, (lb=1e-3, ub=20.0), fieldtype, StructFoo, :a), + (false, :stability, (lb=1e-3, ub=20.0), fieldtype, StructFoo, :b), + (false, :stability, (lb=1e-3, ub=20.0), fieldtype, MutableFoo, :a), + (false, :stability, (lb=1e-3, ub=20.0), fieldtype, MutableFoo, :b), + (true, :none, _range, getfield, StructFoo(5.0), :a), + (false, :none, _range, getfield, StructFoo(5.0, randn(5)), :a), + (false, :none, _range, getfield, StructFoo(5.0, randn(5)), :b), + (true, :none, _range, getfield, StructFoo(5.0), 1), + (false, :none, _range, getfield, StructFoo(5.0, randn(5)), 1), + (false, :none, _range, getfield, StructFoo(5.0, randn(5)), 2), + (true, :none, _range, getfield, MutableFoo(5.0), :a), + (false, :none, _range, getfield, MutableFoo(5.0, randn(5)), :b), + (false, :stability_and_allocs, nothing, getfield, UnitRange{Int}(5:9), :start), + (false, :stability_and_allocs, nothing, getfield, UnitRange{Int}(5:9), :stop), + (false, :stability_and_allocs, nothing, getfield, (5.0, ), 1), + (false, :stability_and_allocs, nothing, getfield, (5.0, 4.0), 1), + (false, :stability_and_allocs, nothing, getfield, (5.0, ), 1, false), + (false, :stability_and_allocs, nothing, getfield, (5.0, 4.0), 1, false), + (false, :stability_and_allocs, nothing, getfield, (1, ), 1, false), + (false, :stability_and_allocs, nothing, getfield, (1, 2), 1), + (false, :stability_and_allocs, nothing, getfield, (a=5, b=4), 1), + (false, :stability_and_allocs, nothing, getfield, (a=5, b=4), 2), + (false, :none, nothing, getfield, (Float64, Float64), 1), + (false, :none, nothing, getfield, (Float64, Float64), 2, false), (false, :none, _range, getfield, (a=5.0, b=4), 1), (false, :none, _range, getfield, (a=5.0, b=4), 2), - [false, :none, _range, getfield, UInt8, :name], - [false, :none, _range, getfield, UInt8, :super], - [true, :none, _range, getfield, UInt8, :layout], - [false, :none, _range, getfield, UInt8, :hash], - [false, :none, _range, getfield, UInt8, :flags], + (false, :none, _range, getfield, UInt8, :name), + (false, :none, _range, getfield, UInt8, :super), + (true, :none, _range, getfield, UInt8, :layout), + (false, :none, _range, getfield, UInt8, :hash), + (false, :none, _range, getfield, UInt8, :flags), # getglobal requires compositional testing, because you can't deepcopy a module # invoke -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, isa, 5.0, Float64], - [false, :stability, nothing, isa, 1, Float64], - [false, :stability, nothing, isdefined, MutableFoo(5.0, randn(5)), :sim], - [false, :stability, nothing, isdefined, MutableFoo(5.0, randn(5)), :a], + (false, :stability, nothing, isa, 5.0, Float64), + (false, :stability, nothing, isa, 1, Float64), + (false, :stability, nothing, isdefined, MutableFoo(5.0, randn(5)), :sim), + (false, :stability, nothing, isdefined, MutableFoo(5.0, randn(5)), :a), # modifyfield! -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, nfields, MutableFoo], - [false, :stability, nothing, nfields, StructFoo], + (false, :stability, nothing, nfields, MutableFoo), + (false, :stability, nothing, nfields, StructFoo), # replacefield! -- NEEDS IMPLEMENTING AND TESTING (false, :none, _range, setfield!, MutableFoo(5.0, randn(5)), :a, 4.0), (false, :none, nothing, setfield!, MutableFoo(5.0, randn(5)), :b, randn(5)), (false, :none, _range, setfield!, MutableFoo(5.0, randn(5)), 1, 4.0), (false, :none, _range, setfield!, MutableFoo(5.0, randn(5)), 2, randn(5)), - (false, :stability, _range, setfield!, NonDifferentiableFoo(5, false), 1, 4), - (false, :stability, _range, setfield!, NonDifferentiableFoo(5, true), 2, false), + (false, :none, _range, setfield!, NonDifferentiableFoo(5, false), 1, 4), + (false, :none, _range, setfield!, NonDifferentiableFoo(5, true), 2, false), # swapfield! -- NEEDS IMPLEMENTING AND TESTING # throw -- NEEDS IMPLEMENTING AND TESTING [false, :stability_and_allocs, nothing, tuple, 5.0, 4.0], @@ -889,21 +900,20 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability_and_allocs, nothing, tuple), (false, :stability_and_allocs, nothing, tuple, 1), (false, :stability_and_allocs, nothing, tuple, 1, 5), + (false, :stability_and_allocs, nothing, tuple, 1.0, (5, )), [false, :stability, nothing, typeassert, 5.0, Float64], [false, :stability, nothing, typeassert, randn(5), Vector{Float64}], [false, :stability, nothing, typeof, 5.0], [false, :stability, nothing, typeof, randn(5)], ] - memory = Any[_x, _dx, _a] + memory = Any[_x, _dx, _a, p, dp] return test_cases, memory end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) test_cases = Any[ - [ - false, - :none, - nothing, + ( + false, :none, nothing, ( function (x) rx = Ref(x) @@ -911,13 +921,19 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) end ), 5.0, - ], - [false, :none, nothing, (v, x) -> (pointerset(pointer(x), v, 2, 1); x), 3.0, randn(5)], - [false, :none, nothing, x -> (pointerset(pointer(x), UInt8(3), 2, 1); x), rand(UInt8, 5)], - [false, :none, nothing, getindex, randn(5), [1, 1]], - [false, :none, nothing, getindex, randn(5), [1, 2, 2]], - [false, :none, nothing, setindex!, randn(5), [4.0, 5.0], [1, 1]], - [false, :none, nothing, setindex!, randn(5), [4.0, 5.0, 6.0], [1, 2, 2]], + ), + ( + false, :none, nothing, + (v, x) -> (pointerset(pointer(x), v, 2, 1); x), 3.0, randn(5), + ), + ( + false, :none, nothing, + x -> (pointerset(pointer(x), UInt8(3), 2, 1); x), rand(UInt8, 5), + ), + (false, :none, nothing, getindex, randn(5), [1, 1]), + (false, :none, nothing, getindex, randn(5), [1, 2, 2]), + (false, :none, nothing, setindex!, randn(5), [4.0, 5.0], [1, 1]), + (false, :none, nothing, setindex!, randn(5), [4.0, 5.0, 6.0], [1, 2, 2]), ] memory = Any[] return test_cases, memory diff --git a/src/rrules/foreigncall.jl b/src/rrules/foreigncall.jl index 42f7fe6c..eb964663 100644 --- a/src/rrules/foreigncall.jl +++ b/src/rrules/foreigncall.jl @@ -59,10 +59,6 @@ over from there. ) end -@generated function _eval(::typeof(_foreigncall_), x::Vararg{Any, N}) where {N} - return Expr(:call, :_foreigncall_, map(n -> :(getfield(x, $n)), 1:N)...) -end - @is_primitive MinimalCtx Tuple{typeof(_foreigncall_), Vararg} # @@ -70,114 +66,34 @@ end # @is_primitive MinimalCtx Tuple{typeof(Base.allocatedinline), Type} -function rrule!!(::CoDual{typeof(Base.allocatedinline)}, T::CoDual{<:Type}) - return CoDual(Base.allocatedinline(primal(T)), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(Base.allocatedinline)}, T::CoDual{<:Type}) + return zero_fcodual(Base.allocatedinline(primal(T))), NoPullback(f, T) end @is_primitive MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), Vararg} where {T, N} function rrule!!( - ::CoDual{Type{Array{T, N}}}, ::CoDual{typeof(undef)}, m::Vararg{CoDual} + f::CoDual{Type{Array{T, N}}}, u::CoDual{typeof(undef)}, m::Vararg{CoDual} ) where {T, N} - _m = map(primal, m) - p = Array{T, N}(undef, _m...) - t = Array{tangent_type(T), N}(undef, _m...) - if isassigned(t, 1) - for n in eachindex(t) - @inbounds t[n] = zero_tangent(p[n]) - end - end - return CoDual(p, t), NoPullback() + return zero_fcodual(Array{T, N}(undef, map(primal, m)...)), NoPullback(f, u, m...) end @is_primitive MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), NTuple{N}} where {T, N} function rrule!!( ::CoDual{<:Type{<:Array{T, N}}}, ::CoDual{typeof(undef)}, m::CoDual{NTuple{N}}, ) where {T, N} - _m = primal(m) - p = Array{T, N}(undef, _m) - t = Array{tangent_type(T), N}(undef, _m) - if isassigned(t, 1) - for n in eachindex(t) - @inbounds t[n] = zero_tangent(p[n]) - end - end - return CoDual(p, t), NoPullback() + return rrule!!(zero_fcodual(Array{T, N}), zero_fcodual(undef), m) end @is_primitive MinimalCtx Tuple{typeof(copy), Array} function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) - y = CoDual(copy(primal(a)), copy(tangent(a))) - copy_pullback!!(dy, df, dx) = df, increment!!(dx, dy) - return y, copy_pullback!! -end - -@is_primitive MinimalCtx Tuple{typeof(fill!), Array{<:Union{UInt8, Int8}}, Integer} -function rrule!!( - ::CoDual{typeof(fill!)}, - a::CoDual{<:Union{Array{UInt8}, Array{Int8}}, <:Array{NoTangent}}, - x::CoDual{<:Integer}, -) - old_value = copy(primal(a)) - fill!(primal(a), primal(x)) - function fill!_pullback!!(dy, df, da, dx) - primal(a) .= old_value - return df, da, dx - end - return a, fill!_pullback!! -end - -@is_primitive MinimalCtx Tuple{typeof(Base._growbeg!), Vector, Integer} -function rrule!!( - ::CoDual{typeof(Base._growbeg!)}, _a::CoDual{<:Vector{T}}, _delta::CoDual{<:Integer}, -) where {T} - d = primal(_delta) - a = primal(_a) - Base._growbeg!(a, d) - Base._growbeg!(tangent(_a), d) - function _growbeg!_pb!!(_, df, da, ddelta) - Base._deletebeg!(a, d) - Base._deletebeg!(da, d) - return df, da, ddelta - end - return zero_codual(nothing), _growbeg!_pb!! -end - -@is_primitive MinimalCtx Tuple{typeof(Base._growend!), Vector, Integer} -function rrule!!( - ::CoDual{typeof(Base._growend!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer}, -) - d = primal(_delta) - a = primal(_a) - Base._growend!(a, d) - Base._growend!(tangent(_a), d) - function _growend!_pullback!!(dy, df, da, ddelta) - Base._deleteend!(a, d) - Base._deleteend!(da, d) - return df, da, ddelta - end - return zero_codual(nothing), _growend!_pullback!! -end - -@is_primitive MinimalCtx Tuple{typeof(Base._growat!), Vector, Integer, Integer} -function rrule!!( - ::CoDual{typeof(Base._growat!)}, - _a::CoDual{<:Vector}, - _i::CoDual{<:Integer}, - _delta::CoDual{<:Integer}, -) - # Extract data. - a, i, delta = map(primal, (_a, _i, _delta)) - - # Run the primal. - Base._growat!(a, i, delta) - Base._growat!(tangent(_a), i, delta) - - function _growat!_pb!!(_, df, da, di, ddelta) - deleteat!(a, i:i+delta-1) - deleteat!(da, i:i+delta-1) - return df, da, di, ddelta + dx = tangent(a) + dy = copy(dx) + y = CoDual(copy(primal(a)), dy) + function copy_pullback!!(::NoRData) + increment!!(dx, dy) + return NoRData(), NoRData() end - return zero_codual(nothing), _growat!_pb!! + return y, copy_pullback!! end @is_primitive MinimalCtx Tuple{typeof(Base._deletebeg!), Vector, Integer} @@ -186,19 +102,20 @@ function rrule!!( ) delta = primal(_delta) a = primal(_a) + da = tangent(_a) a_beg = a[1:delta] - da_beg = tangent(_a)[1:delta] + da_beg = da[1:delta] Base._deletebeg!(a, delta) - Base._deletebeg!(tangent(_a), delta) + Base._deletebeg!(da, delta) - function _deletebeg!_pb!!(_, df, da, ddelta) + function _deletebeg!_pb!!(::NoRData) splice!(a, 1:0, a_beg) splice!(da, 1:0, da_beg) - return df, da, ddelta + return NoRData(), NoRData(), NoRData() end - return zero_codual(nothing), _deletebeg!_pb!! + return zero_fcodual(nothing), _deletebeg!_pb!! end @is_primitive MinimalCtx Tuple{typeof(Base._deleteend!), Vector, Integer} @@ -207,17 +124,18 @@ function rrule!!( ) # Extract data. a = primal(_a) + da = tangent(_a) delta = primal(_delta) # Store the section to be cut for later. primal_tail = a[end-delta+1:end] - tangent_tail = tangent(_a)[end-delta+1:end] + tangent_tail = da[end-delta+1:end] # Cut the end off the primal and tangent. Base._deleteend!(a, delta) - Base._deleteend!(tangent(_a), delta) + Base._deleteend!(da, delta) - function _deleteend!_pb!!(_, df, da, ddelta) + function _deleteend!_pb!!(::NoRData) Base._growend!(a, delta) a[end-delta+1:end] .= primal_tail @@ -225,9 +143,9 @@ function rrule!!( Base._growend!(da, delta) da[end-delta+1:end] .= tangent_tail - return df, da, ddelta + return NoRData(), NoRData(), NoRData() end - return zero_codual(nothing), _deleteend!_pb!! + return zero_fcodual(nothing), _deleteend!_pb!! end @is_primitive MinimalCtx Tuple{typeof(Base._deleteat!), Vector, Integer, Integer} @@ -239,48 +157,142 @@ function rrule!!( ) # Extract data. a, i, delta = map(primal, (_a, _i, _delta)) + da = tangent(_a) # Store the cut section for later. primal_mem = a[i:i+delta-1] - tangent_mem = tangent(_a)[i:i+delta-1] + tangent_mem = da[i:i+delta-1] # Run the primal. Base._deleteat!(a, i, delta) - Base._deleteat!(tangent(_a), i, delta) + Base._deleteat!(da, i, delta) - function _deleteat!_pb!!(_, df, da, di, ddelta) + function _deleteat!_pb!!(::NoRData) splice!(a, i:i-1, primal_mem) splice!(da, i:i-1, tangent_mem) - return df, da, di, ddelta + return NoRData(), NoRData(), NoRData(), NoRData() + end + + return zero_fcodual(nothing), _deleteat!_pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(Base._growbeg!), Vector, Integer} +function rrule!!( + ::CoDual{typeof(Base._growbeg!)}, _a::CoDual{<:Vector{T}}, _delta::CoDual{<:Integer}, +) where {T} + d = primal(_delta) + a = primal(_a) + da = tangent(_a) + Base._growbeg!(a, d) + Base._growbeg!(da, d) + function _growbeg!_pb!!(::NoRData) + Base._deletebeg!(a, d) + Base._deletebeg!(da, d) + return NoRData(), NoRData(), NoRData() + end + return zero_fcodual(nothing), _growbeg!_pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(Base._growend!), Vector, Integer} +function rrule!!( + ::CoDual{typeof(Base._growend!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer}, +) + d = primal(_delta) + a = primal(_a) + da = tangent(_a) + Base._growend!(a, d) + Base._growend!(da, d) + function _growend!_pullback!!(::NoRData) + Base._deleteend!(a, d) + Base._deleteend!(da, d) + return NoRData(), NoRData(), NoRData() end + return zero_fcodual(nothing), _growend!_pullback!! +end + +@is_primitive MinimalCtx Tuple{typeof(Base._growat!), Vector, Integer, Integer} +function rrule!!( + ::CoDual{typeof(Base._growat!)}, + _a::CoDual{<:Vector}, + _i::CoDual{<:Integer}, + _delta::CoDual{<:Integer}, +) + # Extract data. + a, i, delta = map(primal, (_a, _i, _delta)) + da = tangent(_a) + + # Run the primal. + Base._growat!(a, i, delta) + Base._growat!(da, i, delta) - return zero_codual(nothing), _deleteat!_pb!! + function _growat!_pb!!(::NoRData) + deleteat!(a, i:i+delta-1) + deleteat!(da, i:i+delta-1) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return zero_fcodual(nothing), _growat!_pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(fill!), Array{<:Union{UInt8, Int8}}, Integer} +function rrule!!( + ::CoDual{typeof(fill!)}, a::CoDual{<:Array{<:Union{UInt8, Int8}}}, x::CoDual{<:Integer} +) + pa = primal(a) + old_value = copy(pa) + fill!(pa, primal(x)) + function fill!_pullback!!(::NoRData) + pa .= old_value + return NoRData(), NoRData(), NoRData() + end + return a, fill!_pullback!! end @is_primitive MinimalCtx Tuple{typeof(sizehint!), Vector, Integer} -function rrule!!(::CoDual{typeof(sizehint!)}, x::CoDual{<:Vector}, sz::CoDual{<:Integer}) +function rrule!!(f::CoDual{typeof(sizehint!)}, x::CoDual{<:Vector}, sz::CoDual{<:Integer}) sizehint!(primal(x), primal(sz)) sizehint!(tangent(x), primal(sz)) - return x, NoPullback() + return x, NoPullback(f, x, sz) end @is_primitive MinimalCtx Tuple{typeof(objectid), Any} -function rrule!!(::CoDual{typeof(objectid)}, @nospecialize(x)) - return CoDual(objectid(primal(x)), NoTangent()), NoPullback() +function rrule!!(f::CoDual{typeof(objectid)}, @nospecialize(x)) + return zero_fcodual(objectid(primal(x))), NoPullback(f, x) end @is_primitive MinimalCtx Tuple{typeof(pointer_from_objref), Any} -function rrule!!(::CoDual{typeof(pointer_from_objref)}, x) +function rrule!!(f::CoDual{typeof(pointer_from_objref)}, x) y = CoDual( pointer_from_objref(primal(x)), bitcast(Ptr{tangent_type(Nothing)}, pointer_from_objref(tangent(x))), ) - return y, NoPullback() + return y, NoPullback(f, x) end @is_primitive MinimalCtx Tuple{typeof(CC.return_type), Vararg} -function rrule!!(::CoDual{typeof(Core.Compiler.return_type)}, args...) - return zero_codual(Core.Compiler.return_type(map(primal, args)...)), NoPullback() +function rrule!!(f::CoDual{typeof(Core.Compiler.return_type)}, args...) + pb!! = NoPullback(f, args...) + return zero_fcodual(Core.Compiler.return_type(map(primal, args)...)), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(Base.unsafe_pointer_to_objref), Ptr} +function rrule!!(f::CoDual{typeof(Base.unsafe_pointer_to_objref)}, x::CoDual{<:Ptr}) + y = CoDual(unsafe_pointer_to_objref(primal(x)), unsafe_pointer_to_objref(tangent(x))) + return y, NoPullback(f, x) +end + +@is_primitive MinimalCtx Tuple{typeof(Threads.threadid)} +function rrule!!(f::CoDual{typeof(Threads.threadid)}) + return zero_fcodual(Threads.threadid()), NoPullback(f) +end + +@is_primitive MinimalCtx Tuple{typeof(typeintersect), Any, Any} +function rrule!!(f::CoDual{typeof(typeintersect)}, @nospecialize(a), @nospecialize(b)) + return zero_fcodual(typeintersect(primal(a), primal(b))), NoPullback(f, a, b) +end + +function _increment_pointer!(x::Ptr{T}, y::Ptr{T}, N::Integer) where {T} + increment!!(unsafe_wrap(Vector{T}, x, N), unsafe_wrap(Vector{T}, y, N)) + return x end # unsafe_copyto! is the only function in Julia that appears to rely on a ccall to `memmove`. @@ -295,23 +307,26 @@ function rrule!!( # Record values that will be overwritten. dest_copy = Vector{T}(undef, _n) ddest_copy = Vector{T}(undef, _n) - unsafe_copyto!(pointer(dest_copy), primal(dest), _n) - unsafe_copyto!(pointer(ddest_copy), tangent(dest), _n) + pdest = primal(dest) + ddest = tangent(dest) + unsafe_copyto!(pointer(dest_copy), pdest, _n) + unsafe_copyto!(pointer(ddest_copy), ddest, _n) # Run primal computation. + dsrc = tangent(src) unsafe_copyto!(primal(dest), primal(src), _n) - unsafe_copyto!(tangent(dest), tangent(src), _n) + unsafe_copyto!(tangent(dest), dsrc, _n) - function unsafe_copyto!_pb!!(_, df, ddest, dsrc, dn) + function unsafe_copyto!_pb!!(::NoRData) # Increment dsrc. - dsrc = _increment_pointer!(dsrc, ddest, _n) + _increment_pointer!(dsrc, ddest, _n) # Restore initial state. - unsafe_copyto!(primal(dest), pointer(dest_copy), _n) - unsafe_copyto!(tangent(dest), pointer(ddest_copy), _n) + unsafe_copyto!(pdest, pointer(dest_copy), _n) + unsafe_copyto!(ddest, pointer(ddest_copy), _n) - return df, ddest, dsrc, dn + return NoRData(), NoRData(), NoRData(), NoRData() end return dest, unsafe_copyto!_pb!! end @@ -332,49 +347,32 @@ function rrule!!( _doffs = primal(doffs) dest_idx = _doffs:_doffs + _n - 1 _soffs = primal(soffs) + pdest = primal(dest) + ddest = tangent(dest) dest_copy = primal(dest)[dest_idx] ddest_copy = tangent(dest)[dest_idx] # Run primal computation. + dsrc = tangent(src) unsafe_copyto!(primal(dest), _doffs, primal(src), _soffs, _n) - unsafe_copyto!(tangent(dest), _doffs, tangent(src), _soffs, _n) + unsafe_copyto!(tangent(dest), _doffs, dsrc, _soffs, _n) - function unsafe_copyto_pb!!(_, df, ddest, ddoffs, dsrc, dsoffs, dn) + function unsafe_copyto_pb!!(::NoRData) # Increment dsrc. src_idx = _soffs:_soffs + _n - 1 dsrc[src_idx] .= increment!!.(view(dsrc, src_idx), view(ddest, dest_idx)) # Restore initial state. - primal(dest)[dest_idx] .= dest_copy - tangent(dest)[dest_idx] .= ddest_copy + pdest[dest_idx] .= dest_copy + ddest[dest_idx] .= ddest_copy - return df, ddest, ddoffs, dsrc, dsoffs, dn + return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData() end return dest, unsafe_copyto_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.unsafe_pointer_to_objref), Ptr} -function rrule!!(::CoDual{typeof(Base.unsafe_pointer_to_objref)}, x::CoDual{<:Ptr}) - y = CoDual(unsafe_pointer_to_objref(primal(x)), unsafe_pointer_to_objref(tangent(x))) - return y, NoPullback() -end - -@is_primitive MinimalCtx Tuple{typeof(Threads.threadid)} -rrule!!(::CoDual{typeof(Threads.threadid)}) = zero_codual(Threads.threadid()), NoPullback() - -@is_primitive MinimalCtx Tuple{typeof(typeintersect), Any, Any} -function rrule!!(::CoDual{typeof(typeintersect)}, @nospecialize(a), @nospecialize(b)) - y = typeintersect(primal(a), primal(b)) - return CoDual(y, zero_tangent(y)), NoPullback() -end - -function _increment_pointer!(x::Ptr{T}, y::Ptr{T}, N::Integer) where {T} - increment!!(unsafe_wrap(Vector{T}, x, N), unsafe_wrap(Vector{T}, y, N)) - return x -end - # @@ -394,7 +392,7 @@ function rrule!!( ccall(:jl_array_ptr, Ptr{T}, (Any, ), primal(a)), ccall(:jl_array_ptr, Ptr{V}, (Any, ), tangent(a)), ) - return y, NoPullback() + return y, NoPullback(ntuple(_ -> NoRData(), 7)) end # function rrule!!( @@ -429,7 +427,7 @@ function rrule!!( ccall(:jl_reshape_array, Array{P, M}, (Any, Any, Any), Array{P, M}, primal(a), d), ccall(:jl_reshape_array, Array{T, M}, (Any, Any, Any), Array{T, M}, tangent(a), d), ) - return y, NoPullback() + return y, NoPullback(ntuple(_ -> NoRData(), 9)) end function rrule!!( @@ -444,33 +442,41 @@ function rrule!!( args..., ) y = ccall(:jl_array_isassigned, Cint, (Any, UInt), primal(a), primal(ii)) - return zero_codual(y), NoPullback() + return zero_fcodual(y), NoPullback(ntuple(_ -> NoRData(), length(args) + 8)) end @is_primitive MinimalCtx Tuple{typeof(deepcopy), Any} function rrule!!(::CoDual{typeof(deepcopy)}, x::CoDual) - deepcopy_pb!!(dy, df, dx) = df, increment!!(dx, dy) - return deepcopy(x), deepcopy_pb!! + fdx = tangent(x) + dx = zero_rdata(primal(x)) + y = deepcopy(x) + fdy = tangent(y) + function deepcopy_pb!!(dy) + increment!!(fdx, fdy) + return NoRData(), increment!!(dx, dy) + end + return y, deepcopy_pb!! end @is_primitive MinimalCtx Tuple{Type{UnionAll}, TypeVar, Any} @is_primitive MinimalCtx Tuple{Type{UnionAll}, TypeVar, Type} -function rrule!!(::CoDual{<:Type{UnionAll}}, x::CoDual{<:TypeVar}, y::CoDual{<:Type}) - return CoDual(UnionAll(primal(x), primal(y)), NoTangent()), NoPullback() +function rrule!!(f::CoDual{<:Type{UnionAll}}, x::CoDual{<:TypeVar}, y::CoDual{<:Type}) + return zero_fcodual(UnionAll(primal(x), primal(y))), NoPullback(f, x, y) end +@is_primitive MinimalCtx Tuple{typeof(hash), Union{String, SubString{String}}, UInt} function rrule!!( - ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_string_ptr}}, args::Vararg{CoDual, N} -) where {N} - x = tuple_map(primal, args) - return uninit_codual(_foreigncall_(Val(:jl_string_ptr), x...)), NoPullback() + f::CoDual{typeof(hash)}, s::CoDual{P}, h::CoDual{UInt} +) where {P<:Union{String, SubString{String}}} + return zero_fcodual(hash(primal(s), primal(h))), NoPullback(f, s, h) end -@is_primitive MinimalCtx Tuple{typeof(hash), Union{String, SubString{String}}, UInt} function rrule!!( - ::CoDual{typeof(hash)}, s::CoDual{P}, h::CoDual{UInt} -) where {P<:Union{String, SubString{String}}} - return zero_codual(hash(primal(s), primal(h))), NoPullback() + ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_string_ptr}}, args::Vararg{CoDual, N} +) where {N} + x = tuple_map(primal, args) + pb!! = NoPullback((NoRData(), NoRData(), tuple_map(_ -> NoRData(), args)...)) + return uninit_fcodual(_foreigncall_(Val(:jl_string_ptr), x...)), pb!! end function unexepcted_foreigncall_error(name) @@ -523,9 +529,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) (true, :stability, nothing, Array{Float64, 5}, undef, 5, 4, 3, 2, 1), (true, :stability, nothing, Array{Float64, 4}, undef, (2, 3, 4, 5)), (true, :stability, nothing, Array{Float64, 5}, undef, (2, 3, 4, 5, 6)), - (true, :stability, nothing, Base._growbeg!, randn(5), 3), - (true, :stability, nothing, Base._growend!, randn(5), 3), - (true, :stability, nothing, Base._growat!, randn(5), 2, 2), + (false, :stability, nothing, copy, randn(5, 4)), (false, :stability, nothing, Base._deletebeg!, randn(5), 0), (false, :stability, nothing, Base._deletebeg!, randn(5), 2), (false, :stability, nothing, Base._deletebeg!, randn(5), 5), @@ -536,9 +540,11 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) (false, :stability, nothing, Base._deleteat!, randn(5), 1, 5), (false, :stability, nothing, Base._deleteat!, randn(5), 5, 1), (false, :stability, nothing, sizehint!, randn(5), 10), - (false, :stability, nothing, copy, randn(5, 4)), (false, :stability, nothing, fill!, rand(Int8, 5), Int8(2)), (false, :stability, nothing, fill!, rand(UInt8, 5), UInt8(2)), + (true, :stability, nothing, Base._growbeg!, randn(5), 3), + (true, :stability, nothing, Base._growend!, randn(5), 3), + (true, :stability, nothing, Base._growat!, randn(5), 2, 2), (false, :stability, nothing, objectid, 5.0), (true, :stability, nothing, objectid, randn(5)), (true, :stability, nothing, pointer_from_objref, _x), @@ -576,6 +582,16 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) (false, :stability, nothing, deepcopy, (a=5.0, b=randn(5))), (false, :none, nothing, UnionAll, TypeVar(:a), Real), (false, :none, nothing, hash, "5", UInt(3)), + ( + true, :none, nothing, + _foreigncall_, + Val(:jl_array_ptr), + Val(Ptr{Float64}), + (Val(Any), ), + Val(0), # nreq + Val(:ccall), # calling convention + randn(5), + ) ] memory = Any[_x, _dx, _a, _da, _b, _db] return test_cases, memory @@ -591,31 +607,23 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) end test_cases = [ - Any[false, :none, nothing, reshape, randn(5, 4), (4, 5)], - Any[false, :none, nothing, reshape, randn(5, 4), (2, 10)], - Any[false, :none, nothing, reshape, randn(5, 4), (10, 2)], - Any[false, :none, nothing, reshape, randn(5, 4), (5, 4, 1)], - Any[false, :none, nothing, reshape, randn(5, 4), (2, 10, 1)], - Any[false, :none, nothing, unsafe_copyto_tester, randn(5), randn(3), 2], - Any[false, :none, nothing, unsafe_copyto_tester, randn(5), randn(6), 4], - [ - false, - :none, - nothing, - unsafe_copyto_tester, - [randn(3) for _ in 1:5], - [randn(4) for _ in 1:6], - 4, - ], - Any[ - false, - :none, - (lb=0.1, ub=150), - x -> unsafe_pointer_to_objref(pointer_from_objref(x)), - _x, - ], - Any[false, :none, nothing, isassigned, randn(5), 4], - Any[false, :none, nothing, x -> (Base._growbeg!(x, 2); x[1:2] .= 2.0), randn(5)], + (false, :none, nothing, reshape, randn(5, 4), (4, 5)), + (false, :none, nothing, reshape, randn(5, 4), (2, 10)), + (false, :none, nothing, reshape, randn(5, 4), (10, 2)), + (false, :none, nothing, reshape, randn(5, 4), (5, 4, 1)), + (false, :none, nothing, reshape, randn(5, 4), (2, 10, 1)), + (false, :none, nothing, unsafe_copyto_tester, randn(5), randn(3), 2), + (false, :none, nothing, unsafe_copyto_tester, randn(5), randn(6), 4), + ( + false, :none, nothing, + unsafe_copyto_tester, [randn(3) for _ in 1:5], [randn(4) for _ in 1:6], 4, + ), + ( + false, :none, (lb=0.1, ub=150), + x -> unsafe_pointer_to_objref(pointer_from_objref(x)), _x, + ), + (false, :none, nothing, isassigned, randn(5), 4), + (false, :none, nothing, x -> (Base._growbeg!(x, 2); x[1:2] .= 2.0), randn(5)), ] memory = Any[_x] return test_cases, memory diff --git a/src/rrules/iddict.jl b/src/rrules/iddict.jl index 467b04a0..8d69e888 100644 --- a/src/rrules/iddict.jl +++ b/src/rrules/iddict.jl @@ -46,6 +46,14 @@ function TestUtils.has_equal_data(p::P, q::P; equal_undefs=true) where {P<:IdDic return all([TestUtils.has_equal_data(p[k], q[k]; equal_undefs) for k in ks]) end +fdata_type(::Type{T}) where {T<:IdDict} = T +fdata(t::IdDict) = t +rdata_type(::Type{<:IdDict}) = NoRData +rdata(t::IdDict) = NoRData() + +tangent_type(::Type{T}, ::Type{NoRData}) where {T<:IdDict} = T +tangent(f::IdDict, ::NoRData) = f + # All of the rules in here are provided in order to avoid nasty `:ccall`s, and to support # standard built-in functionality on `IdDict`s. @@ -53,7 +61,7 @@ end function rrule!!(::CoDual{typeof(Base.rehash!)}, d::CoDual{<:IdDict}, newsz::CoDual) Base.rehash!(primal(d), primal(newsz)) Base.rehash!(tangent(d), primal(newsz)) - return d, NoPullback() + return d, NoPullback((NoRData(), NoRData(), NoRData())) end @is_primitive MinimalCtx Tuple{typeof(setindex!), IdDict, Any, Any} @@ -67,12 +75,14 @@ function rrule!!(::CoDual{typeof(setindex!)}, d::CoDual{IdDict{K,V}}, val, key) end setindex!(primal(d), primal(val), k) - setindex!(tangent(d), tangent(val), k) + setindex!(tangent(d), zero_tangent(primal(val), tangent(val)), k) - function setindex_pb!!(_, df, dd, dval, dkey) + dval = LazyZeroRData(primal(val)) + dkey = LazyZeroRData(primal(key)) + function setindex_pb!!(::NoRData) # Increment tangent. - dval = increment!!(dval, tangent(d)[k]) + _dval = increment!!(instantiate(dval), rdata(tangent(d)[k])) # Restore previous state if necessary. if restore_state @@ -83,7 +93,7 @@ function rrule!!(::CoDual{typeof(setindex!)}, d::CoDual{IdDict{K,V}}, val, key) delete!(tangent(d), k) end - return df, dd, dval, dkey + return NoRData(), NoRData(), _dval, instantiate(dkey) end return d, setindex_pb!! end @@ -94,15 +104,19 @@ function rrule!!( ) where {K, V} k = primal(key) has_key = in(k, keys(primal(d))) - y = has_key ? CoDual(primal(d)[k], tangent(d)[k]) : default + y = has_key ? CoDual(primal(d)[k], fdata(tangent(d)[k])) : default - function get_pb!!(dy, df, dd, dkey, ddefault) + dd = tangent(d) + dkey = LazyZeroRData(primal(key)) + rdefault = LazyZeroRData(primal(default)) + function get_pb!!(dy) if has_key - dd[k] = increment!!(dd[k], dy) + dd[k] = increment_rdata!!(dd[k], dy) + _rdefault = instantiate(rdefault) else - ddefault = increment!!(ddefault, dy) + _rdefault = increment_rdata!!(instantiate(rdefault), dy) end - return df, dd, dkey, ddefault + return NoRData(), NoRData(), instantiate(dkey), _rdefault end return y, get_pb!! end @@ -112,10 +126,12 @@ function rrule!!( ::CoDual{typeof(getindex)}, d::CoDual{IdDict{K, V}}, key::CoDual ) where {K, V} k = primal(key) - y = CoDual(getindex(primal(d), k), getindex(tangent(d), k)) - function getindex_pb!!(dy, df, dd, dkey) - dd[k] = increment!!(dd[k], dy) - return df, dd, dkey + y = CoDual(getindex(primal(d), k), fdata(getindex(tangent(d), k))) + dkey = LazyZeroRData(primal(key)) + dd = tangent(d) + function getindex_pb!!(dy) + dd[k] = increment_rdata!!(dd[k], dy) + return NoRData(), NoRData(), instantiate(dkey) end return y, getindex_pb!! end diff --git a/src/rrules/lapack.jl b/src/rrules/lapack.jl index c036263a..4c3c9497 100644 --- a/src/rrules/lapack.jl +++ b/src/rrules/lapack.jl @@ -1,6 +1,6 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) TInt = :(Ptr{BLAS.BlasInt}) - @eval function rrule!!( + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type @@ -13,8 +13,8 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) _LDA::CoDual{$TInt}, # leading dimension of A _IPIV::CoDual{$TInt}, # pivot indices _INFO::CoDual{$TInt}, # some info of some kind - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} GC.@preserve args begin # Extract names. M, N, LDA, IPIV, INFO = map(primal, (_M, _N, _LDA, _IPIV, _INFO)) @@ -44,9 +44,9 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) # Zero out the tangent. foreach(n -> unsafe_store!(dA, zero($elty), n), 1:data_len) - function getrf_pb!!( - _, d1, d2, d3, d4, d5, d6, dM, dN, dA, dLDA, dIPIV, dINFO, dargs... - ) + dA = tangent(_A) + function getrf_pb!!(::NoRData) + # Run reverse-pass. L, U = UnitLowerTriangular(A_mat), UpperTriangular(A_mat) dA_mat = wrap_ptr_as_view(dA, LDA_val, M_val, N_val) @@ -62,16 +62,16 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) # Restore initial state. A_mat .= A_store - return d1, d2, d3, d4, d5, d6, dM, dN, dA, dLDA, dIPIV, dINFO, dargs... + return tuple_fill(NoRData(), Val(12 + Nargs)) end - return CoDual(Cvoid(), zero_tangent(Cvoid())), getrf_pb!! + return zero_fcodual(Cvoid()), getrf_pb!! end end for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) TInt = :(Ptr{BLAS.BlasInt}) - @eval function rrule!!( + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type @@ -88,8 +88,9 @@ for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BLAS.BlasInt}}, _info::CoDual{Ptr{BLAS.BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} + # Load in data. ul_p, tA_p, diag_p = map(primal, (_ul, _tA, _diag)) N_p, Nrhs_p, lda_p, ldb_p, info_p = map(primal, (_N, _Nrhs, _lda, _ldb, _info)) @@ -114,10 +115,10 @@ for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) 1, 1, 1, ) - function trtrs_pb!!( - _, d1, d2, d3, d4, d5, d6, - dul, dtA, ddiag, dN, dNrhs, _dA, dlda, _dB, dldb, dINFO, dargs... - ) + _dA = tangent(_A) + _dB = tangent(_B) + function trtrs_pb!!(::NoRData) + # Compute cotangent of B. dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) LAPACK.trtrs!(Char(ul), Char(tA) == 'N' ? 'T' : 'N', Char(diag), A, dB) @@ -133,15 +134,14 @@ for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) # Restore initial state. B .= B_copy - return d1, d2, d3, d4, d5, d6, - dul, dtA, ddiag, dN, dNrhs, _dA, dlda, _dB, dldb, dINFO, dargs... + return tuple_fill(NoRData(), Val(16 + Nargs)) end - return CoDual(Cvoid(), zero_tangent(Cvoid())), trtrs_pb!! + return zero_fcodual(Cvoid()), trtrs_pb!! end end for (fname, elty) in ((:dgetrs_, :Float64), (:sgetrs_, :Float32)) - @eval function rrule!!( + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type @@ -157,8 +157,9 @@ for (fname, elty) in ((:dgetrs_, :Float64), (:sgetrs_, :Float32)) _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BlasInt}}, _info::CoDual{Ptr{BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} + # Load in values. tA = Char(unsafe_load(primal(_tA))) N, Nrhs, lda, ldb, info = map(unsafe_load ∘ primal, (_N, _Nrhs, _lda, _ldb, _info)) @@ -197,10 +198,10 @@ for (fname, elty) in ((:dgetrs_, :Float64), (:sgetrs_, :Float32)) # We need to write to `info`. unsafe_store!(primal(_info), 0) - function getrs_pb!!( - _, d1, d2, d3, d4, d5, d6, - dtA, dN, dNrhs, _dA, dlda, _ipiv, _dB, dldb, dINFO, dargs... - ) + _dA = tangent(_A) + _dB = tangent(_B) + function getrs_pb!!(::NoRData) + dA = wrap_ptr_as_view(_dA, lda, N, N) dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) @@ -234,15 +235,14 @@ for (fname, elty) in ((:dgetrs_, :Float64), (:sgetrs_, :Float32)) # Restore initial state. B .= B0 - return d1, d2, d3, d4, d5, d6, - dtA, dN, dNrhs, _dA, dlda, _ipiv, _dB, dldb, dINFO, dargs... + return tuple_fill(NoRData(), Val(15 + Nargs)) end - return CoDual(Cvoid(), zero_tangent(Cvoid())), getrs_pb!! + return zero_fcodual(Cvoid()), getrs_pb!! end end for (fname, elty) in ((:dgetri_, :Float64), (:sgetri_, :Float32)) - @eval function rrule!!( + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type @@ -256,8 +256,9 @@ for (fname, elty) in ((:dgetri_, :Float64), (:sgetri_, :Float32)) _work::CoDual{Ptr{$elty}}, _lwork::CoDual{Ptr{BlasInt}}, _info::CoDual{Ptr{BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} + # Pull out data. N_p, lda_p, lwork_p, info_p = map(primal, (_N, _lda, _lwork, _info)) N, lda, lwork, info = map(unsafe_load, (N_p, lda_p, lwork_p, info_p)) @@ -277,9 +278,8 @@ for (fname, elty) in ((:dgetri_, :Float64), (:sgetri_, :Float32)) p = LinearAlgebra.ipiv2perm(unsafe_wrap(Array, primal(_ipiv), N), N) - function getri_pb!!( - _, d1, d2, d3, d4, d5, d6, dN, _dA, dlda, dipiv, dwork, dlwork, dinfo, dargs... - ) + _dA = tangent(_A) + function getri_pb!!(::NoRData) if lwork != -1 dA = wrap_ptr_as_view(_dA, lda, N, N) A .= A[:, p] @@ -294,17 +294,16 @@ for (fname, elty) in ((:dgetri_, :Float64), (:sgetri_, :Float32)) A .= A_copy end - return d1, d2, d3, d4, d5, d6, - dN, _dA, dlda, dipiv, dwork, dlwork, dinfo, dargs... + return tuple_fill(NoRData(), Val(13 + Nargs)) end - return CoDual(Cvoid(), zero_tangent(Cvoid())), getri_pb!! + return zero_fcodual(Cvoid()), getri_pb!! end end __sym(X) = 0.5 * (X + X') for (fname, elty) in ((:dpotrf_, :Float64), (:spotrf_, :Float32)) - @eval function rrule!!( + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type @@ -316,8 +315,9 @@ for (fname, elty) in ((:dpotrf_, :Float64), (:spotrf_, :Float32)) _A::CoDual{Ptr{$elty}}, _lda::CoDual{Ptr{BlasInt}}, _info::CoDual{Ptr{BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} + # Pull out the data. uplo_p, N_p, A_p, lda_p, info_p = map(primal, (_uplo, _N, _A, _lda, _info)) uplo, lda, N = map(unsafe_load, (uplo_p, lda_p, N_p)) @@ -333,9 +333,9 @@ for (fname, elty) in ((:dpotrf_, :Float64), (:spotrf_, :Float32)) uplo_p, N_p, A_p, lda_p, info_p, ) - function potrf_pb!!( - _, d1, d2, d3, d4, d5, d6, duplo, dN, _dA, dlda, dinfo, dargs... - ) + _dA = tangent(_A) + function potrf_pb!!(::NoRData) + dA = wrap_ptr_as_view(_dA, lda, N, N) dA2 = dA @@ -355,14 +355,14 @@ for (fname, elty) in ((:dpotrf_, :Float64), (:spotrf_, :Float32)) # Restore initial state. A .= A_copy - return d1, d2, d3, d4, d5, d6, duplo, dN, _dA, dlda, dinfo, dargs... + return tuple_fill(NoRData(), Val(11 + Nargs)) end - return CoDual(Cvoid(), zero_tangent(Cvoid())), potrf_pb!! + return zero_fcodual(Cvoid()), potrf_pb!! end end for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32)) - @eval function rrule!!( + @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type @@ -377,8 +377,8 @@ for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32)) _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BlasInt}}, _info::CoDual{Ptr{BlasInt}}, - args..., - ) + args::Vararg{Any, Nargs}, + ) where {Nargs} # Pull out the data. uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p = map( primal, (_uplo, _N, _Nrhs, _A, _lda, _B, _ldb, _info) @@ -400,9 +400,10 @@ for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32)) uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p, ) - function potrs_pb!!( - _, d1, d2, d3, d4, d5, d6, duplo, dN, dNrhs, _dA, dlda, _dB, dldb, dinfo, dargs... - ) + _dA = tangent(_A) + _dB = tangent(_B) + function potrs_pb!!(::NoRData) + dA = wrap_ptr_as_view(_dA, lda, N, N) dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) @@ -420,9 +421,9 @@ for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32)) # Restore initial state. B .= B_copy - return d1, d2, d3, d4, d5, d6, duplo, dN, dNrhs, _dA, dlda, _dB, dldb, dinfo, dargs... + return tuple_fill(NoRData(), Val(14 + Nargs)) end - return CoDual(Cvoid(), zero_tangent(Cvoid())), potrs_pb!! + return zero_fcodual(Cvoid()), potrs_pb!! end end @@ -434,12 +435,12 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) # getrf! [ - Any[false, :none, nothing, getrf_wrapper!, randn(5, 5), false], - Any[false, :none, nothing, getrf_wrapper!, randn(5, 5), true], - Any[false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 1:5, 1:5), false], - Any[false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 1:5, 1:5), true], - Any[false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 2:7, 3:8), false], - Any[false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 3:8, 2:7), true], + (false, :none, nothing, getrf_wrapper!, randn(5, 5), false), + (false, :none, nothing, getrf_wrapper!, randn(5, 5), true), + (false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 1:5, 1:5), false), + (false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 1:5, 1:5), true), + (false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 2:7, 3:8), false), + (false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 3:8, 2:7), true), ], # trtrs @@ -451,7 +452,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) As = [randn(N, N) + 10I, view(randn(15, 15) + 10I, 2:N+1, 2:N+1)] Bs = [randn(N, Nrhs), view(randn(15, 15), 4:N+3, 3:N+2)] return map(product(As, Bs)) do (A, B) - Any[false, :none, nothing, trtrs!, ul, tA, diag, A, B] + (false, :none, nothing, trtrs!, ul, tA, diag, A, B) end end, )), @@ -466,7 +467,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) ]) Bs = [randn(N, Nrhs), view(randn(15, 15), 4:N+3, 3:Nrhs+2)] return map(product(As, Bs)) do ((A, ipiv), B) - Any[false, :none, nothing, getrs!, trans, A, ipiv, B] + (false, :none, nothing, getrs!, trans, A, ipiv, B) end end, )), @@ -478,7 +479,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) As = getrf!.([randn(N, N) + 5I, view(randn(15, 15) + I, 2:N+1, 2:N+1)]) As = getrf!.([randn(N, N) + 5I]) return map(As) do (A, ipiv) - Any[false, :none, nothing, getri!, A, ipiv] + (false, :none, nothing, getri!, A, ipiv) end end, )), @@ -489,9 +490,9 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) map([1, 3, 9]) do N X = randn(N, N) A = X * X' + I - return [ - Any[false, :none, nothing, potrf!, 'L', A], - Any[false, :none, nothing, potrf!, 'U', A], + return Any[ + (false, :none, nothing, potrf!, 'L', A), + (false, :none, nothing, potrf!, 'U', A), ] end, )), @@ -503,9 +504,9 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) X = randn(N, N) A = X * X' + I B = randn(N, Nrhs) - return [ - Any[false, :none, nothing, potrs!, 'L', potrf!('L', copy(A))[1], copy(B)], - Any[false, :none, nothing, potrs!, 'U', potrf!('U', copy(A))[1], copy(B)], + return Any[ + (false, :none, nothing, potrs!, 'L', potrf!('L', copy(A))[1], copy(B)), + (false, :none, nothing, potrs!, 'U', potrf!('U', copy(A))[1], copy(B)), ] end, )), diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index 7c358e07..39cbeb5a 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -5,33 +5,48 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) continue # Skip rules for methods not defined in the current scope end (f == :rem2pi || f == :ldexp) && continue # not designed for Float64s - (f == :+ || f == :*) && continue # use intrinsics instead + (f in [:+, :*, :sin, :cos]) && continue # use intrinsics instead + P = Float64 if arity == 1 dx = DiffRules.diffrule(M, f, :x) pb_name = Symbol("$(M).$(f)_pb!!") @eval begin - @is_primitive MinimalCtx Tuple{typeof($M.$f), Float64} - function rrule!!(::CoDual{typeof($M.$f)}, x::CoDual{Float64}) - x = primal(x) - $pb_name(ȳ, f̄, x̄) = f̄, x̄ + ȳ * $dx - return CoDual(($M.$f)(x), zero(Float64)), $pb_name + @is_primitive MinimalCtx Tuple{typeof($M.$f), $P} + function rrule!!(::CoDual{typeof($M.$f)}, _x::CoDual{$P}) + x = primal(_x) # needed for dx expression + $pb_name(ȳ) = NoRData(), ȳ * $dx + return CoDual(($M.$f)(x), NoFData()), $pb_name end end elseif arity == 2 da, db = DiffRules.diffrule(M, f, :a, :b) pb_name = Symbol("$(M).$(f)_pb!!") @eval begin - @is_primitive MinimalCtx Tuple{typeof($M.$f), Float64, Float64} - function rrule!!(::CoDual{typeof($M.$f)}, a::CoDual{Float64}, b::CoDual{Float64}) - a = primal(a) - b = primal(b) - $pb_name(ȳ, f̄, ā, b̄) = f̄, ā + ȳ * $da, b̄ + ȳ * $db - return CoDual(($M.$f)(a, b), zero(Float64)), $pb_name + @is_primitive MinimalCtx Tuple{typeof($M.$f), $P, $P} + function rrule!!(::CoDual{typeof($M.$f)}, _a::CoDual{$P}, _b::CoDual{$P}) + a = primal(_a) + b = primal(_b) + $pb_name(ȳ) = NoRData(), ȳ * $da, ȳ * $db + return CoDual(($M.$f)(a, b), NoFData()), $pb_name end end end end +@is_primitive MinimalCtx Tuple{typeof(sin), Float64} +function rrule!!(::CoDual{typeof(sin), NoFData}, x::CoDual{Float64, NoFData}) + s, c = sincos(primal(x)) + sin_pullback!!(dy::Float64) = NoRData(), dy * c + return CoDual(s, NoFData()), sin_pullback!! +end + +@is_primitive MinimalCtx Tuple{typeof(cos), Float64} +function rrule!!(::CoDual{typeof(cos), NoFData}, x::CoDual{Float64, NoFData}) + s, c = sincos(primal(x)) + cos_pullback!!(dy::Float64) = NoRData(), -dy * s + return CoDual(c, NoFData()), cos_pullback!! +end + rand_inputs(rng, f, arity) = randn(rng, arity) rand_inputs(rng, ::typeof(acosh), _) = (rand(rng) + 1 + 1e-3, ) rand_inputs(rng, ::typeof(asech), _) = (rand(rng) * 0.9, ) diff --git a/src/rrules/misc.jl b/src/rrules/misc.jl index 226f60e6..df4061d6 100644 --- a/src/rrules/misc.jl +++ b/src/rrules/misc.jl @@ -28,17 +28,18 @@ for name in [ :(Base.depwarn), :(Base.reduced_indices), :(Base.check_reducedims), + :(Base.throw_boundserror), + :(Base.Broadcast.eltypes), ] @eval @is_primitive DefaultCtx Tuple{typeof($name), Vararg} - @eval function rrule!!(::CoDual{_typeof($name)}, args::CoDual...) - v = $name(map(primal, args)...) - return CoDual(v, zero_tangent(v)), NoPullback() + @eval function rrule!!(f::CoDual{_typeof($name)}, args::Vararg{CoDual, N}) where {N} + return zero_fcodual($name(map(primal, args)...)), NoPullback(f, args...) end end @is_primitive MinimalCtx Tuple{Type, TypeVar, Type} function rrule!!(x::CoDual{<:Type}, y::CoDual{<:TypeVar}, z::CoDual{<:Type}) - return CoDual(primal(x)(primal(y), primal(z)), NoTangent()), NoPullback() + return zero_fcodual(primal(x)(primal(y), primal(z))), NoPullback(x, y, z) end """ @@ -59,40 +60,67 @@ This approach is identical to the one taken by `Zygote.jl` to circumvent the sam """ lgetfield(x, ::Val{f}) where {f} = getfield(x, f) -@is_primitive MinimalCtx Tuple{typeof(lgetfield), Any, Any} -function rrule!!(::CoDual{typeof(lgetfield)}, x::CoDual, ::CoDual{Val{f}}) where {f} - lgetfield_pb!!(dy, df, dx, dsym) = df, increment_field!!(dx, dy, Val{f}()), dsym - y = CoDual(getfield(primal(x), f), _get_tangent_field(primal(x), tangent(x), f)) - return y, lgetfield_pb!! +@is_primitive MinimalCtx Tuple{typeof(lgetfield), Any, Val} +@inline function rrule!!( + ::CoDual{typeof(lgetfield)}, x::CoDual{P}, ::CoDual{Val{f}} +) where {P, f} + pb!! = if ismutabletype(P) + dx = tangent(x) + function mutable_lgetfield_pb!!(dy) + increment_field_rdata!(dx, dy, Val{f}()) + return NoRData(), NoRData(), NoRData() + end + else + dx_r = LazyZeroRData(primal(x)) + field = Val{f}() + function immutable_lgetfield_pb!!(dy) + return NoRData(), increment_field!!(instantiate(dx_r), dy, field), NoRData() + end + end + y = CoDual(getfield(primal(x), f), _get_fdata_field(primal(x), tangent(x), f)) + return y, pb!! end -# Specialise for non-differentiable arguments. -function rrule!!( - ::CoDual{typeof(lgetfield)}, x::CoDual{<:Any, NoTangent}, ::CoDual{Val{f}} -) where {f} - return uninit_codual(getfield(primal(x), f)), NoPullback() +@inline _get_fdata_field(_, t::Union{Tuple, NamedTuple}, f...) = getfield(t, f...) +@inline _get_fdata_field(_, data::FData, f...) = val(getfield(data.data, f...)) +@inline _get_fdata_field(primal, ::NoFData, f...) = uninit_fdata(getfield(primal, f...)) +@inline _get_fdata_field(_, t::MutableTangent, f...) = fdata(val(getfield(t.fields, f...))) + +increment_field_rdata!(dx::MutableTangent, ::NoRData, ::Val) = dx +increment_field_rdata!(dx::NoFData, ::NoRData, ::Val) = dx +function increment_field_rdata!(dx::T, dy_rdata, ::Val{f}) where {T<:MutableTangent, f} + set_tangent_field!(dx, f, increment_rdata!!(get_tangent_field(dx, f), dy_rdata)) + return dx end +# +# lgetfield with order argument +# + +# This is largely copy + pasted from the above. Attempts were made to refactor to avoid +# code duplication, but it wound up not being any cleaner than this copy + pasted version. + lgetfield(x, ::Val{f}, ::Val{order}) where {f, order} = getfield(x, f, order) -@is_primitive MinimalCtx Tuple{typeof(lgetfield), Any, Any, Any} -function rrule!!( - ::CoDual{typeof(lgetfield)}, x::CoDual, ::CoDual{Val{f}}, ::CoDual{Val{order}} -) where {f, order} - function lgetfield_pb!!(dy, df, dx, dsym, dorder) - return df, increment_field!!(dx, dy, Val{f}()), dsym, dorder +@is_primitive MinimalCtx Tuple{typeof(lgetfield), Any, Val, Val} +@inline function rrule!!( + ::CoDual{typeof(lgetfield)}, x::CoDual{P}, ::CoDual{Val{f}}, ::CoDual{Val{order}} +) where {P, f, order} + pb!! = if ismutabletype(P) + dx = tangent(x) + function mutable_lgetfield_pb!!(dy) + increment_field_rdata!(dx, dy, Val{f}()) + return NoRData(), NoRData(), NoRData(), NoRData() + end + else + dx_r = LazyZeroRData(primal(x)) + function immutable_lgetfield_pb!!(dy) + tmp = increment_field!!(instantiate(dx_r), dy, Val{f}()) + return NoRData(), tmp, NoRData(), NoRData() + end end - y = CoDual(getfield(primal(x), f), _get_tangent_field(primal(x), tangent(x), f)) - return y, lgetfield_pb!! -end - -function rrule!!( - ::CoDual{typeof(lgetfield)}, - x::CoDual{<:Any, NoTangent}, - ::CoDual{Val{f}}, - ::CoDual{Val{order}}, -) where {f, order} - return uninit_codual(getfield(primal(x), f)), NoPullback() + y = CoDual(getfield(primal(x), f, order), _get_fdata_field(primal(x), tangent(x), f)) + return y, pb!! end """ @@ -107,40 +135,34 @@ setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v) lsetfield!(value, ::Val{name}, x) where {name} = setfield!(value, name, x) @is_primitive MinimalCtx Tuple{typeof(lsetfield!), Any, Any, Any} -function rrule!!( - ::CoDual{typeof(lsetfield!)}, value::CoDual, ::CoDual{Val{name}}, x::CoDual -) where {name} +@inline function rrule!!( + ::CoDual{typeof(lsetfield!)}, value::CoDual{P}, ::CoDual{Val{name}}, x::CoDual +) where {P, name} + F = fdata_type(tangent_type(P)) save = isdefined(primal(value), name) old_x = save ? getfield(primal(value), name) : nothing - old_dx = save ? val(getfield(tangent(value).fields, name)) : nothing - function setfield!_pullback(dy, df, dvalue, dname, dx) - new_dx = increment!!(dx, val(getfield(dvalue.fields, name))) - new_dx = increment!!(new_dx, dy) - old_x !== nothing && lsetfield!(primal(value), Val(name), old_x) - old_x !== nothing && set_tangent_field!(tangent(value), name, old_dx) - return df, dvalue, dname, new_dx + old_dx = if F == NoFData + NoFData() + else + save ? val(getfield(tangent(value).fields, name)) : nothing end - y = CoDual( - lsetfield!(primal(value), Val(name), primal(x)), - set_tangent_field!(tangent(value), name, tangent(x)), - ) - return y, setfield!_pullback -end - -function rrule!!( - ::CoDual{typeof(lsetfield!)}, - value::CoDual{<:Any, NoTangent}, - ::CoDual{Val{name}}, - x::CoDual, -) where {name} - save = isdefined(primal(value), name) - old_x = save ? getfield(primal(value), name) : nothing - function setfield!_pullback(dy, df, dvalue, dname, dx) - old_x !== nothing && lsetfield!(primal(value), Val(name), old_x) - return df, dvalue, dname, dx + dvalue = tangent(value) + pb!! = if F == NoFData + function __setfield!_pullback(dy) + old_x !== nothing && lsetfield!(primal(value), Val(name), old_x) + return NoRData(), NoRData(), NoRData(), dy + end + else + function setfield!_pullback(dy) + new_dx = increment!!(dy, rdata(val(getfield(dvalue.fields, name)))) + old_x !== nothing && lsetfield!(primal(value), Val(name), old_x) + old_x !== nothing && set_tangent_field!(dvalue, name, old_dx) + return NoRData(), NoRData(), NoRData(), new_dx + end end - y = CoDual(lsetfield!(primal(value), Val(name), primal(x)), NoTangent()) - return y, setfield!_pullback + yf = F == NoFData ? NoFData() : fdata(set_tangent_field!(dvalue, name, zero_tangent(primal(x), tangent(x)))) + y = CoDual(lsetfield!(primal(value), Val(name), primal(x)), yf) + return y, pb!! end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) @@ -150,7 +172,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) _dx = Ref(4.0) memory = Any[_x, _dx] - test_cases = Any[ + specific_test_cases = Any[ # Rules to avoid pointer type conversions. ( true, :stability, nothing, @@ -191,18 +213,6 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) ), (false, :allocs, nothing, Threads.nthreads), - # Literal replacements for getfield. - (false, :stability_and_allocs, nothing, lgetfield, (5.0, 4), Val(1)), - (false, :stability_and_allocs, nothing, lgetfield, (5.0, 4), Val(2)), - (false, :stability_and_allocs, nothing, lgetfield, (1, 4), Val(2)), - (false, :stability_and_allocs, nothing, lgetfield, ((), 4), Val(2)), - (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(1)), - (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(2)), - (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(:a)), - (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(:b)), - (false, :stability_and_allocs, nothing, lgetfield, 1:5, Val(:start)), - (false, :stability_and_allocs, nothing, lgetfield, 1:5, Val(:stop)), - # Literal replacement for setfield!. ( false, :stability_and_allocs, nothing, @@ -221,6 +231,84 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) lsetfield!, NonDifferentiableFoo(5, false), Val(:y), true, ) ] + + # Some specific test cases for lgetfield to test the basics. + specific_lgetfield_test_cases = Any[ + + # Tuple + (false, :stability_and_allocs, nothing, lgetfield, (5.0, 4), Val(1)), + (false, :stability_and_allocs, nothing, lgetfield, (5.0, 4), Val(2)), + (false, :stability_and_allocs, nothing, lgetfield, (1, 4), Val(2)), + (false, :stability_and_allocs, nothing, lgetfield, ((), 4), Val(2)), + (false, :stability_and_allocs, nothing, lgetfield, (randn(2),), Val(1)), + (false, :stability_and_allocs, nothing, lgetfield, (randn(2), 5), Val(1)), + (false, :stability_and_allocs, nothing, lgetfield, (randn(2), 5), Val(2)), + + # NamedTuple + (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(1)), + (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(2)), + (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(:a)), + (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(:b)), + (false, :stability_and_allocs, nothing, lgetfield, (y=randn(2),), Val(1)), + (false, :stability_and_allocs, nothing, lgetfield, (y=randn(2),), Val(:y)), + (false, :stability_and_allocs, nothing, lgetfield, (y=randn(2), x=5), Val(1)), + (false, :stability_and_allocs, nothing, lgetfield, (y=randn(2), x=5), Val(2)), + (false, :stability_and_allocs, nothing, lgetfield, (y=randn(2), x=5), Val(:y)), + (false, :stability_and_allocs, nothing, lgetfield, (y=randn(2), x=5), Val(:x)), + + # structs + (false, :stability_and_allocs, nothing, lgetfield, 1:5, Val(:start)), + (false, :stability_and_allocs, nothing, lgetfield, 1:5, Val(:stop)), + (true, :none, (lb=1, ub=100), lgetfield, StructFoo(5.0), Val(:a)), + (false, :none, (lb=1, ub=100), lgetfield, StructFoo(5.0, randn(5)), Val(:a)), + (false, :none, (lb=1, ub=100), lgetfield, StructFoo(5.0, randn(5)), Val(:b)), + (true, :none, (lb=1, ub=100), lgetfield, StructFoo(5.0), Val(1)), + (false, :none, (lb=1, ub=100), lgetfield, StructFoo(5.0, randn(5)), Val(1)), + (false, :none, (lb=1, ub=100), lgetfield, StructFoo(5.0, randn(5)), Val(2)), + + # mutable structs + (true, :none, nothing, lgetfield, MutableFoo(5.0), Val(:a)), + (false, :none, nothing, lgetfield, MutableFoo(5.0, randn(5)), Val(:b)), + (false, :none, nothing, lgetfield, UInt8, Val(:name)), + (false, :none, nothing, lgetfield, UInt8, Val(:super)), + (true, :none, nothing, lgetfield, UInt8, Val(:layout)), + (false, :none, nothing, lgetfield, UInt8, Val(:hash)), + (false, :none, nothing, lgetfield, UInt8, Val(:flags)), + ] + + # Create `lgetfield` tests for each type in TestTypes in order to increase coverage. + general_lgetfield_test_cases = map(TestTypes.PRIMALS) do (interface_only, P, args) + _, primal = TestTypes.instantiate((interface_only, P, args)) + names = fieldnames(P)[1:length(args)] # only query fields which get initialised + return Any[ + (interface_only, :none, nothing, lgetfield, primal, Val(name)) for + name in names + ] + end + + # lgetfield has both 3 and 4 argument forms. Create test cases for both scenarios. + all_lgetfield_test_cases = Any[ + (case..., order...) for + case in vcat(specific_lgetfield_test_cases, general_lgetfield_test_cases...) for + order in Any[(), (Val(false), )] + ] + + # Create `lsetfield` testsfor each type in TestTypes in order to increase coverage. + general_lsetfield_test_cases = map(TestTypes.PRIMALS) do (interface_only, P, args) + ismutabletype(P) || return Any[] + _, primal = TestTypes.instantiate((interface_only, P, args)) + names = fieldnames(P)[1:length(args)] # only query fields which get initialised + return Any[ + (interface_only, :none, nothing, lsetfield!, primal, Val(name), args[n]) for + (n, name) in enumerate(names) + ] + end + + test_cases = vcat( + specific_test_cases, + all_lgetfield_test_cases..., + general_lsetfield_test_cases..., + ) return test_cases, memory end diff --git a/src/rrules/new.jl b/src/rrules/new.jl index f2497011..6c3a4399 100644 --- a/src/rrules/new.jl +++ b/src/rrules/new.jl @@ -1,47 +1,87 @@ -for N in 0:32 - @eval @inline function _new_(::Type{T}, x::Vararg{Any, $N}) where {T} - return $(Expr(:new, :T, map(n -> :(x[$n]), 1:N)...)) - end - @eval function _new_pullback!!(dy, d_new_, d_T, dx::Vararg{Any, $N}) - return d_new_, d_T, map((x, y) -> increment!!(x, _value(y)), dx, Tuple(dy.fields))... - end - @eval function _new_pullback!!( - dy::Union{Tuple, NamedTuple}, d_new_, d_T, dx::Vararg{Any, $N} - ) - return d_new_, d_T, map(increment!!, dx, Tuple(dy))... - end - @eval function _new_pullback!!(::NoTangent, d_new_, d_T, dx::Vararg{Any, $N}) - return d_new_, NoTangent(), dx... +@inline @generated function _new_(::Type{T}, x::Vararg{Any, N}) where {T, N} + return Expr(:new, :T, map(n -> :(x[$n]), 1:N)...) +end + +@is_primitive MinimalCtx Tuple{typeof(_new_), Vararg} + +function rrule!!( + f::CoDual{typeof(_new_)}, p::CoDual{Type{P}}, x::Vararg{CoDual, N} +) where {P, N} + y = _new_(P, tuple_map(primal, x)...) + F = fdata_type(tangent_type(P)) + R = rdata_type(tangent_type(P)) + dy = F == NoFData ? NoFData() : build_fdata(P, tuple_map(primal, x), tuple_map(tangent, x)) + pb!! = if ismutabletype(P) + if F == NoFData + NoPullback(f, p, x...) + else + function _mutable_new_pullback!!(::NoRData) + rdatas = tuple_map(rdata ∘ val, Tuple(dy.fields)[1:N]) + return NoRData(), NoRData(), rdatas... + end + end + else + if R == NoRData + NoPullback(f, p, x...) + else + function _new_pullback_for_immutable!!(dy::T) where {T} + data = Tuple(T <: NamedTuple ? dy : dy.data)[1:N] + return NoRData(), NoRData(), map(val, data)... + end + end end - @eval function rrule!!( - ::CoDual{typeof(_new_)}, ::CoDual{Type{P}}, x::Vararg{CoDual, $N} - ) where {P} - y = $(Expr(:new, :P, map(n -> :(primal(x[$n])), 1:N)...)) - T = tangent_type(P) - dy = T == NoTangent ? NoTangent() : build_tangent(P, tuple_map(tangent, x)...) - return CoDual(y, dy), _new_pullback!! + return CoDual(y, dy), pb!! +end + +@generated function build_fdata(::Type{P}, x::Tuple, fdata::Tuple) where {P} + names = fieldnames(P) + fdata_exprs = map(eachindex(names)) do n + F = fdata_field_type(P, n) + if n <= length(fdata.parameters) + data_expr = Expr(:call, __get_data, P, :x, :fdata, n) + return F <: PossiblyUninitTangent ? Expr(:call, F, data_expr) : data_expr + else + return :($F()) + end end + F_out = fdata_type(tangent_type(P)) + return :($F_out(NamedTuple{$names}($(Expr(:call, tuple, fdata_exprs...))))) end -@is_primitive MinimalCtx Tuple{typeof(_new_), Vararg} +# Helper for build_fdata +@inline function __get_data(::Type{P}, x, f, n) where {P} + tmp = getfield(f, n) + return ismutabletype(P) ? zero_tangent(getfield(x, n), tmp) : tmp +end + +@inline function build_fdata(::Type{P}, x::Tuple, fdata::Tuple) where {P<:NamedTuple} + return fdata_type(tangent_type(P))(fdata) +end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:new}) - test_cases = Any[ - (false, :stability_and_allocs, nothing, _new_, Tuple{}), - (false, :stability_and_allocs, nothing, _new_, Tuple{Float64, Int}, 5.0, 4), - (false, :stability_and_allocs, nothing, _new_, Tuple{Float64, Float64}, 5.0, 4.0), - (false, :stability_and_allocs, nothing, _new_, Tuple{Int, Int}, 5, 5), + + # Specialised test cases for _new_. + specific_test_cases = Any[ (false, :stability_and_allocs, nothing, _new_, @NamedTuple{}), (false, :stability_and_allocs, nothing, _new_, @NamedTuple{y::Float64}, 5.0), + (false, :stability_and_allocs, nothing, _new_, @NamedTuple{y::Int, x::Int}, 5, 4), ( false, :stability_and_allocs, nothing, _new_, @NamedTuple{y::Float64, x::Int}, 5.0, 4, ), - (false, :stability_and_allocs, nothing, _new_, @NamedTuple{y::Int, x::Int}, 5, 4), + ( + false, :stability_and_allocs, nothing, + _new_, @NamedTuple{y::Vector{Float64}, x::Int}, randn(2), 4, + ), + ( + false, :stability_and_allocs, nothing, + _new_, @NamedTuple{y::Vector{Float64}}, randn(2), + ), ( false, :stability_and_allocs, nothing, _new_, TestResources.TypeStableStruct{Float64}, 5, 4.0, ), + (false, :stability_and_allocs, nothing, _new_, UnitRange{Int64}, 5, 4), ( false, :stability_and_allocs, nothing, _new_, TestResources.TypeStableMutableStruct{Float64}, 5.0, 4.0, @@ -50,8 +90,33 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:new}) false, :none, nothing, _new_, TestResources.TypeStableMutableStruct{Any}, 5.0, 4.0, ), - (false, :stability_and_allocs, nothing, _new_, UnitRange{Int64}, 5, 4), + (false, :none, nothing, _new_, TestResources.StructFoo, 6.0, [1.0, 2.0]), + (false, :none, nothing, _new_, TestResources.StructFoo, 6.0), + (false, :none, nothing, _new_, TestResources.MutableFoo, 6.0, [1.0, 2.0]), + (false, :none, nothing, _new_, TestResources.MutableFoo, 6.0), + (false, :stability_and_allocs, nothing, _new_, TestResources.StructNoFwds, 5.0), + (false, :stability_and_allocs, nothing, _new_, TestResources.StructNoRvs, [5.0]), + ( + false, :stability_and_allocs, nothing, + _new_, LowerTriangular{Float64, Matrix{Float64}}, randn(2, 2), + ), + ( + false, :stability_and_allocs, nothing, + _new_, UpperTriangular{Float64, Matrix{Float64}}, randn(2, 2), + ), + ( + false, :stability_and_allocs, nothing, + _new_, UnitLowerTriangular{Float64, Matrix{Float64}}, randn(2, 2), + ), + ( + false, :stability_and_allocs, nothing, + _new_, UnitUpperTriangular{Float64, Matrix{Float64}}, randn(2, 2), + ), ] + general_test_cases = map(TestTypes.PRIMALS) do (interface_only, P, args) + return (interface_only, :none, nothing, _new_, P, args...) + end + test_cases = vcat(specific_test_cases, general_test_cases) memory = Any[] return test_cases, memory end @@ -60,4 +125,4 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:new}) test_cases = Any[] memory = Any[] return test_cases, memory -end \ No newline at end of file +end diff --git a/src/safe_mode.jl b/src/safe_mode.jl new file mode 100644 index 00000000..321299a9 --- /dev/null +++ b/src/safe_mode.jl @@ -0,0 +1,145 @@ + +""" + SafePullback(pb, y, x) + +Construct a callable which is equivalent to `pb`, but which enforces type-based pre- and +post-conditions to `pb`. Let `dx = pb.pb(dy)`, for some rdata `dy`, then this function +- checks that `dy` has the correct rdata type for `y`, and +- checks that each element of `dx` has the correct rdata type for `x`. +""" +struct SafePullback{Tpb, Ty, Tx} + pb::Tpb +end + +""" + (pb::SafePullback)(dy) + +Apply type checking to enforce pre- and post-conditions on `pb.pb`. See the docstring for +`SafePullback` for details. +""" +@inline function (pb::SafePullback{Tpb, Ty, Tx})(dy) where {Tpb, Ty, Tx} + verify_rvs_input(Ty, dy) + dx = pb.pb(dy) + verify_rvs_output(Tx, dx) + return dx +end + +@noinline verify_rvs_input(::Type{Ty}, dy) where {Ty} = verify_rvs(Ty, dy) + +@noinline function verify_rvs_output(::Type{Tx}, dx) where {Tx} + @nospecialize pb dx + + # Number of arguments and number of elements in pullback must match. Have to check this + # because `zip` doesn't require equal lengths for arguments. + l_pb = length(Tx.parameters) + l_dx = length(dx) + if l_pb != l_dx + error("Number of args = $l_pb but number of rdata = $l_dx. They must to be equal.") + end + + # Use for-loop to keep stack trace as simple as possible. + for (x, dx) in zip(Tx.parameters, dx) + verify_rvs(x, dx) + end +end + +@noinline function verify_rvs(::Type{P}, dx::R) where {P, R} + _R = rdata_type(tangent_type(P)) + R <: ZeroRData && return nothing + (R <: _R) || throw(ArgumentError("Type $P has rdata type $_R, but got $R.")) +end + +""" + SafeRRule(rule) + +Construct a callable which is equivalent to `rule`, but inserts additional type checking. +In particular: +- check that the fdata in each argument is of the correct type for the primal +- check that the fdata in the `CoDual` returned from the rule is of the correct type for the + primal. + +Additionally, dynamic checks may be performed (e.g. that an fdata array of the same size as +its primal). + +Let `rule` return `y, pb!!`, then `SafeRRule(rule)` returns `y, SafePullback(pb!!)`. +`SafePullback` inserts the same kind of checks as `SafeRRule`, but on the reverse-pass. See +the docstring for details. + +*Note:* at any given point in time, the checks performed by this function constitute a +necessary but insufficient set of conditions to ensure correctness. If you find that an +error isn't being caught by these tests, but you believe it ought to be, please open an +issue or (better still) a PR. +""" +struct SafeRRule{Trule} + rule::Trule +end + +""" + (rule::SafeRRule)(x::CoDual...) + +Apply type checking to enforce pre- and post-conditions on `rule.rule`. See the docstring +for `SafeRRule` for details. +""" +@inline function (rule::SafeRRule)(x::Vararg{CoDual, N}) where {N} + verify_fwds_inputs(x) + y, pb = rule.rule(x...) + verify_fwds_output(x, y) + return y::CoDual, SafePullback{_typeof(pb), _typeof(primal(y)), Tuple{tuple_map(_typeof ∘ primal, x)...}}(pb) +end + +@noinline function verify_fwds_inputs(@nospecialize(x::Tuple)) + try + # Use for-loop to keep the stack trace as simple as possible. + for _x in x + verify_fwds(_x) + end + catch e + error("error in inputs to rule with input types $(_typeof(x))") + end +end + +@noinline function verify_fwds_output(@nospecialize(x), @nospecialize(y)) + try + verify_fwds(y) + catch e + error("error in outputs of rule with input types $(_typeof(x))") + end +end + +@noinline function verify_fwds(x::CoDual{P, F}) where {P, F} + _fdata_type_checker(P, F) + verify_fwds_values(primal(x), tangent(x)) +end + +function verify_fwds_values(p::P, f::F) where {P, F} + _fdata_type_checker(P, F) + if F == NoFData + return + elseif P <: Array + if size(p) != size(f) + throw(ArgumentError("size of P is $(size(p)) but size of F is $(size(f))")) + end + isbitstype(eltype(P)) && return + for n in eachindex(p) + !isassigned(p, n) && continue + Fn = _typeof(f[n]) + Pn = _typeof(p[n]) + Tn = tangent_type(Pn) + if Fn != Tn + throw(ArgumentError( + "the type of each element of an fdata Array must be the tangent_type " * + "of the corresponding element of the primal array. Found that " * + "element $n of fdata array is of type $Fn, while primal is of " * + "type $Pn, whose tangent type is $Tn.", + )) + end + end + elseif isstructtype(P) + return + end +end + +function _fdata_type_checker(P, F) + _F = fdata_type(tangent_type(P)) + F == _F || throw(ArgumentError("type $P has fdata type $_F, but got $F.")) +end diff --git a/src/stack.jl b/src/stack.jl index 4d29a268..ec83ddaf 100644 --- a/src/stack.jl +++ b/src/stack.jl @@ -62,39 +62,6 @@ end Base.eltype(::Stack{T}) where {T} = T -top_ref(x::Stack) = Ref(getfield(x, :memory), getfield(x, :position)) - -""" - NoTangentStack() - -If a type has `NoTangent` as its tangent type, it should use one of these stacks. -Probably needs to be generalised to an inactive-tangent stack in future, as we also need to -handle constants, which aren't always active. -""" -struct NoTangentStack end - -Base.push!(::NoTangentStack, ::Any) = nothing -Base.getindex(::NoTangentStack) = NoTangent() -Base.setindex!(::NoTangentStack, ::NoTangent) = nothing -Base.pop!(::NoTangentStack) = NoTangent() - -struct NoTangentRef <: Ref{NoTangent} end - -Base.getindex(::NoTangentRef) = NoTangent() -Base.setindex!(::NoTangentRef, ::NoTangent) = nothing - -top_ref(::NoTangentStack) = NoTangentRef() - -""" - NoTangentRefStack - -Stack for `NoTangentRef`s. -""" -struct NoTangentRefStack end - -Base.push!(::NoTangentRefStack, ::Any) = nothing -Base.pop!(::NoTangentRefStack) = NoTangentRef() - struct SingletonStack{T} end @@ -102,46 +69,7 @@ Base.push!(::SingletonStack, ::Any) = nothing @generated Base.pop!(::SingletonStack{T}) where {T} = T.instance -function tangent_stack_type(::Type{P}) where {P} - P === DataType && return Stack{Any} - T = tangent_type(P) - return T === NoTangent ? NoTangentStack : Stack{T} -end - -__array_ref_type(::Type{P}) where {P} = Base.RefArray{P, Vector{P}, Nothing} - -function tangent_ref_type_ub(::Type{P}) where {P} - P === DataType && return Ref - T = tangent_type(P) - T === NoTangent && return NoTangentRef - return isconcretetype(P) ? __array_ref_type(T) : Ref -end - -tangent_ref_type_ub(::Type{Type{P}}) where {P} = NoTangentRef - -struct InactiveStack{T} - zero_tangent::T -end - -Base.pop!(s::InactiveStack{T}) where {T} = s.zero_tangent - -struct InactiveRef{T} - x::T -end - -Base.getindex(x::InactiveRef{T}) where {T} = x.x - -increment_ref!(::InactiveRef{T}, ::T) where {T} = nothing - -top_ref(::Nothing) = InactiveRef(nothing) - -increment_ref!(::InactiveRef{Nothing}, ::T) where {T} = nothing - - -struct FixedStackTangentRefStack{T} - x::Stack{T} +function reverse_data_ref_type(::Type{P}) where {P} + P === DataType && return Ref{Any} + return Base.RefValue{rdata_type(tangent_type(P))} end - -Base.push!(x::FixedStackTangentRefStack, t) = nothing - -Base.pop!(x::FixedStackTangentRefStack) = top_ref(x.x) diff --git a/src/tangents.jl b/src/tangents.jl index c8b740bf..21e966c6 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -23,8 +23,6 @@ end @inline PossiblyUninitTangent(tangent::T) where {T} = PossiblyUninitTangent{T}(tangent) @inline PossiblyUninitTangent(T::Type) = PossiblyUninitTangent{T}() -const __PUT = PossiblyUninitTangent - @inline is_init(t::PossiblyUninitTangent) = isdefined(t, :tangent) is_init(t) = true @@ -49,6 +47,8 @@ _wrap_type(::Type{T}) where {T} = PossiblyUninitTangent{T} _wrap_field(::Type{Q}, x::T) where {Q, T} = PossiblyUninitTangent{Q}(x) _wrap_field(x::T) where {T} = _wrap_field(T, x) +Base.eltype(::Type{PossiblyUninitTangent{T}}) where {T} = T + struct Tangent{Tfields<:NamedTuple} fields::Tfields end @@ -61,6 +61,10 @@ end Base.:(==)(x::MutableTangent, y::MutableTangent) = x.fields == y.fields +fields_type(::Type{MutableTangent{Tfields}}) where {Tfields<:NamedTuple} = Tfields +fields_type(::Type{Tangent{Tfields}}) where {Tfields<:NamedTuple} = Tfields +fields_type(::Type{<:Union{MutableTangent, Tangent}}) = NamedTuple + const PossiblyMutableTangent{T} = Union{MutableTangent{T}, Tangent{T}} """ @@ -72,9 +76,9 @@ Has the same semantics that `setfield!` would have if the data in the `fields` f were actually fields of `t`. This is the moral equivalent of `getfield` for `MutableTangent`. """ -function get_tangent_field(t::PossiblyMutableTangent{Tfields}, i::Int) where {Tfields} +@inline function get_tangent_field(t::PossiblyMutableTangent{Tfs}, i::Int) where {Tfs} v = getfield(t.fields, i) - return fieldtype(Tfields, i) <: PossiblyUninitTangent ? val(v) : v + return fieldtype(Tfs, i) <: PossiblyUninitTangent ? val(v) : v end @inline function get_tangent_field(t::PossiblyMutableTangent{F}, s::Symbol) where {F} @@ -143,9 +147,6 @@ function __tangent_from_non_concrete(::Type{P}, fields) where {names, P<:NamedTu return NamedTuple{names}(fields) end -_value(v::PossiblyUninitTangent) = val(v) -_value(v) = v - """ tangent_type(T) @@ -154,9 +155,7 @@ determined entirely by its type. """ tangent_type(T) -function tangent_type(x) - throw(error("$x is not a type. Perhaps you meant typeof(x)?")) -end +tangent_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?")) # This is essential for DataType, as the recursive definition always recurses infinitely, # because one of the fieldtypes is itself always a DataType. In particular, we'll always @@ -207,11 +206,16 @@ tangent_type(::Type{Core.TypeName}) = NoTangent tangent_type(::Type{Core.MethodTable}) = NoTangent +tangent_type(::Type{DimensionMismatch}) = NoTangent + +tangent_type(::Type{Method}) = NoTangent + @generated function tangent_type(::Type{P}) where {P<:Tuple} isa(P, Union) && return Union{tangent_type(P.a), tangent_type(P.b)} isempty(P.parameters) && return NoTangent isa(last(P.parameters), Core.TypeofVararg) && return Any all(p -> tangent_type(p) == NoTangent, P.parameters) && return NoTangent + isconcretetype(P) || return Any return Tuple{map(tangent_type, fieldtypes(P))...} end @@ -267,59 +271,6 @@ function tangent_field_type(::Type{P}, n::Int) where {P} return is_always_initialised(P, n) ? t : _wrap_type(t) end -""" - is_always_initialised(::Type{P}, n::Int) - -True if the `n`th field of `P` is always initialised. If the `n`th fieldtype of `P` -`isbitstype`, then this is distinct from asking whether the `n`th field is always defined. -An isbits field is always defined, but is not always explicitly initialised. -""" -function is_always_initialised(::Type{P}, n::Int) where {P} - return n <= Core.Compiler.datatype_min_ninitialized(P) -end - -""" - is_always_fully_initialised(::Type{P}) where {P} - -True if all fields in `P` are always initialised. Put differently, there are no inner -constructors which permit partial initialisation. -""" -function is_always_fully_initialised(::Type{P}) where {P} - return Core.Compiler.datatype_min_ninitialized(P) == fieldcount(P) -end - -function _map_if_assigned!(f::F, y::Array, x::Array{P}) where {F, P} - @assert size(y) == size(x) - @inbounds for n in eachindex(y) - if isbitstype(P) || isassigned(x, n) - y[n] = f(x[n]) - end - end - return y -end - -function _map_if_assigned!(f::F, y::Array, x1::Array{P}, x2::Array) where {F, P} - @assert size(y) == size(x1) - @assert size(y) == size(x2) - @inbounds for n in eachindex(y) - if isbitstype(P) || isassigned(x1, n) - y[n] = f(x1[n], x2[n]) - end - end - return y -end - -""" - _map(f, x...) - -Same as `map` but requires all elements of `x` to have equal length. -The usual function `map` doesn't enforce this for `Array`s. -""" -@inline function _map(f::F, x::Vararg{Any, N}) where {F, N} - @assert allequal(map(length, x)) - return map(f, x...) -end - """ zero_tangent(x) @@ -337,7 +288,7 @@ end return _map_if_assigned!(zero_tangent, Array{tangent_type(P), N}(undef, size(x)...), x) end @inline function zero_tangent(x::P) where {P<:Union{Tuple, NamedTuple}} - return tangent_type(P) == NoTangent ? NoTangent() : map(zero_tangent, x) + return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(zero_tangent, x) end @generated function zero_tangent(x::P) where {P} @@ -391,7 +342,8 @@ function randn_tangent(rng::AbstractRNG, x::SimpleVector) end end function randn_tangent(rng::AbstractRNG, x::P) where {P <: Union{Tuple, NamedTuple}} - return tangent_type(P) == NoTangent ? NoTangent() : map(x -> randn_tangent(rng, x), x) + tangent_type(P) == NoTangent && return NoTangent() + return tuple_map(x -> randn_tangent(rng, x), x) end function randn_tangent(rng::AbstractRNG, x::T) where {T<:Union{Tangent, MutableTangent}} return T(randn_tangent(rng, x.fields)) @@ -433,7 +385,7 @@ increment!!(x::Ptr{T}, y::Ptr{T}) where {T} = x === y ? x : throw(error("eurgh") function increment!!(x::T, y::T) where {P, N, T<:Array{P, N}} return x === y ? x : _map_if_assigned!(increment!!, x, x, y) end -increment!!(x::T, y::T) where {T<:Tuple} = _map(increment!!, x, y) +increment!!(x::T, y::T) where {T<:Tuple} = tuple_map(increment!!, x, y) increment!!(x::T, y::T) where {T<:NamedTuple} = T(increment!!(Tuple(x), Tuple(y))) function increment!!(x::T, y::T) where {T<:PossiblyUninitTangent} is_init(x) && is_init(y) && return T(increment!!(val(x), val(y))) @@ -441,7 +393,7 @@ function increment!!(x::T, y::T) where {T<:PossiblyUninitTangent} !is_init(x) && is_init(y) && error("x is not initialised, but y is") return x end -increment!!(x::T, y::T) where {T<:Tangent} = Tangent(increment!!(x.fields, y.fields)) +increment!!(x::T, y::T) where {T<:Tangent} = T(increment!!(x.fields, y.fields)) function increment!!(x::T, y::T) where {T<:MutableTangent} x === y && return x x.fields = increment!!(x.fields, y.fields) @@ -466,59 +418,6 @@ function set_to_zero!!(x::MutableTangent) return x end -""" - set_immutable_to_zero(x::T) where {T} - -Return a `T` whose immutable components are zero, and whose mutable components are `===` to -`x`. Please consult implementation for details. -""" -set_immutable_to_zero(x::NoTangent) = NoTangent() -set_immutable_to_zero(x::Base.IEEEFloat) = zero(x) -set_immutable_to_zero(x::Union{Tuple, NamedTuple}) = map(set_immutable_to_zero, x) -set_immutable_to_zero(x::Array) = x -function set_immutable_to_zero(x::T) where {T<:PossiblyUninitTangent} - return is_init(x) ? T(set_immutable_to_zero(val(x))) : x -end -set_immutable_to_zero(x::Tangent) = Tangent(set_immutable_to_zero(x.fields)) -set_immutable_to_zero(x::MutableTangent) = x -set_immutable_to_zero(x::Ptr) = x - -""" - increment_field!!(x::T, y::V, f) where {T, V} - -`increment!!` the field `f` of `x` by `y`, and return the updated `x`. -""" -@inline @generated function increment_field!!(x::Tuple, y, ::Val{i}) where {i} - exprs = map(n -> n == i ? :(increment!!(x[$n], y)) : :(x[$n]), fieldnames(x)) - return Expr(:tuple, exprs...) -end - -@inline @generated function increment_field!!(x::T, y, ::Val{f}) where {T<:NamedTuple, f} - i = f isa Symbol ? findfirst(==(f), fieldnames(T)) : f - new_fields = Expr(:call, increment_field!!, :(Tuple(x)), :y, :(Val($i))) - return Expr(:call, T, new_fields) -end - -function increment_field!!(x::Tangent{T}, y, f::Val{F}) where {T, F} - y isa NoTangent && return x - new_val = fieldtype(T, F) <: PossiblyUninitTangent ? fieldtype(T, F)(y) : y - return Tangent(increment_field!!(x.fields, new_val, f)) -end -function increment_field!!(x::MutableTangent{T}, y, f::V) where {T, F, V<:Val{F}} - y isa NoTangent && return x - new_val = fieldtype(T, F) <: PossiblyUninitTangent ? fieldtype(T, F)(y) : y - setfield!(x, :fields, increment_field!!(x.fields, new_val, f)) - return x -end - -increment_field!!(x, y, f::Symbol) = increment_field!!(x, y, Val(f)) -increment_field!!(x, y, n::Int) = increment_field!!(x, y, Val(n)) - -# Fallback method for when a tangent type for a struct is declared to be `NoTangent`. -for T in [Symbol, Int, Val] - @eval increment_field!!(::NoTangent, ::NoTangent, f::Union{$T}) = NoTangent() -end - """ _scale(a::Float64, t::T) where {T} @@ -645,14 +544,222 @@ function _containerlike_diff(p::P, q::P) where {P} return build_tangent(P, diffed_fields...) end -@generated function might_be_active(::Type{P}) where {P} - tangent_type(P) == NoTangent && return :(return false) - Base.issingletontype(P) && return :(return false) - Base.isabstracttype(P) && return :(return true) - isprimitivetype(P) && return :(return true) - return :(return $(any(might_be_active, fieldtypes(P)))) +""" + increment_field!!(x::T, y::V, f) where {T, V} + +`increment!!` the field `f` of `x` by `y`, and return the updated `x`. +""" +@inline @generated function increment_field!!(x::Tuple, y, ::Val{i}) where {i} + exprs = map(n -> n == i ? :(increment!!(x[$n], y)) : :(x[$n]), fieldnames(x)) + return Expr(:tuple, exprs...) +end + +# Optimal for homogeneously-typed Tuples with dynamic field choice. +function increment_field!!(x::Tuple, y, i::Int) + return ntuple(n -> n == i ? increment!!(x[n], y) : x[n], length(x)) end -@generated function might_be_active(::Type{<:Array{P}}) where {P} - return :(return $(might_be_active(P))) + +@inline @generated function increment_field!!(x::T, y, ::Val{f}) where {T<:NamedTuple, f} + i = f isa Symbol ? findfirst(==(f), fieldnames(T)) : f + new_fields = Expr(:call, increment_field!!, :(Tuple(x)), :y, :(Val($i))) + return Expr(:call, T, new_fields) +end + +# Optimal for homogeneously-typed NamedTuples with dynamic field choice. +function increment_field!!(x::T, y, i::Int) where {T<:NamedTuple} + return T(increment_field!!(Tuple(x), y, i)) +end +function increment_field!!(x::T, y, s::Symbol) where {T<:NamedTuple} + return T(tuple_map(n -> n == s ? increment!!(x[n], y) : x[n], fieldnames(T))) +end + +function increment_field!!(x::Tangent{T}, y, f::Val{F}) where {T, F} + y isa NoTangent && return x + new_val = fieldtype(T, F) <: PossiblyUninitTangent ? fieldtype(T, F)(y) : y + return Tangent(increment_field!!(x.fields, new_val, f)) +end +function increment_field!!(x::MutableTangent{T}, y, f::V) where {T, F, V<:Val{F}} + y isa NoTangent && return x + new_val = fieldtype(T, F) <: PossiblyUninitTangent ? fieldtype(T, F)(y) : y + setfield!(x, :fields, increment_field!!(x.fields, new_val, f)) + return x +end + +increment_field!!(x, y, f::Symbol) = increment_field!!(x, y, Val(f)) +increment_field!!(x, y, n::Int) = increment_field!!(x, y, Val(n)) + +# Fallback method for when a tangent type for a struct is declared to be `NoTangent`. +for T in [Symbol, Int, Val] + @eval increment_field!!(::NoTangent, ::NoTangent, f::Union{$T}) = NoTangent() +end + +#= + tangent_test_cases() + +Constructs a `Vector` of `Tuple`s containing test cases for the tangent infrastructure. + +If the returned tuple has 2 elements, the elements should be interpreted as follows: +1 - interface_only +2 - primal value + +interface_only is a Bool which will be used to determine which subset of tests to run. + +If the returned tuple has 5 elements, then the elements are interpreted as follows: +1 - interface_only +2 - primal value +3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>). + +Generally speaking, it's very straightforward to produce test cases in the first format, +while the second requires more work. Consequently, at the time of writing there are many +more instances of the first format than the second. + +Test cases in the first format make use of `zero_tangent` / `randn_tangent` etc to generate +tangents, but they're unable to check that `increment!!` is correct in an absolute sense. +=# +function tangent_test_cases() + + N_large = 33 + _names = Tuple(map(n -> Symbol("x$n"), 1:N_large)) + + abs_test_cases = vcat( + [ + (sin, NoTangent(), NoTangent(), NoTangent()), + (map(Float16, (5.0, 4.0, 3.1, 7.1))...), + (5f0, 4f0, 3f0, 7f0), + (5.1, 4.0, 3.0, 7.0), + (svec(5.0), Any[4.0], Any[3.0], Any[7.0]), + ([3.0, 2.0], [1.0, 2.0], [2.0, 3.0], [3.0, 5.0]), + ( + [1, 2], + [NoTangent(), NoTangent()], + [NoTangent(), NoTangent()], + [NoTangent(), NoTangent()], + ), + ( + [[1.0], [1.0, 2.0]], + [[2.0], [2.0, 3.0]], + [[3.0], [4.0, 5.0]], + [[5.0], [6.0, 8.0]], + ), + ( + setindex!(Vector{Vector{Float64}}(undef, 2), [1.0], 1), + setindex!(Vector{Vector{Float64}}(undef, 2), [2.0], 1), + setindex!(Vector{Vector{Float64}}(undef, 2), [3.0], 1), + setindex!(Vector{Vector{Float64}}(undef, 2), [5.0], 1), + ), + ( + setindex!(Vector{Vector{Float64}}(undef, 2), [1.0], 2), + setindex!(Vector{Vector{Float64}}(undef, 2), [2.0], 2), + setindex!(Vector{Vector{Float64}}(undef, 2), [3.0], 2), + setindex!(Vector{Vector{Float64}}(undef, 2), [5.0], 2), + ), + ( + (6.0, [1.0, 2.0]), + (5.0, [3.0, 4.0]), + (4.0, [4.0, 3.0]), + (9.0, [7.0, 7.0]), + ), + ((), NoTangent(), NoTangent(), NoTangent()), + ((1,), NoTangent(), NoTangent(), NoTangent()), + ((2, 3), NoTangent(), NoTangent(), NoTangent()), + ( + Tapir.tuple_fill(5.0, Val(N_large)), + Tapir.tuple_fill(6.0, Val(N_large)), + Tapir.tuple_fill(7.0, Val(N_large)), + Tapir.tuple_fill(13.0, Val(N_large)), + ), + ( + (a=6.0, b=[1.0, 2.0]), + (a=5.0, b=[3.0, 4.0]), + (a=4.0, b=[4.0, 3.0]), + (a=9.0, b=[7.0, 7.0]), + ), + ((;), NoTangent(), NoTangent(), NoTangent()), + ( + NamedTuple{_names}(Tapir.tuple_fill(5.0, Val(N_large))), + NamedTuple{_names}(Tapir.tuple_fill(6.0, Val(N_large))), + NamedTuple{_names}(Tapir.tuple_fill(7.0, Val(N_large))), + NamedTuple{_names}(Tapir.tuple_fill(13.0, Val(N_large))), + ), + ( + TestResources.TypeStableMutableStruct{Float64}(5.0, 3.0), + build_tangent(TestResources.TypeStableMutableStruct{Float64}, 5.0, 4.0), + build_tangent(TestResources.TypeStableMutableStruct{Float64}, 3.0, 3.0), + build_tangent(TestResources.TypeStableMutableStruct{Float64}, 8.0, 7.0), + ), + ( # complete init + TestResources.StructFoo(6.0, [1.0, 2.0]), + build_tangent(TestResources.StructFoo, 5.0, [3.0, 4.0]), + build_tangent(TestResources.StructFoo, 3.0, [2.0, 1.0]), + build_tangent(TestResources.StructFoo, 8.0, [5.0, 5.0]), + ), + ( # partial init + TestResources.StructFoo(6.0), + build_tangent(TestResources.StructFoo, 5.0), + build_tangent(TestResources.StructFoo, 4.0), + build_tangent(TestResources.StructFoo, 9.0), + ), + ( # complete init + TestResources.MutableFoo(6.0, [1.0, 2.0]), + build_tangent(TestResources.MutableFoo, 5.0, [3.0, 4.0]), + build_tangent(TestResources.MutableFoo, 3.0, [2.0, 1.0]), + build_tangent(TestResources.MutableFoo, 8.0, [5.0, 5.0]), + ), + ( # partial init + TestResources.MutableFoo(6.0), + build_tangent(TestResources.MutableFoo, 5.0), + build_tangent(TestResources.MutableFoo, 4.0), + build_tangent(TestResources.MutableFoo, 9.0), + ), + ( + TestResources.StructNoFwds(5.0), + build_tangent(TestResources.StructNoFwds, 5.0), + build_tangent(TestResources.StructNoFwds, 4.0), + build_tangent(TestResources.StructNoFwds, 9.0), + ), + ( + TestResources.StructNoRvs([5.0]), + build_tangent(TestResources.StructNoRvs, [5.0]), + build_tangent(TestResources.StructNoRvs, [4.0]), + build_tangent(TestResources.StructNoRvs, [9.0]), + ), + (UnitRange{Int}(5, 7), NoTangent(), NoTangent(), NoTangent()), + ], + map([ + LowerTriangular{Float64, Matrix{Float64}}, + UpperTriangular{Float64, Matrix{Float64}}, + UnitLowerTriangular{Float64, Matrix{Float64}}, + UnitUpperTriangular{Float64, Matrix{Float64}}, + ]) do T + return ( + T(randn(2, 2)), + build_tangent(T, [1.0 2.0; 3.0 4.0]), + build_tangent(T, [2.0 1.0; 5.0 4.0]), + build_tangent(T, [3.0 3.0; 8.0 8.0]), + ) + end, + [ + (p, NoTangent(), NoTangent(), NoTangent()) for p in + [Array, Float64, Union{Float64, Float32}, Union, UnionAll, + typeof(<:)] + ], + ) + rel_test_cases = Any[ + (2.0, 3), + (3, 2.0), + (2.0, 1.0), + (randn(10), 3), + (3, randn(10)), + (randn(10), randn(10)), + (a=2.0, b=3), + (a=3, b=2.0), + (a=randn(10), b=3), + (a=3, b=randn(10)), + (a=randn(10), b=randn(10)), + ] + return vcat( + map(x -> (false, x...), abs_test_cases), + map(x -> (false, x), rel_test_cases), + map(Tapir.TestTypes.instantiate, Tapir.TestTypes.PRIMALS), + ) end -might_be_active(::Type{SimpleVector}) = true diff --git a/src/test_utils.jl b/src/test_utils.jl index 8541377f..851b65f2 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -13,7 +13,7 @@ using Core: svec using ExprTools: combinedef using ..Tapir: NoTangent, tangent_type, _typeof -const PRIMALS = Tuple{Bool, Any}[] +const PRIMALS = Tuple{Bool, Any, Tuple}[] # Generate all of the composite types against which we might wish to test. function generate_primals() @@ -66,14 +66,16 @@ function generate_primals() t = @eval $name for n in n_always_def:n_fields interface_only = any(x -> isbitstype(x.type), fields[n+1:end]) - p = invokelatest(t, map(x -> deepcopy(x.primal), fields[1:n])...) - push!(PRIMALS, (interface_only, p)) + fields_copies = map(x -> deepcopy(x.primal), fields[1:n]) + push!(PRIMALS, (interface_only, t, fields_copies)) end end end return nothing end +instantiate(test_case) = (test_case[1], test_case[2](test_case[3]...)) + end """ @@ -88,7 +90,9 @@ using JET, Random, Tapir, Test, InteractiveUtils using Tapir: CoDual, NoTangent, rrule!!, is_init, zero_codual, DefaultCtx, @is_primitive, val, is_always_fully_initialised, get_tangent_field, set_tangent_field!, MutableTangent, - Tangent, _typeof + Tangent, _typeof, rdata, NoFData, to_fwds, uninit_fdata, zero_rdata, + zero_rdata_from_type, CannotProduceZeroRDataFromType, LazyZeroRData, instantiate, + can_produce_zero_rdata_from_type, increment_rdata!!, fcodual_type has_equal_data(x::T, y::T; equal_undefs=true) where {T<:String} = x == y has_equal_data(x::Type, y::Type; equal_undefs=true) = x == y @@ -155,6 +159,7 @@ throws an `AssertionError` if the same address is not mapped to in `tangent` eac function populate_address_map!(m::AddressMap, primal::P, tangent::T) where {P, T} isprimitivetype(P) && return m T === NoTangent && return m + T === NoFData && return m if ismutabletype(P) @assert T <: MutableTangent k = pointer_from_objref(primal) @@ -163,7 +168,7 @@ function populate_address_map!(m::AddressMap, primal::P, tangent::T) where {P, T m[k] = v end foreach(fieldnames(P)) do n - t_field = getfield(tangent.fields, n) + t_field = __get_data_field(tangent, n) if isdefined(primal, n) && is_init(t_field) populate_address_map!(m, getfield(primal, n), val(t_field)) elseif isdefined(primal, n) && !is_init(t_field) @@ -175,7 +180,11 @@ function populate_address_map!(m::AddressMap, primal::P, tangent::T) where {P, T return m end +__get_data_field(t::Union{Tangent, MutableTangent}, n) = getfield(t.fields, n) +__get_data_field(t::Union{Tapir.FData, Tapir.RData}, n) = getfield(t.data, n) + function populate_address_map!(m::AddressMap, p::P, t) where {P<:Union{Tuple, NamedTuple}} + t isa NoFData && return m t isa NoTangent && return m foreach(n -> populate_address_map!(m, getfield(p, n), getfield(t, n)), fieldnames(P)) return m @@ -237,9 +246,11 @@ function test_rrule_numerical_correctness(rng::AbstractRNG, f_f̄, x_x̄...; rul # Run `rrule!!` on copies of `f` and `x`. We use randomly generated tangents so that we # can later verify that non-zero values do not get propagated by the rule. - x_x̄_rule = map(x -> CoDual(_deepcopy(x), zero_tangent(x)), x) + x̄_zero = map(zero_tangent, x) + x̄_fwds = map(Tapir.fdata, x̄_zero) + x_x̄_rule = map((x, x̄_f) -> fcodual_type(_typeof(x))(_deepcopy(x), x̄_f), x, x̄_fwds) inputs_address_map = populate_address_map(map(primal, x_x̄_rule), map(tangent, x_x̄_rule)) - y_ȳ_rule, pb!! = rule(f_f̄, x_x̄_rule...) + y_ȳ_rule, pb!! = rule(to_fwds(f_f̄), x_x̄_rule...) # Verify that inputs / outputs are the same under `f` and its rrule. @test has_equal_data(x_primal, map(primal, x_x̄_rule)) @@ -256,11 +267,13 @@ function test_rrule_numerical_correctness(rng::AbstractRNG, f_f̄, x_x̄...; rul ȳ_delta = randn_tangent(rng, primal(y_ȳ_rule)) x̄_delta = map(Base.Fix1(randn_tangent, rng) ∘ primal, x_x̄_rule) - ȳ_init = set_to_zero!!(tangent(y_ȳ_rule)) - x̄_init = map(set_to_zero!! ∘ tangent, x_x̄_rule) + ȳ_init = set_to_zero!!(zero_tangent(primal(y_ȳ_rule), tangent(y_ȳ_rule))) + x̄_init = map(set_to_zero!!, x̄_zero) ȳ = increment!!(ȳ_init, ȳ_delta) - x̄ = map(increment!!, x̄_init, x̄_delta) - _, x̄... = pb!!(ȳ, tangent(f_f̄), x̄...) + map(increment!!, x̄_init, x̄_delta) + _, x̄_rvs_inc... = pb!!(Tapir.rdata(ȳ)) + x̄_rvs = map((x, x_inc) -> increment!!(rdata(x), x_inc), x̄_delta, x̄_rvs_inc) + x̄ = map(tangent, x̄_fwds, x̄_rvs) # Check that inputs have been returned to their original value. @test all(map(has_equal_data_up_to_undefs, x, map(primal, x_x̄_rule))) @@ -274,7 +287,7 @@ get_address(x) = ismutable(x) ? pointer_from_objref(x) : nothing _deepcopy(x) = deepcopy(x) _deepcopy(x::Module) = x -rrule_output_type(::Type{Ty}) where {Ty} = Tuple{codual_type(Ty), Any} +rrule_output_type(::Type{Ty}) where {Ty} = Tuple{Tapir.fcodual_type(Ty), Any} function test_rrule_interface(f_f̄, x_x̄...; is_primitive, ctx::C, rule) where {C} @nospecialize f_f̄ x_x̄ @@ -307,24 +320,29 @@ function test_rrule_interface(f_f̄, x_x̄...; is_primitive, ctx::C, rule) where @test _typeof(tangent(x_x̄)) == tangent_type(_typeof(primal(x_x̄))) end + # Extract the forwards-data from the tangents. + f_fwds = to_fwds(f_f̄) + x_fwds = map(to_fwds, x_x̄) + # Run the rrule, check it has output a thing of the correct type, and extract results. # Throw a meaningful exception if the rrule doesn't run at all. x_addresses = map(get_address, x) rrule_ret = try - rule(f_f̄, x_x̄...) + rule(f_fwds, x_fwds...) catch e display(e) println() throw(ArgumentError( - "rrule!! for $(_typeof(f_f̄)) with argument types $(_typeof(x_x̄)) does not run." + "rrule!! for $(_typeof(f_fwds)) with argument types $(_typeof(x_fwds)) does not run." )) end @test rrule_ret isa rrule_output_type(_typeof(y)) y_ȳ, pb!! = rrule_ret # Run the reverse-pass. Throw a meaningful exception if it doesn't run at all. + ȳ = Tapir.rdata(zero_tangent(primal(y_ȳ), tangent(y_ȳ))) f̄_new, x̄_new... = try - pb!!(tangent(y_ȳ), f̄, x̄...) + pb!!(ȳ) catch e display(e) println() @@ -339,14 +357,13 @@ function test_rrule_interface(f_f̄, x_x̄...; is_primitive, ctx::C, rule) where # Check the tangent types output by the reverse-pass, and that memory addresses of # mutable objects have remained constant. - @test _typeof(f̄_new) == _typeof(f̄) - @test all(map((a, b) -> _typeof(a) == _typeof(b), x̄_new, x̄)) - @test all(map((x̄, x̄_new) -> ismutable(x̄) ? x̄ === x̄_new : true, x̄, x̄_new)) + @test _typeof(f̄_new) == _typeof(rdata(f̄)) + @test all(map((a, b) -> _typeof(a) == _typeof(rdata(b)), x̄_new, x̄)) end function __forwards_and_backwards(rule, x_x̄::Vararg{Any, N}) where {N} out, pb!! = rule(x_x̄...) - return pb!!(tangent(out), map(tangent, x_x̄)...) + return pb!!(Tapir.zero_rdata(primal(out))) end function test_rrule_performance( @@ -368,14 +385,12 @@ function test_rrule_performance( JET.test_opt(primal(f_f̄), map(_typeof ∘ primal, x_x̄)) # Test forwards-pass stability. - JET.test_opt(rule, (_typeof(f_f̄), map(_typeof, x_x̄)...)) + JET.test_opt(rule, (_typeof(to_fwds(f_f̄)), map(_typeof ∘ to_fwds, x_x̄)...)) # Test reverse-pass stability. - y_ȳ, pb!! = rule(f_f̄, _deepcopy(x_x̄)...) - JET.test_opt( - pb!!, - (_typeof(tangent(y_ȳ)), _typeof(tangent(f_f̄)), map(_typeof ∘ tangent, x_x̄)...), - ) + y_ȳ, pb!! = rule(to_fwds(f_f̄), map(to_fwds, _deepcopy(x_x̄))...) + rvs_data = Tapir.rdata(zero_tangent(primal(y_ȳ), tangent(y_ȳ))) + JET.test_opt(pb!!, (_typeof(rvs_data), )) end if performance_checks_flag in (:allocs, :stability_and_allocs) @@ -387,8 +402,10 @@ function test_rrule_performance( @test (@allocations f(x...)) == 0 # Test allocations in round-trip. - __forwards_and_backwards(rule, f_f̄, x_x̄...) - @test (@allocations __forwards_and_backwards(rule, f_f̄, x_x̄...)) == 0 + f_f̄_fwds = to_fwds(f_f̄) + x_x̄_fwds = map(to_fwds, x_x̄) + __forwards_and_backwards(rule, f_f̄_fwds, x_x̄_fwds...) + @test (@allocations __forwards_and_backwards(rule, f_f̄_fwds, x_x̄_fwds...)) == 0 end end @@ -434,18 +451,11 @@ function test_rrule!!( test_rrule_performance(perf_flag, rule, x_x̄...) end -""" - test_interpreted_rrule!!(rng::AbstractRNG, x...; interp, kwargs...) - -A thin wrapper around a call to `set_up_gradient_problem` and `test_rrule!!`. -Does not require that the method being called is a primitive. -""" -function test_interpreted_rrule!!(rng::AbstractRNG, x...; interp, kwargs...) - rule, in_f = set_up_gradient_problem(x...; interp) - test_rrule!!(rng, in_f, x...; rule, kwargs...) -end - -function test_derived_rule(rng::AbstractRNG, x...; interp, kwargs...) +function test_derived_rule(rng::AbstractRNG, x...; safety_on=true, interp, kwargs...) + if safety_on + safe_rule = Tapir.build_rrule(interp, _typeof(__get_primals(x)); safety_on=true) + test_rrule!!(rng, x...; rule=safe_rule, interface_only=true, perf_flag=:none, is_primitive=false) + end rule = Tapir.build_rrule(interp, _typeof(__get_primals(x))) test_rrule!!(rng, x...; rule, kwargs...) end @@ -656,7 +666,6 @@ To verify that this is the case, ensure that all tests in either `test_tangent` `test_tangent_consistency` pass. """ function test_tangent_performance(rng::AbstractRNG, p::P) where {P} - @nospecialize rng, p # Should definitely infer, because tangent type must be known statically from primal. z = @inferred zero_tangent(p) @@ -668,8 +677,8 @@ function test_tangent_performance(rng::AbstractRNG, p::P) where {P} # Check there are no allocations when there ought not to be. if !__tangent_generation_should_allocate(P) - @test (@allocations zero_tangent_wrapper(p)) == 0 - @test (@allocations randn_tangent_wrapper(rng, p)) == 0 + JET.test_opt(Tuple{typeof(zero_tangent), P}) + JET.test_opt(Tuple{typeof(randn_tangent), Xoshiro, P}) end # `increment!!` should always infer. @@ -726,6 +735,8 @@ function test_get_tangent_field_performance(t::Union{MutableTangent, Tangent}) V = Tapir._typeof(t.fields) for n in 1:fieldcount(V) !is_init(t.fields[n]) && continue + Tfield = fieldtype(Tapir.fields_type(Tapir._typeof(t)), n) + !__is_completely_stable_type(Tfield) && continue # Int mode. i = Val(n) @@ -746,12 +757,10 @@ function count_allocs(f::F, x::Vararg{Any, N}) where {F, N} @allocations f(x...) end -@noinline zero_tangent_wrapper(p) = zero_tangent(p) -@noinline randn_tangent_wrapper(rng, p) = randn_tangent(rng, p) - -# Returns true if both `zero_tangent` and `randn_tangnet` should allocate when run on +# Returns true if both `zero_tangent` and `randn_tangent` should allocate when run on # an object of type `P`. function __tangent_generation_should_allocate(::Type{P}) where {P} + (!isconcretetype(P) || isabstracttype(P)) && return true (fieldcount(P) == 0 && !ismutabletype(P)) && return false return ismutabletype(P) || any(__tangent_generation_should_allocate, fieldtypes(P)) end @@ -763,9 +772,16 @@ function __increment_should_allocate(::Type{P}) where {P} Tapir.tangent_field_type(P, n) <: PossiblyUninitTangent end end +__increment_should_allocate(::Type{Core.SimpleVector}) = true + +function __is_completely_stable_type(::Type{P}) where {P} + (!isconcretetype(P) || isabstracttype(P)) && return false + isprimitivetype(P) && return true + return all(__is_completely_stable_type, fieldtypes(P)) +end """ - test_tangent(rng::AbstractRNG, p::P, z_target::T, x::T, y::T) where {P, T} + test_tangent(rng::AbstractRNG, p::P, x::T, y::T, z_target::T) where {P, T} Verify that primal `p` with tangents `z_target`, `x`, and `y`, satisfies the tangent interface. If these tests pass, then it should be possible to write `rrule!!`s for primals @@ -774,8 +790,10 @@ of type `P`, and to test them using `test_rrule!!`. As always, there are limits to the errors that these tests can identify -- they form necessary but not sufficient conditions for the correctness of your code. """ -function test_tangent(rng::AbstractRNG, p::P, z_target::T, x::T, y::T) where {P, T} - @nospecialize rng p z_target x y +function test_tangent( + rng::AbstractRNG, p::P, x::T, y::T, z_target::T; interface_only, perf=true +) where {P, T} + @nospecialize rng p x y z_target # Check the interface. test_tangent_consistency(rng, p; interface_only=false) @@ -786,19 +804,20 @@ function test_tangent(rng::AbstractRNG, p::P, z_target::T, x::T, y::T) where {P, # Check that zero_tangent infers. @inferred Tapir.zero_tangent(p) - # Verify that the zero tangent is zero via its action. - z = zero_tangent(p) - t = randn_tangent(rng, p) - @test has_equal_data(@inferred(increment!!(z, z)), z) - @test has_equal_data(increment!!(z, t), t) - @test has_equal_data(increment!!(t, z), t) - # Verify that adding together `x` and `y` gives the value the user expected. z_pred = increment!!(x, y) @test has_equal_data(z_pred, z_target) if ismutabletype(P) @test z_pred === x end + + # Check performance is as expected. + perf && test_tangent_performance(rng, p) +end + +function test_tangent(rng::AbstractRNG, p::P; interface_only=false) where {P} + test_tangent_consistency(rng, p; interface_only) + test_tangent_performance(rng, p) end function test_equality_comparison(x) @@ -809,10 +828,78 @@ function test_equality_comparison(x) @test has_equal_data_up_to_undefs(x, x) end +""" + test_fwds_rvs_data_interface(rng::AbstractRNG, p::P) where {P} + +Verify that the forwards data and reverse data functionality associated to primal `p` works +correctly. +""" +function test_fwds_rvs_data(rng::AbstractRNG, p::P) where {P} + + # Check that fdata_type and rdata_type run and produce types. + T = tangent_type(P) + F = Tapir.fdata_type(T) + @test F isa Type + R = Tapir.rdata_type(T) + @test R isa Type + + # Check that fdata and rdata produce the correct types. + t = randn_tangent(rng, p) + f = Tapir.fdata(t) + @test f isa F + r = Tapir.rdata(t) + @test r isa R + + # Check that uninit_fdata yields data of the correct type. + @test uninit_fdata(p) isa F + + # Compute the tangent type associated to `F` and `R`, and check it is equal to `T`. + @test tangent_type(F, R) == T + + # Check that combining f and r yields a tangent of the correct type and value. + t_combined = Tapir.tangent(f, r) + @test t_combined isa T + @test t_combined === t + + # Check that pulling out `f` and `r` from `t_combined` yields the correct values. + @test Tapir.fdata(t_combined) === f + @test Tapir.rdata(t_combined) === r + + # Test that `zero_rdata` produces valid reverse data. + @test zero_rdata(p) isa R + + # Check that constructing a zero tangent from reverse data yields the original tangent. + z = zero_tangent(p) + f_z = Tapir.fdata(z) + @test f_z isa Tapir.fdata_type(T) + z_new = zero_tangent(p, f_z) + @test z_new isa tangent_type(P) + @test z_new === z + + # Query whether or not the rdata type can be built given only the primal type. + can_make_zero = @inferred can_produce_zero_rdata_from_type(P) + + # Check that when the zero element is asked from the primal type alone, the result is + # either an instance of R _or_ a `CannotProduceZeroRDataFromType`. + JET.test_opt(zero_rdata_from_type, Tuple{Type{P}}) + rzero_from_type = @inferred zero_rdata_from_type(P) + @test rzero_from_type isa R || rzero_from_type isa CannotProduceZeroRDataFromType + @test can_make_zero != isa(rzero_from_type, CannotProduceZeroRDataFromType) + + # Check that we can produce a lazy zero rdata, and that it has the correct type. + JET.test_opt(LazyZeroRData, Tuple{P}) + lazy_rzero = @inferred LazyZeroRData(p) + @test instantiate(lazy_rzero) isa R + + # Check incrementing the rdata component of a tangent yields the correct type. + @test increment_rdata!!(t, r) isa T +end + function run_hand_written_rrule!!_test_cases(rng_ctor, v::Val) test_cases, memory = Tapir.generate_hand_written_rrule!!_test_cases(rng_ctor, v) + interp = Tapir.PInterp() GC.@preserve memory @testset "$f, $(_typeof(x))" for (interface_only, perf_flag, _, f, x...) in test_cases - test_rrule!!(rng_ctor(123), f, x...; interface_only, perf_flag) + test_derived_rule(rng_ctor(123), f, x...; interface_only, perf_flag, interp) end end @@ -833,39 +920,9 @@ function run_rrule!!_test_cases(rng_ctor, v::Val) end function to_benchmark(__rrule!!::R, dx::Vararg{CoDual, N}) where {R, N} - out, pb!! = __rrule!!(dx...) - return pb!!(tangent(out), map(tangent, dx)...) -end - -""" - set_up_gradient_problem(fargs...; interp=Tapir.PInterp()) - -Constructs a `rule` and `InterpretedFunction` which can be passed to `value_and_gradient!!`. - -For example: -```julia -f(x) = sum(abs2, x) -x = randn(25) -rule, in_f = Tapir.TestUtils.set_up_gradient_problem(f, x) -y, dx = Tapir.TestUtils.value_and_gradient!!(rule, in_f, f, x) -``` -will yield the value and associated gradient for `f` and `x`. - -You only need to run this function once, and may then call `value_and_gradient!!` many times -with the same `rule` and `in_f` arguments, but with different values of `x`. - -Optionally, an interpreter may be provided via the `interp` kwarg. - -See also: `Tapir.TestUtils.value_and_gradient!!`. -""" -function set_up_gradient_problem(fargs...; interp=Tapir.PInterp()) - sig = _typeof(__get_primals(fargs)) - if Tapir.is_primitive(DefaultCtx, sig) - return rrule!!, Tapir._eval - else - in_f = Tapir.InterpretedFunction(DefaultCtx(), sig, interp) - return Tapir.build_rrule!!(in_f), in_f - end + dx_f = Tapir.tuple_map(x -> CoDual(primal(x), Tapir.fdata(tangent(x))), dx) + out, pb!! = __rrule!!(dx_f...) + return pb!!(Tapir.zero_rdata(primal(out))) end __get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs) @@ -958,6 +1015,14 @@ function Base.:(==)(a::FullyInitMutableStruct, b::FullyInitMutableStruct) return equal_field(a, b, :x) && equal_field(a, b, :y) end +struct StructNoFwds + x::Float64 +end + +struct StructNoRvs + x::Vector{Float64} +end + # # Tests for AD. There are not rules defined directly on these functions, and they require # that most language primitives have rules defined. @@ -987,6 +1052,11 @@ function bar(x, y) return x5 end +function unused_expression(x, n) + y = getfield((Float64, ), n) + return x +end + const_tester_non_differentiable() = 1 const_tester() = cos(5.0) @@ -1207,11 +1277,11 @@ end # This catches the case where there are multiple phi nodes at the start of the block, and # they refer to one another. It is in this instance that the distinction between phi nodes # acting "instanteneously" and "in sequence" becomes apparent. -function test_multiple_phinode_block(x::Float64) +function test_multiple_phinode_block(x::Float64, N::Int) a = 1.0 b = x i = 1 - while i < 3 + while i < N temp = a a = b b = 2temp @@ -1281,6 +1351,11 @@ function test_union_of_types(x::Ref{Union{Type{Float64}, Type{Int}}}) return x[] end +function test_small_union(x::Ref{Union{Float64, Vector{Float64}}}) + v = x[] + return v isa Float64 ? v : v[1] +end + # Only one of these is a primitive. Lots of methods to prevent the compiler from # over-specialising. @noinline edge_case_tester(x::Float64) = 5x @@ -1290,8 +1365,8 @@ end @noinline edge_case_tester(x::String) = "hi" @is_primitive MinimalCtx Tuple{typeof(edge_case_tester), Float64} function Tapir.rrule!!(::CoDual{typeof(edge_case_tester)}, x::CoDual{Float64}) - edge_case_tester_pb!!(dy, df, dx) = df, dx + 5 * dy - return CoDual(5 * primal(x), 0.0), edge_case_tester_pb!! + edge_case_tester_pb!!(dy) = Tapir.NoRData(), 5 * dy + return Tapir.zero_fcodual(5 * primal(x)), edge_case_tester_pb!! end # To test the edge case properly, call this with x = Any[5.0, false] @@ -1329,7 +1404,7 @@ end function test_handwritten_sum(x::AbstractArray{<:Real}) y = 0.0 n = 0 - while n < length(x) + @inbounds while n < length(x) n += 1 y += x[n] end @@ -1350,6 +1425,7 @@ function generate_test_functions() (false, :allocs, nothing, foo, 5.0), (false, :allocs, nothing, non_differentiable_foo, 5), (false, :allocs, nothing, bar, 5.0, 4.0), + (false, :allocs, nothing, unused_expression, 5.0, 1), (false, :none, nothing, type_unstable_argument_eval, sin, 5.0), (false, :none, (lb=1, ub=1_000), pi_node_tester, Ref{Any}(5.0)), (false, :none, (lb=1, ub=1_000), pi_node_tester, Ref{Any}(5)), @@ -1357,7 +1433,7 @@ function generate_test_functions() (false, :allocs, nothing, goto_tester, 5.0), (false, :allocs, nothing, new_tester, 5.0, :hello), (false, :allocs, nothing, new_tester_2, 4.0), - (false, :none, nothing, new_tester_3, Ref{Any}(Tuple{Float64})), + (false, :none, nothing, new_tester_3, Ref{Any}(StructFoo)), (false, :allocs, nothing, type_stable_getfield_tester_1, StableFoo(5.0, :hi)), (false, :allocs, nothing, type_stable_getfield_tester_2, StableFoo(5.0, :hi)), (false, :none, nothing, globalref_tester), @@ -1366,18 +1442,18 @@ function generate_test_functions() (false, :none, nothing, globalref_tester_2, false), (false, :allocs, nothing, globalref_tester_3), (false, :allocs, nothing, globalref_tester_4), - (false, :none, (lb=1, ub=500), globalref_tester_5), + (false, :none, nothing, globalref_tester_5), (false, :none, (lb=1, ub=1_000), type_unstable_tester_0, Ref{Any}(5.0)), (false, :none, nothing, type_unstable_tester, Ref{Any}(5.0)), (false, :none, nothing, type_unstable_tester_2, Ref{Real}(5.0)), - (false, :none, (lb=1, ub=1000), type_unstable_tester_3, Ref{Any}(5.0)), - (false, :none, (lb=1, ub=10_000), test_primitive_dynamic_dispatch, Any[5.0, false]), + (false, :none, (lb=1, ub=500), type_unstable_tester_3, Ref{Any}(5.0)), + (false, :none, (lb=1, ub=500), test_primitive_dynamic_dispatch, Any[5.0, false]), (false, :none, nothing, type_unstable_function_eval, Ref{Any}(sin), 5.0), (false, :allocs, nothing, phi_const_bool_tester, 5.0), (false, :allocs, nothing, phi_const_bool_tester, -5.0), (false, :allocs, nothing, phi_node_with_undefined_value, true, 4.0), (false, :allocs, nothing, phi_node_with_undefined_value, false, 4.0), - (false, :allocs, nothing, test_multiple_phinode_block, 3.0), + (false, :allocs, nothing, test_multiple_phinode_block, 3.0, 3), ( false, :none, @@ -1392,7 +1468,7 @@ function generate_test_functions() (false, :allocs, nothing, simple_foreigncall_tester, randn(5)), (false, :none, nothing, simple_foreigncall_tester_2, randn(6), (2, 3)), (false, :allocs, nothing, foreigncall_tester, randn(5)), - (false, :none, (lb=1, ub=1_000), no_primitive_inlining_tester, 5.0), + (false, :none, nothing, no_primitive_inlining_tester, 5.0), (false, :allocs, nothing, varargs_tester, 5.0), (false, :allocs, nothing, varargs_tester, 5.0, 4), (false, :allocs, nothing, varargs_tester, 5.0, 4, 3.0), @@ -1416,14 +1492,14 @@ function generate_test_functions() (false, :none, (lb=1, ub=1_000), datatype_slot_tester, 2), (false, :none, (lb=1, ub=100_000_000), test_union_of_arrays, randn(5), true), ( - false, :none, (lb=1, ub=500), + false, :none, nothing, test_union_of_types, Ref{Union{Type{Float64}, Type{Int}}}(Float64), ), (false, :allocs, nothing, test_self_reference, 1.1, 1.5), (false, :allocs, nothing, test_self_reference, 1.5, 1.1), (false, :none, nothing, test_recursive_sum, randn(2)), ( - false, :none, (lb=1, ub=1_000), + false, :none, nothing, LinearAlgebra._modify!, LinearAlgebra.MulAddMul(5.0, 4.0), 5.0, @@ -1441,8 +1517,10 @@ function generate_test_functions() setfield_tester_right!, FullyInitMutableStruct(5.0, randn(3)), randn(5), ), (false, :none, nothing, mul!, randn(3, 5)', randn(5, 5), randn(5, 3), 4.0, 3.0), + (false, :none, nothing, Random.make_seed, 5), + (false, :none, nothing, Random.SHA.digest!, Random.SHA.SHA2_256_CTX()), (false, :none, nothing, Xoshiro, 123456), - (false, :none, (lb=1, ub=100_000), *, randn(250, 500), randn(500, 250)), + (false, :none, nothing, *, randn(250, 500), randn(500, 250)), (false, :allocs, nothing, test_sin, 1.0), (false, :allocs, nothing, test_cos_sin, 2.0), (false, :allocs, nothing, test_isbits_multiple_usage, 5.0), @@ -1454,7 +1532,7 @@ function generate_test_functions() (false, :allocs, nothing, test_isbits_multiple_usage_phi, true, 1.1), (false, :allocs, nothing, test_multiple_call_non_primitive, 5.0), (false, :none, (lb=1, ub=1500), test_multiple_pi_nodes, Ref{Any}(5.0)), - (false, :none, (lb=1, ub=1500), test_multi_use_pi_node, Ref{Any}(5.0)), + (false, :none, (lb=1, ub=500), test_multi_use_pi_node, Ref{Any}(5.0)), (false, :allocs, nothing, test_getindex, [1.0, 2.0]), (false, :allocs, nothing, test_mutation!, [1.0, 2.0]), (false, :allocs, nothing, test_while_loop, 2.0), @@ -1462,7 +1540,7 @@ function generate_test_functions() (false, :none, nothing, test_mutable_struct_basic, 5.0), (false, :none, nothing, test_mutable_struct_basic_sin, 5.0), (false, :none, nothing, test_mutable_struct_setfield, 4.0), - (false, :none, (lb=1, ub=2_000), test_mutable_struct, 5.0), + (false, :none, (lb=1, ub=500), test_mutable_struct, 5.0), (false, :none, nothing, test_struct_partial_init, 3.5), (false, :none, nothing, test_mutable_partial_init, 3.3), ( @@ -1471,35 +1549,79 @@ function generate_test_functions() ), ( false, :allocs, nothing, - (A, C) -> test_naive_mat_mul!(C, A, A), randn(100, 100), randn(100, 100), + (A, C) -> test_naive_mat_mul!(C, A, A), randn(25, 25), randn(25, 25), ), - (false, :allocs, (lb=10, ub=1_000), sum, randn(30)), - (false, :none, (lb=10, ub=1_000), test_diagonal_to_matrix, Diagonal(randn(30))), + (false, :allocs, nothing, sum, randn(32)), + (false, :none, nothing, test_diagonal_to_matrix, Diagonal(randn(30))), ( - false, :allocs, (lb=100, ub=1_000), + false, :allocs, nothing, ldiv!, randn(20, 20), Diagonal(rand(20) .+ 1), randn(20, 20), ), ( - false, :allocs, (lb=10, ub=500), - LinearAlgebra._kron!, randn(400, 400), randn(20, 20), randn(20, 20), + false, :allocs, nothing, + LinearAlgebra._kron!, randn(25, 25), randn(5, 5), randn(5, 5), ), ( - false, :allocs, (lb=10, ub=500), - kron!, randn(400, 400), Diagonal(randn(20)), randn(20, 20), + false, :allocs, nothing, + kron!, randn(25, 25), Diagonal(randn(5)), randn(5, 5), ), ( false, :none, nothing, test_mlp, - randn(sr(1), 500, 200), - randn(sr(2), 700, 500), - randn(sr(3), 300, 700), + randn(sr(1), 50, 20), + randn(sr(2), 70, 50), + randn(sr(3), 30, 70), ), - (false, :allocs, (lb=1.0, ub=150), test_handwritten_sum, randn(1024 * 1024)), + (false, :allocs, nothing, test_handwritten_sum, randn(128, 128)), + (false, :allocs, nothing, _naive_map_sin_cos_exp, randn(1024), randn(1024)), + (false, :allocs, nothing, _naive_map_negate, randn(1024), randn(1024)), + (false, :allocs, nothing, test_from_slack, randn(10_000)), (false, :none, nothing, _sum, randn(1024)), (false, :none, nothing, test_map, randn(1024), randn(1024)), + (false, :none, nothing, _broadcast_sin_cos_exp, randn(10, 10)), + (false, :none, nothing, _map_sin_cos_exp, randn(10, 10)), + (false, :none, nothing, ArgumentError, "hi"), + (false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}(5.0)), + (false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}([1.0])), ] end +_broadcast_sin_cos_exp(x::AbstractArray{<:Real}) = sum(sin.(cos.(exp.(x)))) + +_map_sin_cos_exp(x::AbstractArray{<:Real}) = sum(map(x -> sin(cos(exp(x))), x)) + +function _naive_map_sin_cos_exp(y::AbstractArray{<:Real}, x::AbstractArray{<:Real}) + n = 1 + while n <= length(x) + y[n] = sin(cos(exp(x[n]))) + n += 1 + end + return y +end + +function _naive_map_negate(y::AbstractArray{<:Real}, x::AbstractArray{<:Real}) + n = 1 + while n <= length(x) + y[n] = -x[n] + n += 1 + end + return y +end + +function test_from_slack(x::AbstractVector{T}) where {T} + y = zero(T) + n = 1 + while n <= length(x) + if iseven(n) + y += sin(x[n]) + else + y += cos(x[n]) + end + n += 1 + end + return y +end + function value_dependent_control_flow(x, n) while n > 0 x = cos(x) diff --git a/src/utils.jl b/src/utils.jl index 07d37c97..6adcd586 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,7 +10,7 @@ _typeof(x::NamedTuple{names}) where {names} = NamedTuple{names, _typeof(Tuple(x) """ tuple_map(f::F, x::Tuple) where {F} -This function is semantically equivalent to `map(f, x)`, but always specialises on all of +This function is largely equivalent to `map(f, x)`, but always specialises on all of the element types of `x`, regardless the length of `x`. This contrasts with `map`, in which the number of element types specialised upon is a fixed constant in the compiler. @@ -18,13 +18,140 @@ As a consequence, if `x` is very long, this function may have very large compile tuple_map(f::F, x::Tuple, y::Tuple) where {F} -Binary extension of `tuple_map`. Equivalent to `map(f, x, y`, but guaranteed to specialise -on all element types of `x` and `y`. +Binary extension of `tuple_map`. Nearly equivalent to `map(f, x, y)`, but guaranteed to +specialise on all element types of `x` and `y`. Furthermore, errors if `x` and `y` aren't +the same length, while `map` will just produce a new tuple whose length is equal to the +shorter of `x` and `y`. """ @inline @generated function tuple_map(f::F, x::Tuple) where {F} return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), eachindex(x.parameters))...) end @inline @generated function tuple_map(f::F, x::Tuple, y::Tuple) where {F} - return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), eachindex(x.parameters))...) + if length(x.parameters) != length(y.parameters) + return :(throw(ArgumentError("length(x) != length(y)"))) + else + stmts = map(n -> :(f(getfield(x, $n), getfield(y, $n))), eachindex(x.parameters)) + return Expr(:call, :tuple, stmts...) + end +end + +function tuple_map(f::F, x::NamedTuple{names}) where {F, names} + return NamedTuple{names}(tuple_map(f, Tuple(x))) +end + +function tuple_map(f::F, x::NamedTuple{names}, y::NamedTuple{names}) where {F, names} + return NamedTuple{names}(tuple_map(f, Tuple(x), Tuple(y))) +end + +@inline @generated function tuple_splat(f, x::Tuple) + return Expr(:call, :f, map(n -> :(x[$n]), 1:length(x.parameters))...) +end + +@inline @generated function tuple_splat(f, v, x::Tuple) + return Expr(:call, :f, :v, map(n -> :(x[$n]), 1:length(x.parameters))...) +end + +@inline @generated function tuple_fill(val ,::Val{N}) where {N} + return Expr(:call, :tuple, map(_ -> :val, 1:N)...) +end + + +#= + _map_if_assigned!(f, y::Array, x::Array{P}) where {P} + +For all `n`, if `x[n]` is assigned, then writes the value returned by `f(x[n])` to `y[n]`, +otherwise leaves `y[n]` unchanged. + +Equivalent to `map!(f, y, x)` if `P` is a bits type as element will always be assigned. + +Requires that `y` and `x` have the same size. +=# +function _map_if_assigned!(f::F, y::Array, x::Array{P}) where {F, P} + @assert size(y) == size(x) + @inbounds for n in eachindex(y) + if isbitstype(P) || isassigned(x, n) + y[n] = f(x[n]) + end + end + return y +end + +#= + _map_if_assigned!(f::F, y::Array, x1::Array{P}, x2::Array) + +Similar to the other method of `_map_if_assigned!` -- for all `n`, if `x1[n]` is assigned, +writes `f(x1[n], x2[n])` to `y[n]`, otherwise leaves `y[n]` unchanged. + +Requires that `y`, `x1`, and `x2` have the same size. +=# +function _map_if_assigned!(f::F, y::Array, x1::Array{P}, x2::Array) where {F, P} + @assert size(y) == size(x1) + @assert size(y) == size(x2) + @inbounds for n in eachindex(y) + if isbitstype(P) || isassigned(x1, n) + y[n] = f(x1[n], x2[n]) + end + end + return y +end + +#= + _map(f, x...) + +Same as `map` but requires all elements of `x` to have equal length. +The usual function `map` doesn't enforce this for `Array`s. +=# +@inline function _map(f::F, x::Vararg{Any, N}) where {F, N} + @assert allequal(map(length, x)) + return map(f, x...) +end + +#= + is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} + +Returns a 2-tuple. The first element is true if the method associated to `sig` is a vararg +method, and false if not. The second element contains all of the names of the static +parameters associated to said method. +=# +function is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} + world = Base.get_world_counter() + min = Base.RefValue{UInt}(typemin(UInt)) + max = Base.RefValue{UInt}(typemax(UInt)) + ms = Base._methods_by_ftype(sig, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector + m = only(ms).method + return m.isva, sparam_names(m) +end + +# Returns the names of all of the static parameters in `m`. +function sparam_names(m::Core.Method)::Vector{Symbol} + whereparams = ExprTools.where_parameters(m.sig) + whereparams === nothing && return Symbol[] + return map(whereparams) do name + name isa Symbol && return name + Meta.isexpr(name, :(<:)) && return name.args[1] + Meta.isexpr(name, :(>:)) && return name.args[1] + error("unrecognised type param $name") + end +end + +""" + is_always_initialised(::Type{P}, n::Int) + +True if the `n`th field of `P` is always initialised. If the `n`th fieldtype of `P` +`isbitstype`, then this is distinct from asking whether the `n`th field is always defined. +An isbits field is always defined, but is not always explicitly initialised. +""" +function is_always_initialised(::Type{P}, n::Int) where {P} + return n <= Core.Compiler.datatype_min_ninitialized(P) +end + +""" + is_always_fully_initialised(::Type{P}) where {P} + +True if all fields in `P` are always initialised. Put differently, there are no inner +constructors which permit partial initialisation. +""" +function is_always_fully_initialised(::Type{P}) where {P} + return Core.Compiler.datatype_min_ninitialized(P) == fieldcount(P) end diff --git a/test/codual.jl b/test/codual.jl index da8e8267..21115f51 100644 --- a/test/codual.jl +++ b/test/codual.jl @@ -12,4 +12,9 @@ codual_type(Union{Float64, Int}), Union{CoDual{Float64, Float64}, CoDual{Int, NoTangent}}, )) + @test codual_type(UnionAll) == CoDual + @testset "NoPullback" begin + @test Base.issingletontype(typeof(NoPullback(zero_fcodual(5.0)))) + @test NoPullback(zero_codual(5.0))(4.0) == (0.0, ) + end end diff --git a/test/front_matter.jl b/test/front_matter.jl index de4ae81b..32ed8395 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -30,20 +30,9 @@ using Tapir: rrule!!, lgetfield, lsetfield!, - might_be_active, build_tangent, - SlotRef, - ConstSlot, - TypedGlobalRef, - build_inst, - TypedPhiNode, - build_coinsts, Stack, _typeof, - get_codual, - get_tangent_stack, - top_ref, - NoTangentRef, BBCode, ID, IDPhiNode, @@ -55,7 +44,18 @@ using Tapir: ad_stmt_info, ADInfo, SharedDataPairs, - increment_field!! + increment_field!!, + NoFData, + NoRData, + zero_fcodual, + zero_like_rdata_from_type, + zero_rdata, + instantiate, + LazyZeroRData, + new_inst, + characterise_unique_predecessor_blocks, + NoPullback, + characterise_used_ids using .TestUtils: test_rrule!!, diff --git a/test/fwds_rvs_data.jl b/test/fwds_rvs_data.jl new file mode 100644 index 00000000..b5215f30 --- /dev/null +++ b/test/fwds_rvs_data.jl @@ -0,0 +1,35 @@ +module FwdsRvsDataTestResources + struct Foo{A} end +end + +@testset "fwds_rvs_data" begin + @testset "$(typeof(p))" for (_, p, _...) in Tapir.tangent_test_cases() + TestUtils.test_fwds_rvs_data(Xoshiro(123456), p) + end + @testset "zero_rdata_from_type checks" begin + @test Tapir.can_produce_zero_rdata_from_type(Vector) == true + @test Tapir.zero_rdata_from_type(Vector) == NoRData() + @test Tapir.can_produce_zero_rdata_from_type(FwdsRvsDataTestResources.Foo) == false + @test ==( + Tapir.zero_rdata_from_type(FwdsRvsDataTestResources.Foo), + Tapir.CannotProduceZeroRDataFromType(), + ) + end + @testset "lazy construction checks" begin + # Check that lazy construction is in fact lazy for some cases where performance + # really matters -- floats, things with no rdata, etc. + @testset "$p" for (p, fully_lazy) in Any[ + (5, true), + (Int32(5), true), + (5.0, true), + (5f0, true), + (Float16(5.0), true), + (StructFoo(5.0), false), + (StructFoo(5.0, randn(4)), false), + (Bool, true), + ] + @test fully_lazy == Base.issingletontype(typeof(LazyZeroRData(p))) + @inferred Tapir.instantiate(LazyZeroRData(p)) + end + end +end diff --git a/test/integration_testing/misc.jl b/test/integration_testing/misc.jl index 4e26270d..b978fd65 100644 --- a/test/integration_testing/misc.jl +++ b/test/integration_testing/misc.jl @@ -23,7 +23,6 @@ (false, Array{Vector{Float64}, 1}, undef, (1, )), (false, Array{Vector{Float64}, 2}, undef, (2, 3)), (false, Array{Vector{Float64}, 3}, undef, (2, 3, 4)), - (false, Xoshiro, 123456), (false, push!, randn(5), 3.0), (false, x -> (a=x, b=x), 5.0), ], @@ -88,7 +87,7 @@ @info "$(_typeof((f, x...)))" TestUtils.test_derived_rule( Xoshiro(123456), f, x...; - interp, perf_flag=:none, interface_only, is_primitive=false, + safety_on=false, interp, perf_flag=:none, interface_only, is_primitive=false, ) end end diff --git a/test/integration_testing/special_functions.jl b/test/integration_testing/special_functions.jl index f1b43d9f..65187b79 100644 --- a/test/integration_testing/special_functions.jl +++ b/test/integration_testing/special_functions.jl @@ -9,6 +9,9 @@ (false, :stability, erfc, 0.1), (false, :stability, erfc, 0.0), (false, :stability, erfc, -0.5), + (false, :stability, erfcx, 0.1), + (false, :stability, erfcx, 0.0), + (false, :stability, erfcx, -0.5), ] test_rrule!!(Xoshiro(123456), f, x...; interface_only, perf_flag) end diff --git a/test/integration_testing/temporalgps.jl b/test/integration_testing/temporalgps.jl new file mode 100644 index 00000000..b73f18df --- /dev/null +++ b/test/integration_testing/temporalgps.jl @@ -0,0 +1,44 @@ +using AbstractGPs, KernelFunctions, TemporalGPs + +build_gp() = to_sde(GP(Matern12Kernel()), SArrayStorage(Float64)) + +temporalgps_logpdf_tester(x, y, s) = logpdf(build_gp()(x, s), y) + +@testset "temporalgps" begin + x = range(-5.0; step=0.1, length=10_000) + s = 1.0 + y = rand(build_gp()(x, s)) + + f = temporalgps_logpdf_tester + x = (x, y, s) + sig = _typeof((temporalgps_logpdf_tester, x...)) + @info "$sig" + interp = Tapir.PInterp() + TestUtils.test_derived_rule( + Xoshiro(123456), f, x...; + interp, perf_flag=:none, interface_only=false, is_primitive=false + ) + + # codual_args = map(zero_codual, (f, x...)) + # rule = Tapir.build_rrule(interp, sig) + + # primal = @benchmark $f($x...) + # gradient = @benchmark(TestUtils.to_benchmark($rule, $codual_args...)) + + # println("primal") + # display(primal) + # println() + + # println("gradient") + # display(gradient) + # println() + + # @show time(gradient) / time(primal) + + # # @profview run_many_times(100, f, x...) + # TestUtils.to_benchmark(rule, codual_args...) + # @profview run_many_times(10, TestUtils.to_benchmark, rule, codual_args...) + # Profile.clear() + # @profile run_many_times(10, TestUtils.to_benchmark, rule, codual_args...) + # pprof() +end diff --git a/test/integration_testing/turing.jl b/test/integration_testing/turing.jl index 1bf2ca9c..064fc9ce 100644 --- a/test/integration_testing/turing.jl +++ b/test/integration_testing/turing.jl @@ -150,6 +150,12 @@ end # @show time(gradient) / time(primal) + # @profview run_many_times(10_000, TestUtils.to_benchmark, rule, codualed_args...) + + # Profile.clear() + # @profile run_many_times(10_000, TestUtils.to_benchmark, rule, codualed_args...) + # pprof() + # push!(turing_bench_results, (name, primal, gradient, revdiff)) end end diff --git a/test/interpreter/bbcode.jl b/test/interpreter/bbcode.jl index ffed52ab..a54ffe95 100644 --- a/test/interpreter/bbcode.jl +++ b/test/interpreter/bbcode.jl @@ -21,6 +21,10 @@ end @test bb isa BBlock @test length(bb) == 2 + ids, phi_nodes = Tapir.phi_nodes(bb) + @test only(ids) == bb.inst_ids[1] + @test only(phi_nodes) == bb.insts[1] + insert!(bb, 1, ID(), CC.NewInstruction(nothing, Nothing)) @test length(bb) == 3 @test bb.insts[1].stmt === nothing @@ -49,4 +53,145 @@ end @test length(Tapir.collect_stmts(bb_code)) == length(ir.stmts.inst) @test Tapir.id_to_line_map(bb_code) isa Dict{ID, Int} end + @testset "_characterise_unique_predecessor_blocks" begin + @testset "single block" begin + blk_id = ID() + blks = BBlock[BBlock(blk_id, [ID()], [new_inst(ReturnNode(5))])] + upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) + @test upreds[blk_id] == true + @test pred_is_upred[blk_id] == true + end + @testset "pair of blocks" begin + blk_id_1 = ID() + blk_id_2 = ID() + blks = BBlock[ + BBlock(blk_id_1, [ID()], [new_inst(IDGotoNode(blk_id_2))]), + BBlock(blk_id_2, [ID()], [new_inst(ReturnNode(5))]), + ] + upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) + @test upreds[blk_id_1] == true + @test upreds[blk_id_2] == true + @test pred_is_upred[blk_id_1] == true + @test pred_is_upred[blk_id_2] == true + end + @testset "Non-Unique Exit Node" begin + blk_id_1 = ID() + blk_id_2 = ID() + blk_id_3 = ID() + blks = BBlock[ + BBlock(blk_id_1, [ID()], [new_inst(IDGotoIfNot(true, blk_id_3))]), + BBlock(blk_id_2, [ID()], [new_inst(ReturnNode(5))]), + BBlock(blk_id_3, [ID()], [new_inst(ReturnNode(5))]), + ] + upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) + @test upreds[blk_id_1] == true + @test upreds[blk_id_2] == false + @test upreds[blk_id_3] == false + @test pred_is_upred[blk_id_1] == true + @test pred_is_upred[blk_id_2] == true + @test pred_is_upred[blk_id_3] == true + end + @testset "diamond structure of four blocks" begin + blk_id_1 = ID() + blk_id_2 = ID() + blk_id_3 = ID() + blk_id_4 = ID() + blks = BBlock[ + BBlock(blk_id_1, [ID()], [new_inst(IDGotoIfNot(true, blk_id_3))]), + BBlock(blk_id_2, [ID()], [new_inst(IDGotoNode(blk_id_4))]), + BBlock(blk_id_3, [ID()], [new_inst(IDGotoNode(blk_id_4))]), + BBlock(blk_id_4, [ID()], [new_inst(ReturnNode(0))]), + ] + upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) + @test upreds[blk_id_1] == true + @test upreds[blk_id_2] == false + @test upreds[blk_id_3] == false + @test upreds[blk_id_4] == true + @test pred_is_upred[blk_id_1] == true + @test pred_is_upred[blk_id_2] == true + @test pred_is_upred[blk_id_3] == true + @test pred_is_upred[blk_id_4] == false + end + @testset "simple loop back to first block" begin + blk_id_1 = ID() + blk_id_2 = ID() + blks = BBlock[ + BBlock(blk_id_1, [ID()], [new_inst(IDGotoIfNot(true, blk_id_1))]), + BBlock(blk_id_2, [ID()], [new_inst(ReturnNode(5))]), + ] + upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) + @test upreds[blk_id_1] == true + @test upreds[blk_id_2] == true + @test pred_is_upred[blk_id_1] == false + @test pred_is_upred[blk_id_2] == true + end + end + @testset "characterise_used_ids" begin + @testset "_find_id_uses!" begin + @testset "Expr" begin + id = ID() + d = Dict{ID, Bool}(id => false) + Tapir._find_id_uses!(d, Expr(:call, sin, 5)) + @test d[id] == false + Tapir._find_id_uses!(d, Expr(:call, sin, id)) + @test d[id] == true + end + @testset "IDGotoIfNot" begin + id = ID() + d = Dict{ID, Bool}(id => false) + Tapir._find_id_uses!(d, IDGotoIfNot(ID(), ID())) + @test d[id] == false + Tapir._find_id_uses!(d, IDGotoIfNot(true, ID())) + @test d[id] == false + Tapir._find_id_uses!(d, IDGotoIfNot(id, ID())) + @test d[id] == true + end + @testset "IDGotoNode" begin + id = ID() + d = Dict{ID, Bool}(id => false) + Tapir._find_id_uses!(d, IDGotoNode(ID())) + @test d[id] == false + end + @testset "IDPhiNode" begin + id = ID() + d = Dict{ID, Bool}(id => false) + Tapir._find_id_uses!(d, IDPhiNode([ID()], Vector{Any}(undef, 1))) + @test d[id] == false + Tapir._find_id_uses!(d, IDPhiNode([ID()], Any[id])) + @test d[id] == true + end + @testset "PiNode" begin + id = ID() + d = Dict{ID, Bool}(id => false) + Tapir._find_id_uses!(d, PiNode(false, Bool)) + @test d[id] == false + Tapir._find_id_uses!(d, PiNode(id, Bool)) + @test d[id] == true + end + @testset "ReturnNode" begin + id = ID() + d = Dict{ID, Bool}(id => false) + Tapir._find_id_uses!(d, ReturnNode()) + @test d[id] == false + Tapir._find_id_uses!(d, ReturnNode(5)) + @test d[id] == false + Tapir._find_id_uses!(d, ReturnNode(id)) + @test d[id] == true + end + end + @testset "some used some unused" begin + id_1 = ID() + id_2 = ID() + id_3 = ID() + stmts = Tuple{ID, Core.Compiler.NewInstruction}[ + (id_1, new_inst(Expr(:call, sin, Argument(1)))), + (id_2, new_inst(Expr(:call, cos, id_1))), + (id_3, new_inst(ReturnNode(id_2))), + ] + result = characterise_used_ids(stmts) + @test result[id_1] == true + @test result[id_2] == true + @test result[id_3] == false + end + end end diff --git a/test/interpreter/interpreted_function.jl b/test/interpreter/interpreted_function.jl deleted file mode 100644 index 508e0ca1..00000000 --- a/test/interpreter/interpreted_function.jl +++ /dev/null @@ -1,230 +0,0 @@ -@testset "interpreted_function" begin - @testset "TypedGlobalRef" begin - @testset "tracks changes" begin - global __x_for_gref = 5.0 - r = TypedGlobalRef(GlobalRef(Main, :__x_for_gref)) - @test r[] == 5.0 - global __x_for_gref = 4.0 - @test r[] == 4.0 - end - @testset "is type stable" begin - global __y_for_gref::Float64 = 5.0 - r = TypedGlobalRef(GlobalRef(Main, :__y_for_gref)) - @test @inferred(r[]) == 5.0 - global __y_for_gref = 4.0 - @test @inferred(r[]) == 4.0 - end - end - - # Check correctness and performance of the ArgInfo type. We really need everything to - # infer correctly. - @testset "ArgInfo: $Tx, $(x), $is_va" for (Tx, x, is_va) in Any[ - - # No varargs examples. - Any[Tuple{Float64}, (5.0,), false], - Any[Tuple{Float64, Int}, (5.0, 3), false], - Any[Tuple{Type{Float64}}, (Float64, ), false], - Any[Tuple{Type{Any}}, (Any, ), false], - - # Varargs examples. - Any[Tuple{Tuple{Float64}}, (5.0, ), true], - Any[Tuple{Tuple{Float64, Int}}, (5.0, 3), true], - Any[Tuple{Float64, Tuple{Int}}, (5.0, 3), true], - Any[Tuple{Float64, Tuple{Int, Float64}}, (5.0, 3, 4.0), true], - ] - ai = Tapir.ArgInfo(Tx, is_va) - @test @inferred Tapir.load_args!(ai, x) === nothing - end - - @testset "TypedPhiNode" begin - @testset "standard example of a phi node" begin - node = TypedPhiNode( - SlotRef{Float64}(), - SlotRef{Float64}(), - (1, 2), - (ConstSlot(5.0), SlotRef(4.0)), - ) - Tapir.store_tmp_value!(node, 1) - @test node.tmp_slot[] == 5.0 - Tapir.transfer_tmp_value!(node) - @test node.ret_slot[] == 5.0 - Tapir.store_tmp_value!(node, 2) - @test node.tmp_slot[] == 4.0 - @test node.ret_slot[] == 5.0 - Tapir.transfer_tmp_value!(node) - @test node.ret_slot[] == 4.0 - end - @testset "phi node with nothing in it" begin - node = TypedPhiNode(SlotRef{Union{}}(), SlotRef{Union{}}(), (), ()) - Tapir.store_tmp_value!(node, 1) - Tapir.transfer_tmp_value!(node) - end - @testset "phi node with undefined value" begin - node = TypedPhiNode( - SlotRef{Float64}(), SlotRef{Float64}(), (1, ), (SlotRef{Float64}(),) - ) - Tapir.store_tmp_value!(node, 1) - Tapir.transfer_tmp_value!(node) - end - end - - @testset "Unit tests for nodes and associated instructions" begin - - global __x_for_gref = 5.0 - global __y_for_gref::Float64 = 4.0 - - @testset "ReturnNode" begin - @testset "build_instruction: ReturnNode, $(_typeof(args))" for args in Any[ - (SlotRef(5.0), SlotRef{Float64}()), - (SlotRef(4), SlotRef{Any}()), - (ConstSlot(5), SlotRef{Int}()), - (ConstSlot(5.0), SlotRef{Real}()), - (ConstSlot(:hi), SlotRef{Symbol}()), - (ConstSlot(:hi), SlotRef{Any}()), - (TypedGlobalRef(GlobalRef(Main, :__x_for_gref)), SlotRef{Any}()), - (ConstSlot(sin), SlotRef{typeof(sin)}()), - ] - val, ret_slot = args - oc = build_inst(ReturnNode, ret_slot, val) - @test oc isa Tapir.Inst - output = oc(0) - @test output == -1 - @test ret_slot[] == val[] - end - end - - @testset "GotoNode $label" for label in Any[1, 2, 3, 4, 5] - oc = build_inst(GotoNode, label) - @test oc isa Tapir.Inst - @test oc(3) == label - end - - global __global_bool = false - @testset "GotoIfNot $cond" for cond in Any[ - SlotRef(true), SlotRef(false), - ConstSlot(true), ConstSlot(false), - SlotRef{Any}(true), SlotRef{Real}(false), - ConstSlot{Any}(true), ConstSlot{Any}(false), - TypedGlobalRef(GlobalRef(Main, :__global_bool)), - ] - oc = build_inst(GotoIfNot, cond, 1, 2) - @test oc isa Tapir.Inst - @test oc(5) == (cond[] ? 1 : 2) - end - - global __global_bool = true - @testset "PiNode" for (input, out, prev_blk, next_blk) in Any[ - (SlotRef{Any}(5.0), SlotRef{Float64}(), 2, 3), - (ConstSlot{Float64}(5.0), SlotRef{Float64}(), 2, 2), - (TypedGlobalRef(GlobalRef(Main, :__global_bool)), ConstSlot(true), 2, 2) - ] - oc = build_inst(PiNode, input, out, next_blk) - @test oc isa Tapir.Inst - @test oc(prev_blk) == next_blk - @test out[] == input[] - end - - global __x_for_gref = 5.0 - @testset "GlobalRef" for (out, x, next_blk) in Any[ - (SlotRef{Float64}(), TypedGlobalRef(Main, :__x_for_gref), 5), - (SlotRef{typeof(sin)}(), ConstSlot(sin), 4), - ] - oc = build_inst(GlobalRef, x, out, next_blk) - @test oc isa Tapir.Inst - @test oc(4) == next_blk - @test out[] == x[] - end - - @testset "QuoteNode and literals" for (x, out, next_blk) in Any[ - (ConstSlot(5), SlotRef{Int}(), 5), - ] - oc = build_inst(nothing, x, out, next_blk) - @test oc isa Tapir.Inst - @test oc(1) == next_blk - @test out[] == x[] - end - - @testset "Val{:boundscheck}" begin - val_ref = SlotRef{Bool}() - oc = build_inst(Val(:boundscheck), val_ref, 3) - @test oc isa Tapir.Inst - @test oc(5) == 3 - @test val_ref[] == true - end - - global __int_output = 5 - @testset "Val{:call}" for (arg_slots, evaluator, val_slot, next_blk) in Any[ - ((ConstSlot(sin), SlotRef(5.0)), Tapir._eval, SlotRef{Float64}(), 3), - ((ConstSlot(*), SlotRef(4.0), ConstSlot(4.0)), Tapir._eval, SlotRef{Any}(), 3), - ( - (ConstSlot(+), ConstSlot(4), ConstSlot(5)), - Tapir._eval, - TypedGlobalRef(Main, :__int_output), - 2, - ), - ( - (ConstSlot(getfield), SlotRef((5.0, 5)), ConstSlot(1)), - Tapir.get_evaluator( - Tapir.MinimalCtx(), - Tuple{typeof(getfield), Tuple{Float64, Int}, Int}, - nothing, - false, - ), - SlotRef{Float64}(), - 3, - ), - ] - oc = build_inst(Val(:call), arg_slots, evaluator, val_slot, next_blk) - @test oc isa Tapir.Inst - @test oc(0) == next_blk - f, args... = map(getindex, arg_slots) - @test val_slot[] == f(args...) - end - - @testset "Val{:skipped_expression}" begin - oc = build_inst(Val(:skipped_expression), 3) - @test oc isa Tapir.Inst - @test oc(5) == 3 - end - - @testset "Val{:throw_undef_if_not}" begin - @testset "defined" begin - slot_to_check = SlotRef(5.0) - oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - @test oc isa Tapir.Inst - @test oc(0) == 2 - end - @testset "undefined (non-isbits)" begin - slot_to_check = SlotRef{Any}() - oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - @test oc isa Tapir.Inst - @test_throws ErrorException oc(3) - end - @testset "undefined (isbits)" begin - slot_to_check = SlotRef{Float64}() - oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - @test oc isa Tapir.Inst - - # a placeholder for failing to throw an ErrorException when evaluated - @test_broken oc(5) == 1 - end - end - end - - # Check that a suite of test cases run and give the correct answer. - interp = Tapir.PInterp() - @testset "$(_typeof((f, x...)))" for (a, b, c, f, x...) in - TestResources.generate_test_functions() - - sig = _typeof((f, x...)) - @info "$sig" - in_f = Tapir.InterpretedFunction(DefaultCtx(), sig, interp) - - # Verify correctness. - @assert f(x...) == f(x...) # check that the primal runs - x_cpy_1 = deepcopy(x) - x_cpy_2 = deepcopy(x) - @test has_equal_data(in_f(f, x_cpy_1...), f(x_cpy_2...)) - @test has_equal_data(x_cpy_1, x_cpy_2) - end -end diff --git a/test/interpreter/registers.jl b/test/interpreter/registers.jl deleted file mode 100644 index 12bc1eba..00000000 --- a/test/interpreter/registers.jl +++ /dev/null @@ -1,8 +0,0 @@ -@testset "registers" begin - @test Tapir.register_type(Float64) <: Tapir.AugmentedRegister{CoDual{Float64, Float64}} - @test Tapir.register_type(Bool) <: Tapir.AugmentedRegister{CoDual{Bool, NoTangent}} - @test Tapir.register_type(Any) == Tapir.AugmentedRegister - @test Tapir.register_type(Real) == Tapir.AugmentedRegister - @test ==(Tapir.register_type(Union{Float64, Float32}), Tapir.AugmentedRegister) - @test Tapir.register_type(Union{Float64, Bool}) <: Union{Tapir.AugmentedRegister, Bool} -end diff --git a/test/interpreter/reverse_mode_ad.jl b/test/interpreter/reverse_mode_ad.jl deleted file mode 100644 index b6c6c39f..00000000 --- a/test/interpreter/reverse_mode_ad.jl +++ /dev/null @@ -1,365 +0,0 @@ -array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} - -@testset "reverse_mode_ad" begin - - # Testing specific nodes. - @testset "ReturnNode" begin - @testset "SlotRefs" begin - ret = SlotRef{CoDual{Float64, Float64}}() - ret_tangent = SlotRef{Float64}() - val = SlotRef((CoDual(5.0, 1.0), top_ref(Stack(1.0)))) - fwds_inst, bwds_inst = build_coinsts(ReturnNode, ret, ret_tangent, val) - - # Test forwards instruction. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(5) == -1 - @test ret[] == get_codual(val) - @test (@allocations fwds_inst(5)) == 0 - - # Test backwards instruction. - @test bwds_inst isa Tapir.BwdsInst - ret_tangent[] = 2.0 - @test bwds_inst(5) isa Int - @test get_tangent_stack(val)[] == 3.0 - @test (@allocations bwds_inst(5)) == 0 - end - @testset "val slot is const" begin - ret = SlotRef{CoDual{Float64, Float64}}() - ret_tangent = SlotRef{Float64}() - val = ConstSlot((CoDual(5.0, 0.0), top_ref(Stack(0.0)))) - fwds_inst, bwds_inst = build_coinsts(ReturnNode, ret, ret_tangent, val) - - # Test forwards instruction. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(5) == -1 - @test ret[] == get_codual(val) - @test (@allocations fwds_inst(5)) == 0 - - # Test backwards instruction. - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(5) isa Int - @test (@allocations bwds_inst(5)) == 0 - end - end - @testset "GotoNode" begin - dest = 5 - fwds_inst, bwds_inst = build_coinsts(GotoNode, dest) - - # Test forwards instructions. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(1) == dest - @test (@allocations fwds_inst(1)) == 0 - - # Test reverse instructions. - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(1) == 1 - @test (@allocations bwds_inst(1)) == 0 - end - @testset "GotoIfNot" begin - @testset "SlotRef cond" begin - dest = 5 - next_blk = 3 - cond = SlotRef((zero_codual(true), NoTangentRef())) - fwds_inst, bwds_inst = build_coinsts(GotoIfNot, dest, next_blk, cond) - - # Test forwards instructions. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(1) == next_blk - @test (@allocations fwds_inst(1)) == 0 - cond[] = (zero_codual(false), get_tangent_stack(cond)) - @test fwds_inst(1) == dest - @test (@allocations fwds_inst(1)) == 0 - - # Test backwards instructions. - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(4) == 4 - @test (@allocations bwds_inst(1)) == 0 - end - @testset "ConstSlot" begin - dest = 5 - next_blk = 3 - cond = ConstSlot((zero_codual(true), NoTangentRef())) - fwds_inst, bwds_inst = build_coinsts(GotoIfNot, dest, next_blk, cond) - - # Test forwards instructions. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(1) == next_blk - @test (@allocations fwds_inst(1)) == 0 - - # Test backwards instructions. - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(4) == 4 - @test (@allocations bwds_inst(1)) == 0 - end - end - @testset "TypedPhiNode" begin - @testset "standard example of a phi node" begin - nodes = ( - TypedPhiNode( - SlotRef{Tuple{CoDual{Float64, Float64}, Base.RefArray{Float64, Vector{Float64}, Nothing}}}(), - SlotRef{Tuple{CoDual{Float64, Float64}, Base.RefArray{Float64, Vector{Float64}, Nothing}}}(), - (1, 2), - ( - ConstSlot((CoDual(5.0, 1.0), top_ref(Stack(1.0)))), - SlotRef((CoDual(4.0, 1.2), top_ref(Stack(1.2)))), - ), - ), - TypedPhiNode( - SlotRef{Tuple{CoDual{Union{}, NoTangent}, NoTangentRef}}(), - SlotRef{Tuple{CoDual{Union{}, NoTangent}, NoTangentRef}}(), - (), - (), - ), - TypedPhiNode( - SlotRef{Tuple{CoDual{Int, NoTangent}, NoTangentRef}}(), - SlotRef{Tuple{CoDual{Int, NoTangent}, NoTangentRef}}(), - (1, ), - (SlotRef{Tuple{CoDual{Int, NoTangent}, NoTangentRef}}(),), # undef element - ), - ) - next_blk = 0 - prev_blk = 1 - fwds_inst, bwds_inst = build_coinsts(Vector{PhiNode}, nodes, next_blk) - - # Test forwards instructions. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(1) == next_blk - @test (@allocations fwds_inst(1)) == 0 - @test nodes[1].tmp_slot[] == nodes[1].values[1][] - @test nodes[1].ret_slot[] == nodes[1].tmp_slot[] - @test !isassigned(nodes[2].tmp_slot) - @test !isassigned(nodes[2].ret_slot) - @test nodes[3].tmp_slot[] == nodes[3].values[1][] - @test nodes[3].ret_slot[] == nodes[3].tmp_slot[] - - # Test backwards instructions. - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(4) == 4 - @test (@allocations bwds_inst(1)) == 0 - end - end - @testset "PiNode" begin - val = SlotRef((CoDual{Any, Any}(5.0, 0.0), top_ref(Stack{Any}(0.0)))) - ret = SlotRef{Tuple{CoDual{Float64, Float64}, Base.RefArray{Float64, Vector{Float64}, Nothing}}}() - next_blk = 5 - fwds_inst, bwds_inst = build_coinsts(PiNode, Float64, val, ret, next_blk) - - # Test forwards instruction. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(1) == next_blk - @test primal(get_codual(ret)) == primal(get_codual(val)) - @test tangent(get_codual(ret)) == tangent(get_codual(val)) - @test length(get_tangent_stack(ret)) == 1 - @test get_tangent_stack(ret)[] == tangent(get_codual(val)) - - # Increment tangent associated to `val`. This is done in order to check that the - # tangent to `val` is incremented on the reverse-pass, not replaced. - Tapir.increment_ref!(get_tangent_stack(val), 0.1) - - # Test backwards instruction. - @test bwds_inst isa Tapir.BwdsInst - Tapir.increment_ref!(get_tangent_stack(ret), 1.6) - @test bwds_inst(3) == 3 - @test get_tangent_stack(val)[] == 1.6 + 0.1 # check increment has happened. - end - global __x_for_gref = 5.0 - @testset "GlobalRef" for (P, out, gref, next_blk) in Any[ - ( - Float64, - SlotRef{Tuple{CoDual{Float64, Float64}, array_ref_type(Float64)}}(), - TypedGlobalRef(Main, :__x_for_gref), - 5, - ), - ( - typeof(sin), - SlotRef{Tuple{codual_type(typeof(sin)), NoTangentRef}}(), - ConstSlot(sin), - 4, - ), - ] - fwds_inst, bwds_inst = build_coinsts(GlobalRef, P, gref, out, next_blk) - - # Forwards pass. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(4) == next_blk - @test primal(get_codual(out)) == gref[] - - # Backwards pass. - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(10) == 10 - end - @testset "QuoteNode and literals" for (x, out, next_blk) in Any[ - ( - ConstSlot(CoDual(5, NoTangent())), - SlotRef{Tuple{CoDual{Int, NoTangent}, NoTangentRef}}(), - 5, - ), - ] - fwds_inst, bwds_inst = build_coinsts(nothing, x, out, next_blk) - - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(1) == next_blk - @test get_codual(out) == x[] - @test length(get_tangent_stack(out)) == 1 - @test get_tangent_stack(out)[] == tangent(get_codual(out)) - - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(10) == 10 - end - - @testset "Expr(:boundscheck)" begin - val_ref = SlotRef{Tuple{codual_type(Bool), NoTangentRef}}() - next_blk = 3 - fwds_inst, bwds_inst = build_coinsts(Val(:boundscheck), val_ref, next_blk) - - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(0) == next_blk - @test get_codual(val_ref) == zero_codual(true) - @test length(get_tangent_stack(val_ref)) == 1 - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(2) == 2 - end - - global __int_output = 5 - @testset "Expr(:call)" for (P, out, arg_slots, next_blk) in Any[ - ( - Float64, - SlotRef{Tuple{codual_type(Float64), array_ref_type(Float64)}}(), - ( - ConstSlot((zero_codual(sin), top_ref(Stack(zero_tangent(sin))))), - SlotRef((zero_codual(5.0), top_ref(Stack(0.0)))), - ), - 3, - ), - ( - Any, - SlotRef{Tuple{CoDual, array_ref_type(Any)}}(), - ( - ConstSlot((zero_codual(*), top_ref(Stack(zero_tangent(*))))), - SlotRef((zero_codual(4.0), top_ref(Stack(0.0)))), - ConstSlot((zero_codual(4.0), top_ref(Stack(0.0)))), - ), - 3, - ), - ( - Int, - SlotRef{Tuple{codual_type(Int), NoTangentRef}}(), - ( - ConstSlot((zero_codual(+), top_ref(Stack(zero_tangent(+))))), - ConstSlot((zero_codual(4), NoTangentRef())), - ConstSlot((zero_codual(5), NoTangentRef())), - ), - 2, - ), - ( - Float64, - SlotRef{Tuple{codual_type(Float64), array_ref_type(Float64)}}(), - ( - ConstSlot((zero_codual(getfield), NoTangentRef())), - SlotRef((zero_codual((5.0, 5)), top_ref(Stack(zero_tangent((5.0, 5)))))), - ConstSlot((zero_codual(1), NoTangentRef())), - ), - 3, - ), - ] - sig = _typeof(map(primal ∘ get_codual, arg_slots)) - interp = Tapir.PInterp() - evaluator = Tapir.get_evaluator(Tapir.MinimalCtx(), sig, interp, true) - __rrule!! = Tapir.get_rrule!!_evaluator(evaluator) - pb_stack = Tapir.build_pb_stack(__rrule!!, evaluator, arg_slots) - fwds_inst, bwds_inst = build_coinsts( - Val(:call), P, out, arg_slots, evaluator, __rrule!!, pb_stack, next_blk - ) - - # Test forwards-pass. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(0) == next_blk - - # Test reverse-pass. - @test bwds_inst isa Tapir.BwdsInst - @test bwds_inst(5) == 5 - end - - @testset "Expr(:skipped_expression)" begin - next_blk = 3 - fwds_inst, bwds_inst = build_coinsts(Val(:skipped_expression), next_blk) - - # Test forwards pass. - @test fwds_inst isa Tapir.FwdsInst - @test fwds_inst(1) == next_blk - - # Test backwards pass. - @test bwds_inst isa Tapir.BwdsInst - end - - # @testset "Expr(:throw_undef_if_not)" begin - # @testset "defined" begin - # slot_to_check = SlotRef(5.0) - # oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - # @test oc isa Tapir.Inst - # @test oc(0) == 2 - # end - # @testset "undefined (non-isbits)" begin - # slot_to_check = SlotRef{Any}() - # oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - # @test oc isa Tapir.Inst - # @test_throws ErrorException oc(3) - # end - # @testset "undefined (isbits)" begin - # slot_to_check = SlotRef{Float64}() - # oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - # @test oc isa Tapir.Inst - - # # a placeholder for failing to throw an ErrorException when evaluated - # @test_broken oc(5) == 1 - # end - # end - - interp = Tapir.PInterp() - - # nothings inserted for consistency with generate_test_functions. - @testset "$(_typeof((f, x...)))" for (interface_only, perf_flag, bnds, f, x...) in - TestResources.generate_test_functions() - - sig = _typeof((f, x...)) - @info "$sig" - in_f = Tapir.InterpretedFunction(DefaultCtx(), sig, interp); - - # Verify correctness. - @assert f(deepcopy(x)...) == f(deepcopy(x)...) # primal runs - x_cpy_1 = deepcopy(x) - x_cpy_2 = deepcopy(x) - @test has_equal_data(in_f(f, x_cpy_1...), f(x_cpy_2...)) - @test has_equal_data(x_cpy_1, x_cpy_2) - rule = Tapir.build_rrule!!(in_f); - TestUtils.test_rrule!!( - Xoshiro(123456), in_f, f, x...; - perf_flag, interface_only, is_primitive=false, rule - ) - - # # Estimate primal performance. - # original = @benchmark $(Ref(f))[]($(Ref(deepcopy(x)))[]...); - - # # Estimate interpretered function performance. - # r = @benchmark $(Ref(in_f))[]($(Ref(f))[], $(Ref(deepcopy(x)))[]...); - - # # Estimate overal forwards-pass and pullback performance. - # __rrule!! = Tapir.build_rrule!!(in_f); - # df = zero_codual(in_f); - # codual_x = map(zero_codual, (f, x...)); - # overall_timing = @benchmark TestUtils.to_benchmark($__rrule!!, $df, $codual_x...); - - # # Print the results. - # println("original") - # display(original) - # println() - # println("phi") - # display(r) - # println() - # println("overall") - # display(overall_timing) - # println() - - # @profview run_many_times(10, TestUtils.to_benchmark, __rrule!!, df, codual_x) - end -end diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 25add9b1..78813c10 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -21,7 +21,8 @@ end id_ssa_1 => CC.NewInstruction(nothing, Float64), id_ssa_2 => CC.NewInstruction(nothing, Any), ) - info = ADInfo(Tapir.PInterp(), arg_types, ssa_insts, Any[]) + is_used_dict = Dict{ID, Bool}(id_ssa_1 => true, id_ssa_2 => true) + info = ADInfo(Tapir.PInterp(), arg_types, ssa_insts, is_used_dict, false) # Verify that we can access the interpreter and terminator block ID. @test info.interp isa Tapir.PInterp @@ -48,10 +49,11 @@ end Tapir.PInterp(), Dict{Argument, Any}(Argument(1) => typeof(sin), Argument(2) => Float64), Dict{ID, CC.NewInstruction}( - id_line_1 => CC.NewInstruction(Expr(:invoke, nothing, cos, Argument(2)), Float64), - id_line_2 => CC.NewInstruction(nothing, Any), + id_line_1 => new_inst(Expr(:invoke, nothing, cos, Argument(2)), Float64), + id_line_2 => new_inst(nothing, Any), ), - Any[Tapir.NoTangentStack(), Stack{Float64}()], + Dict{ID, Bool}(id_line_1=>true, id_line_2=>true), + false, ) @testset "Nothing" begin @@ -71,10 +73,10 @@ end end @testset "Argument" begin val = Argument(4) - @test TestUtils.has_equal_data( - make_ad_stmts!(ReturnNode(Argument(2)), line, info), - ad_stmt_info(line, ReturnNode(Argument(3)), nothing), - ) + stmts = make_ad_stmts!(ReturnNode(Argument(2)), line, info) + @test only(stmts.fwds)[2].stmt == ReturnNode(Argument(3)) + @test Meta.isexpr(only(stmts.rvs)[2].stmt, :call) + @test only(stmts.rvs)[2].stmt.args[1] == Tapir.increment_ref! end @testset "literal" begin stmt_info = make_ad_stmts!(ReturnNode(5.0), line, info) @@ -138,8 +140,7 @@ end @testset "differentiable const globals" begin stmt_info = make_ad_stmts!(GlobalRef(S2SGlobals, :const_float), ID(), info) @test stmt_info isa Tapir.ADStmtInfo - @test Meta.isexpr(only(stmt_info.fwds)[2].stmt, :call) - @test only(stmt_info.fwds)[2].stmt.args[1] == identity + @test only(stmt_info.fwds)[2].stmt isa CoDual{Float64} end end @testset "PhiCNode" begin @@ -155,15 +156,6 @@ end ) end @testset "Expr" begin - @testset "invoke" begin - stmt = Expr(:invoke, nothing, cos, Argument(2)) - ad_stmts = make_ad_stmts!(stmt, id_line_1, info) - fwds_stmt = ad_stmts.fwds[2][2].stmt - @test Meta.isexpr(fwds_stmt, :call) - @test fwds_stmt.args[1] == Tapir.__fwds_pass! - @test Meta.isexpr(ad_stmts.rvs[2][2].stmt, :call) - @test ad_stmts.rvs[2][2].stmt.args[1] == Tapir.__rvs_pass! - end @testset "copyast" begin stmt = Expr(:copyast, QuoteNode(:(hi))) ad_stmts = make_ad_stmts!(stmt, ID(), info) @@ -202,29 +194,30 @@ end ) # codual_args = map(zero_codual, (f, x...)) + # fwds_args = map(Tapir.to_fwds, codual_args) # rule = Tapir.build_rrule(interp, sig) - # out, pb!! = rule(codual_args...) + # out, pb!! = rule(fwds_args...) # # @code_warntype optimize=true rule(codual_args...) # # @code_warntype optimize=true pb!!(tangent(out), map(tangent, codual_args)...) # primal_time = @benchmark $f($(Ref(x))[]...) - # s2s_time = @benchmark $rule($codual_args...)[2]($(tangent(out)), $(map(tangent, codual_args))...) - # in_f = in_f = Tapir.InterpretedFunction(DefaultCtx(), sig, interp); - # __rrule!! = Tapir.build_rrule!!(in_f); - # df = zero_codual(in_f); - # codual_x = map(zero_codual, (f, x...)); - # interp_time = @benchmark TestUtils.to_benchmark($__rrule!!, $df, $codual_x...) + # s2s_time = @benchmark $rule($fwds_args...)[2]($(Tapir.zero_rdata(primal(out)))) + # # in_f = in_f = Tapir.InterpretedFunction(DefaultCtx(), sig, interp); + # # __rrule!! = Tapir.build_rrule!!(in_f); + # # df = zero_codual(in_f); + # # codual_x = map(zero_codual, (f, x...)); + # # interp_time = @benchmark TestUtils.to_benchmark($__rrule!!, $df, $codual_x...) # display(primal_time) # display(s2s_time) - # display(interp_time) + # # display(interp_time) # s2s_ratio = time(s2s_time) / time(primal_time) - # interp_ratio = time(interp_time) / time(primal_time) + # # interp_ratio = time(interp_time) / time(primal_time) # println("s2s ratio ratio: $(s2s_ratio)") - # println("interp ratio: $(interp_ratio)") + # # println("interp ratio: $(interp_ratio)") - # f(rule, codual_args, out) = rule(codual_args...)[2](tangent(out), map(tangent, codual_args)...) - # f(rule, codual_args, out) - # @profview(run_many_times(1_000, f, rule, codual_args, out)) + # f(rule, fwds_args, out) = rule(fwds_args...)[2]((Tapir.zero_rdata(primal(out)))) + # f(rule, fwds_args, out) + # @profview(run_many_times(500, f, rule, fwds_args, out)) end end diff --git a/test/interpreter/zero_like_rdata.jl b/test/interpreter/zero_like_rdata.jl new file mode 100644 index 00000000..98d5f4c8 --- /dev/null +++ b/test/interpreter/zero_like_rdata.jl @@ -0,0 +1,13 @@ +@testset "zero_like_rdata" begin + @testset "zero_like_rdata_from_type" begin + @testset "$P" for P in Any[ + @NamedTuple{a}, + Tuple{Any}, + Float64, + Int, + Vector{Float64}, + ] + @test Tapir.zero_like_rdata_from_type(P) isa Tapir.zero_like_rdata_type(P) + end + end +end diff --git a/test/rrules/iddict.jl b/test/rrules/iddict.jl index 000e9583..76fc0a17 100644 --- a/test/rrules/iddict.jl +++ b/test/rrules/iddict.jl @@ -1,10 +1,11 @@ @testset "iddict" begin @testset "IdDict tangent functionality" begin p = IdDict(true => 5.0, false => 4.0) - z = IdDict(true => 3.0, false => 2.0) x = IdDict(true => 1.0, false => 1.0) y = IdDict(true => 2.0, false => 1.0) - test_tangent(sr(123456), p, z, x, y) + z = IdDict(true => 3.0, false => 2.0) + TestUtils.test_tangent(sr(123456), p, x, y, z; interface_only=false, perf=false) + TestUtils.test_fwds_rvs_data(sr(123456), p) end TestUtils.run_rrule!!_test_cases(StableRNG, Val(:iddict)) end diff --git a/test/runtests.jl b/test/runtests.jl index 00401763..e3707d9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,9 @@ include("front_matter.jl") if test_group == "basic" include("utils.jl") include("tangents.jl") + include("fwds_rvs_data.jl") include("codual.jl") + include("safe_mode.jl") include("stack.jl") @testset "interpreter" begin include(joinpath("interpreter", "contexts.jl")) @@ -12,9 +14,7 @@ include("front_matter.jl") include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_utils.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) - include(joinpath("interpreter", "registers.jl")) - # include(joinpath("interpreter", "interpreted_function.jl")) - # include(joinpath("interpreter", "reverse_mode_ad.jl")) + include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) end elseif test_group == "rrules" @@ -28,6 +28,8 @@ include("front_matter.jl") include(joinpath("rrules", "builtins.jl")) @info "foreigncall" include(joinpath("rrules", "foreigncall.jl")) + @info "iddict" + include(joinpath("rrules", "iddict.jl")) @info "lapack" include(joinpath("rrules", "lapack.jl")) @info "low_level_maths" diff --git a/test/safe_mode.jl b/test/safe_mode.jl new file mode 100644 index 00000000..f6b13f17 --- /dev/null +++ b/test/safe_mode.jl @@ -0,0 +1,36 @@ +@testset "safety" begin + + # Forwards-pass tests. + x = (CoDual(sin, NoTangent()), CoDual(5.0, NoFData())) + @test_throws(ErrorException, Tapir.SafeRRule(rrule!!)(x...)) + x = (CoDual(sin, NoFData()), CoDual(5.0, NoFData())) + @test_throws( + ErrorException, Tapir.SafeRRule((x..., ) -> (CoDual(1.0, 0.0), nothing))(x...) + ) + + # Basic type checking. + x = (CoDual(size, NoFData()), CoDual(randn(10), randn(Float16, 11))) + @test_throws ErrorException Tapir.SafeRRule(rrule!!)(x...) + + # Element type checking. Abstractly typed-elements prevent determining incorrectness + # just by looking at the array. + x = ( + CoDual(size, NoFData()), + CoDual(Any[rand() for _ in 1:10], Any[rand(Float16) for _ in 1:10]) + ) + @test_throws ErrorException Tapir.SafeRRule(rrule!!)(x...) + + # Test that bad rdata is caught as a pre-condition. + y, pb!! = Tapir.SafeRRule(rrule!!)(zero_fcodual(sin), zero_fcodual(5.0)) + @test_throws(ArgumentError, pb!!(5)) + + # Test that bad rdata is caught as a post-condition. + rule_with_bad_pb(x::CoDual{Float64}) = x, dy -> (5, ) # returns the wrong type + y, pb!! = Tapir.SafeRRule(rule_with_bad_pb)(zero_fcodual(5.0)) + @test_throws ArgumentError pb!!(1.0) + + # Test that bad rdata is caught as a post-condition. + rule_with_bad_pb_length(x::CoDual{Float64}) = x, dy -> (5, 5.0) # returns the wrong type + y, pb!! = Tapir.SafeRRule(rule_with_bad_pb_length)(zero_fcodual(5.0)) + @test_throws ErrorException pb!!(1.0) +end diff --git a/test/stack.jl b/test/stack.jl index 74c767ed..1c760929 100644 --- a/test/stack.jl +++ b/test/stack.jl @@ -14,17 +14,4 @@ @test length(s) == 0 @test isempty(s) end - @testset "tangent_stack_type" begin - @test Tapir.tangent_stack_type(Float64) == Stack{Float64} - @test Tapir.tangent_stack_type(Int) == Tapir.NoTangentStack - @test Tapir.tangent_stack_type(Any) == Stack{Any} - @test Tapir.tangent_stack_type(DataType) == Stack{Any} - @test Tapir.tangent_stack_type(Type{Float64}) == Tapir.NoTangentStack - - @test Tapir.tangent_ref_type_ub(Float64) == Tapir.__array_ref_type(Float64) - @test Tapir.tangent_ref_type_ub(Int) == Tapir.NoTangentRef - @test Tapir.tangent_ref_type_ub(Any) == Ref - @test Tapir.tangent_ref_type_ub(DataType) == Ref - @test Tapir.tangent_ref_type_ub(Type{Float64}) == Tapir.NoTangentRef - end end diff --git a/test/tangents.jl b/test/tangents.jl index 72ae8e58..51a8ff56 100644 --- a/test/tangents.jl +++ b/test/tangents.jl @@ -1,131 +1,7 @@ @testset "tangents" begin - # Each tuple is of the form (primal, t1, t2, increment!!(t1, t2)). - @testset "$(typeof(p))" for (p, x, y, z) in vcat( - [ - (sin, NoTangent(), NoTangent(), NoTangent()), - map(Float16, (5.0, 4.0, 3.1, 7.1)), - (5f0, 4f0, 3f0, 7f0), - (5.1, 4.0, 3.0, 7.0), - (svec(5.0), Any[4.0], Any[3.0], Any[7.0]), - ([3.0, 2.0], [1.0, 2.0], [2.0, 3.0], [3.0, 5.0]), - ( - [1, 2], - [NoTangent(), NoTangent()], - [NoTangent(), NoTangent()], - [NoTangent(), NoTangent()], - ), - ( - [[1.0], [1.0, 2.0]], - [[2.0], [2.0, 3.0]], - [[3.0], [4.0, 5.0]], - [[5.0], [6.0, 8.0]], - ), - ( - setindex!(Vector{Vector{Float64}}(undef, 2), [1.0], 1), - setindex!(Vector{Vector{Float64}}(undef, 2), [2.0], 1), - setindex!(Vector{Vector{Float64}}(undef, 2), [3.0], 1), - setindex!(Vector{Vector{Float64}}(undef, 2), [5.0], 1), - ), - ( - setindex!(Vector{Vector{Float64}}(undef, 2), [1.0], 2), - setindex!(Vector{Vector{Float64}}(undef, 2), [2.0], 2), - setindex!(Vector{Vector{Float64}}(undef, 2), [3.0], 2), - setindex!(Vector{Vector{Float64}}(undef, 2), [5.0], 2), - ), - ( - (6.0, [1.0, 2.0]), - (5.0, [3.0, 4.0]), - (4.0, [4.0, 3.0]), - (9.0, [7.0, 7.0]), - ), - ((), NoTangent(), NoTangent(), NoTangent()), - ((1,), NoTangent(), NoTangent(), NoTangent()), - ((2, 3), NoTangent(), NoTangent(), NoTangent()), - ( - (a=6.0, b=[1.0, 2.0]), - (a=5.0, b=[3.0, 4.0]), - (a=4.0, b=[4.0, 3.0]), - (a=9.0, b=[7.0, 7.0]), - ), - ((;), NoTangent(), NoTangent(), NoTangent()), - ( - TypeStableMutableStruct{Float64}(5.0, 3.0), - build_tangent(TypeStableMutableStruct{Float64}, 5.0, 4.0), - build_tangent(TypeStableMutableStruct{Float64}, 3.0, 3.0), - build_tangent(TypeStableMutableStruct{Float64}, 8.0, 7.0), - ), - ( # complete init - StructFoo(6.0, [1.0, 2.0]), - build_tangent(StructFoo, 5.0, [3.0, 4.0]), - build_tangent(StructFoo, 3.0, [2.0, 1.0]), - build_tangent(StructFoo, 8.0, [5.0, 5.0]), - ), - ( # partial init - StructFoo(6.0), - build_tangent(StructFoo, 5.0), - build_tangent(StructFoo, 4.0), - build_tangent(StructFoo, 9.0), - ), - ( # complete init - MutableFoo(6.0, [1.0, 2.0]), - build_tangent(MutableFoo, 5.0, [3.0, 4.0]), - build_tangent(MutableFoo, 3.0, [2.0, 1.0]), - build_tangent(MutableFoo, 8.0, [5.0, 5.0]), - ), - ( # partial init - MutableFoo(6.0), - build_tangent(MutableFoo, 5.0), - build_tangent(MutableFoo, 4.0), - build_tangent(MutableFoo, 9.0), - ), - (UnitRange{Int}(5, 7), NoTangent(), NoTangent(), NoTangent()), - ], - map([ - LowerTriangular{Float64, Matrix{Float64}}, - UpperTriangular{Float64, Matrix{Float64}}, - UnitLowerTriangular{Float64, Matrix{Float64}}, - UnitUpperTriangular{Float64, Matrix{Float64}}, - ]) do T - return ( - T(randn(2, 2)), - build_tangent(T, [1.0 2.0; 3.0 4.0]), - build_tangent(T, [2.0 1.0; 5.0 4.0]), - build_tangent(T, [3.0 3.0; 8.0 8.0]), - ) - end, - [ - (p, NoTangent(), NoTangent(), NoTangent()) for p in - [Array, Float64, Union{Float64, Float32}, Union, UnionAll, - Core.Intrinsics.xor_int, typeof(<:)] - ], - ) - rng = Xoshiro(123456) - test_tangent(rng, p, z, x, y) - end - - __x = randn(10) - p = pointer(__x, 3) - @testset "set_immutable_to_zero($(Tapir._typeof(x)))" for x in Any[ - NoTangent(), - 5.0, - 5f0, - (5.0, NoTangent()), - (a=5.0, b=NoTangent(), c=(5.0, )), - randn(5), - [randn(3), 5.0, NoTangent()], - randn_tangent(Xoshiro(1), TestResources.StableFoo(5.0, :hi)), - randn_tangent(Xoshiro(1), TestResources.MutableFoo(5.0, randn(3))), - p, - ] - @test Tapir.set_immutable_to_zero(x) isa Tapir._typeof(x) - end - - # Bulk test auto-generated tangents. - @testset "autogenerated $n $x" for (n, x) in collect(enumerate(Tapir.TestTypes.PRIMALS)) - (interface_only, p) = x - TestUtils.test_tangent_consistency(sr(1), p; interface_only) - TestUtils.test_tangent_performance(sr(1), p) + @testset "$(typeof(data))" for (interface_only, data...) in Tapir.tangent_test_cases() + test_tangent(Xoshiro(123456), data...; interface_only) end tangent(nt::NamedTuple) = Tangent(map(PossiblyUninitTangent, nt)) @@ -149,6 +25,9 @@ # Slow versions. @test increment_field!!(x, y, 1) == (8.0, nt) @test increment_field!!(x, nt, 2) == (5.0, nt) + + # Homogeneous type optimisation. + @test @inferred(increment_field!!((5.0, 4.0), 3.0, 2)) == (5.0, 7.0) end @testset "NamedTuple" begin nt = NoTangent() @@ -163,6 +42,10 @@ @test increment_field!!(x, nt, :b) == (a=5.0, b=nt) @test increment_field!!(x, 3.0, 1) == (a=8.0, b=nt) @test increment_field!!(x, nt, 2) == (a=5.0, b=nt) + + # Homogeneous type optimisation. + @test @inferred(increment_field!!((a=5.0, b=4.0), 3.0, 1)) == (a=8.0, b=4.0) + @test @inferred(increment_field!!((a=5.0, b=4.0), 3.0, :a)) == (a=8.0, b=4.0) end @testset "Tangent" begin nt = NoTangent() @@ -216,6 +99,7 @@ tangent_type(Union{Tuple{Float64}, Tuple{Int}, Tuple{Float64, Int}}), Union{Tuple{Float64}, NoTangent, Tuple{Float64, NoTangent}}, ) + @test tangent_type(Tuple{Any, Any}) == Any end end diff --git a/test/utils.jl b/test/utils.jl index 6e5ba8d7..9425ec11 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -14,5 +14,66 @@ map(*, (5, 4.0, 3), (5.0, 4, 3.0)), Tapir.tuple_map(*, (5, 4.0, 3), (5.0, 4, 3.0)), ) + + @test map(sin, (a=5.0, b=4)) == Tapir.tuple_map(sin, (a=5.0, b=4)) + @test ==( + map(*, (a=5, b=4.0, c=3), (a=5.0, b=4, c=3.0)), + Tapir.tuple_map(*, (a=5, b=4.0, c=3), (a=5.0, b=4, c=3.0)), + ) + + # Require that length of arguments are equal. + @test_throws ArgumentError Tapir.tuple_map(*, (5.0, 4.0), (4.0, )) + @test_throws ArgumentError Tapir.tuple_map(*, (4.0, ), (5.0, 4.0)) + end + @testset "_map_if_assigned!" begin + @testset "unary bits type" begin + x = Vector{Float64}(undef, 10) + y = randn(10) + z = Tapir._map_if_assigned!(sin, y, x) + @test z === y + @test all(map(isequal, z, map(sin, x))) + end + @testset "unary non bits type" begin + x = Vector{Vector{Float64}}(undef, 2) + x[1] = randn(5) + y = [1.0, 1.0] + z = Tapir._map_if_assigned!(sum, y, x) + @test z === y + + # The first element of `x` is assigned, so z[1] should be its sum. + @test z[1] ≈ sum(x[1]) + + # The second element of `x` is unassigned, so z[2] should be unchanged. + @test z[2] == 1.0 + end + @testset "binary bits type" begin + x1 = Vector{Float64}(undef, 7) + x2 = Vector{Float64}(undef, 7) + y = Vector{Float64}(undef, 7) + z = Tapir._map_if_assigned!(*, y, x1, x2) + @test z === y + @test all(map(isequal, z, map(*, x1, x2))) + end + @testset "binary non bits type" begin + x1 = Vector{Vector{Float64}}(undef, 2) + x1[1] = randn(3) + x2 = [randn(3), randn(2)] + y = [1.0, 1.0] + z = Tapir._map_if_assigned!(dot, y, x1, x2) + @test z === y + + # The first element of x1 is assigned, so should have the inner product in z[1]. + @test z[1] ≈ dot(x1[1], x2[1]) + + # The second element of x2 is not assigned, so z[2] should be unchanged. + @test z[2] == 1 + end + end + @testset "_map" begin + x = randn(10) + y = randn(10) + @test Tapir._map(*, x, y) == map(*, x, y) + @assert length(map(*, x, randn(11))) == 10 + @test_throws AssertionError Tapir._map(*, x, randn(11)) end end