From 163e1731d0a3bbc71464772d61501572be931208 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 22 Jun 2021 18:45:47 +0100 Subject: [PATCH 1/3] use rrules even when all the arguments are types --- src/compiler/interface2.jl | 10 ++++++---- test/chainrules.jl | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index ac4a5a76a..f0c4fa690 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -7,22 +7,24 @@ function edge!(m::IRTools.Meta, edge::Core.MethodInstance) end @generated function _pullback(ctx::AContext, f, args...) - T = Tuple{f,args...} - ignore_sig(T) && return :(f(args...), Pullback{$T}(())) - + # Try using ChainRulesCore if is_kwfunc(f, args...) # if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function cr_T = Tuple{ZygoteRuleConfig{ctx}, args[2:end]...} chain_rrule_f = :chain_rrule_kw else cr_T = Tuple{ZygoteRuleConfig{ctx}, f, args...} + Core.println("cr_T=", cr_T) chain_rrule_f = :chain_rrule end hascr, cr_edge = has_chain_rrule(cr_T) - hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...)) + # No ChainRule, going to have to work it out. + T = Tuple{f,args...} + ignore_sig(T) && return :(f(args...), Pullback{$T}(())) + g = try _generate_pullback_via_decomposition(T) catch e e end g === nothing && return :(f(args...), Pullback{$T}((f,))) meta, forw, _ = g diff --git a/test/chainrules.jl b/test/chainrules.jl index 519c10ad6..66058c93d 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -214,6 +214,24 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote @test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2), 10.4) @test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2; kw=2.0), 10.4) end + + @testset "Type only rrule" begin + struct StructForTestingTypeOnlyRRules{T} + x::T + end + StructForTestingTypeOnlyRRules() = StructForTestingTypeOnlyRRules(1.0) + + function ChainRulesCore.rrule(P::Type{<:StructForTestingTypeOnlyRRules}) + # notice here we mess with the primal doing 2.0 rather than 1.0, this is for testing purposes + # and also because apparently people actually want to do this. Weird, but 🤷 + # https://github.com/SciML/SciMLBase.jl/issues/69#issuecomment-865639754 + P(2.0), _->NoTangent() + end + + @assert StructForTestingTypeOnlyRRules().x == 1.0 + aug_primal_val, _ = Zygote.pullback(x->StructForTestingTypeOnlyRRules(), 1.2) + @test aug_primal_val.x == 2.0 + end end @testset "ChainRulesCore.rrule_via_ad" begin From 2dab48fdfaddd8a908c341eb011ef82817fba0f9 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 23 Jun 2021 17:50:32 +0100 Subject: [PATCH 2/3] Remove leftove debugging statements --- src/compiler/interface2.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index f0c4fa690..0f7da4b32 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -14,7 +14,6 @@ end chain_rrule_f = :chain_rrule_kw else cr_T = Tuple{ZygoteRuleConfig{ctx}, f, args...} - Core.println("cr_T=", cr_T) chain_rrule_f = :chain_rrule end From b9f186f8f044ad94b469773ae8fe722fea457983 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 24 Jun 2021 08:48:27 +0100 Subject: [PATCH 3/3] Update test/chainrules.jl Co-authored-by: Dhairya Gandhi --- test/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 66058c93d..32bdd3799 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -225,7 +225,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote # notice here we mess with the primal doing 2.0 rather than 1.0, this is for testing purposes # and also because apparently people actually want to do this. Weird, but 🤷 # https://github.com/SciML/SciMLBase.jl/issues/69#issuecomment-865639754 - P(2.0), _->NoTangent() + P(2.0), _ -> (NoTangent(),) end @assert StructForTestingTypeOnlyRRules().x == 1.0