diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index cce7c4d6d..c09d6db31 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -115,12 +115,34 @@ as a named tuple. julia> y, ∇ = withgradient(/, 1, 2) (val = 0.5, grad = (0.5, -0.25)) -julia> ∇ == gradient(/, 1, 2) # explicit mode +julia> ∇ == gradient(/, 1, 2) true +``` + +Allows you to capture auxillary outputs, in addition to the scalar +used by `gradient`. To do this, `f` must return a Tuple or NamedTuple. +Then it calculates `grad = gradient(first∘f, args...) +but returns the whole `val = f(args...)`: + +```jldoctest; setup=:(using Zygote) +julia> withgradient([1,2,4]) do x + z = 1 ./ x + sum(z), z # here z is an auxillary output + end +(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],)) + +julia> withgradient(3.0, 4.0) do x, y + (div = x/y, mul = x*y) + end +(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875)) +``` + +Also supports implicit mode: +```jldoctest; setup=:(using Zygote) julia> w = [3.0]; -julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode +julia> res = withgradient(() -> sum(abs2, w), Params([w])) (val = 9.0, grad = Grads(...)) julia> res.grad[w] @@ -130,7 +152,15 @@ julia> res.grad[w] """ function withgradient(f, args...) y, back = pullback(f, args...) - grad = back(sensitivity(y)) + grad = if y isa Tuple + dy = (sensitivity(first(y)), map(_ -> nothing, Base.tail(y))...) + back(dy) + elseif y isa NamedTuple + dy = (sensitivity(first(y)), map(_ -> nothing, Base.tail(y))...) + back(NamedTuple{propertynames(y), typeof(dy)}(dy)) + else + back(sensitivity(y)) + end results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad) (val=y, grad=results) end diff --git a/test/features.jl b/test/features.jl index 0499987d8..908ae5815 100644 --- a/test/features.jl +++ b/test/features.jl @@ -866,3 +866,21 @@ end end @test gradient(f760, 3)[1] ≈ 123.93054835019153 end + +@testset "withgradient" begin + @test withgradient([1,2,4]) do x + z = 1 ./ x + sum(z), z + end == (val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],)) + + @test withgradient(3.0, 4.0) do x, y + (div = x/y, mul = x*y) + end == (val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875)) + + f3(x) = sum(sin, x), sum(cos, x), sum(tan, x) + g1 = gradient(first∘f3, [1,2,3.0]) + y2, g2 = withgradient(first∘f3, [1,2,3.0]) + y3, g3 = withgradient(f3, [1,2,3.0]) + @test g1[1] ≈ g2[1] ≈ g3[1] +end +