Skip to content

Commit

Permalink
Merge pull request #1419 from mcabbott/withgrad3
Browse files Browse the repository at this point in the history
Allow `f` to return a Tuple in `withgradient(f, args...)`
  • Loading branch information
mcabbott authored Jul 10, 2023
2 parents 2f49370 + e0d3d8b commit 29fa32a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
36 changes: 33 additions & 3 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,34 @@ as a named tuple.
julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
julia> ∇ == gradient(/, 1, 2) # explicit mode
julia> ∇ == gradient(/, 1, 2)
true
```
Allows you to capture auxillary outputs, in addition to the scalar
used by `gradient`. To do this, `f` must return a Tuple or NamedTuple.
Then it calculates `grad = gradient(first∘f, args...)
but returns the whole `val = f(args...)`:
```jldoctest; setup=:(using Zygote)
julia> withgradient([1,2,4]) do x
z = 1 ./ x
sum(z), z # here z is an auxillary output
end
(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))
julia> withgradient(3.0, 4.0) do x, y
(div = x/y, mul = x*y)
end
(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
```
Also supports implicit mode:
```jldoctest; setup=:(using Zygote)
julia> w = [3.0];
julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
julia> res = withgradient(() -> sum(abs2, w), Params([w]))
(val = 9.0, grad = Grads(...))
julia> res.grad[w]
Expand All @@ -130,7 +152,15 @@ julia> res.grad[w]
"""
function withgradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
grad = if y isa Tuple
dy = (sensitivity(first(y)), map(_ -> nothing, Base.tail(y))...)
back(dy)
elseif y isa NamedTuple
dy = (sensitivity(first(y)), map(_ -> nothing, Base.tail(y))...)
back(NamedTuple{propertynames(y), typeof(dy)}(dy))
else
back(sensitivity(y))
end
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
(val=y, grad=results)
end
Expand Down
18 changes: 18 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,21 @@ end
end
@test gradient(f760, 3)[1] 123.93054835019153
end

@testset "withgradient" begin
@test withgradient([1,2,4]) do x
z = 1 ./ x
sum(z), z
end == (val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))

@test withgradient(3.0, 4.0) do x, y
(div = x/y, mul = x*y)
end == (val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))

f3(x) = sum(sin, x), sum(cos, x), sum(tan, x)
g1 = gradient(firstf3, [1,2,3.0])
y2, g2 = withgradient(firstf3, [1,2,3.0])
y3, g3 = withgradient(f3, [1,2,3.0])
@test g1[1] g2[1] g3[1]
end

0 comments on commit 29fa32a

Please sign in to comment.