Skip to content

Commit

Permalink
Merge pull request #1006 from FluxML/ox/typeonlyrrules
Browse files Browse the repository at this point in the history
use rrules even when all the arguments are types
  • Loading branch information
oxinabox authored Jun 24, 2021
2 parents 87e2f12 + b9f186f commit b170521
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ function edge!(m::IRTools.Meta, edge::Core.MethodInstance)
end

@generated function _pullback(ctx::AContext, f, args...)
T = Tuple{f,args...}
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))

# Try using ChainRulesCore
if is_kwfunc(f, args...)
# if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
cr_T = Tuple{ZygoteRuleConfig{ctx}, args[2:end]...}
Expand All @@ -20,9 +18,12 @@ end
end

hascr, cr_edge = has_chain_rrule(cr_T)

hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))

# No ChainRule, going to have to work it out.
T = Tuple{f,args...}
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))

g = try _generate_pullback_via_decomposition(T) catch e e end
g === nothing && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
Expand Down
18 changes: 18 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,24 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2), 10.4)
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2; kw=2.0), 10.4)
end

@testset "Type only rrule" begin
struct StructForTestingTypeOnlyRRules{T}
x::T
end
StructForTestingTypeOnlyRRules() = StructForTestingTypeOnlyRRules(1.0)

function ChainRulesCore.rrule(P::Type{<:StructForTestingTypeOnlyRRules})
# notice here we mess with the primal doing 2.0 rather than 1.0, this is for testing purposes
# and also because apparently people actually want to do this. Weird, but 🤷
# https://github.com/SciML/SciMLBase.jl/issues/69#issuecomment-865639754
P(2.0), _ -> (NoTangent(),)
end

@assert StructForTestingTypeOnlyRRules().x == 1.0
aug_primal_val, _ = Zygote.pullback(x->StructForTestingTypeOnlyRRules(), 1.2)
@test aug_primal_val.x == 2.0
end
end

@testset "ChainRulesCore.rrule_via_ad" begin
Expand Down

0 comments on commit b170521

Please sign in to comment.