diff --git a/dev/adjoints/index.html b/dev/adjoints/index.html index 3c675637d..667930c62 100644 --- a/dev/adjoints/index.html +++ b/dev/adjoints/index.html @@ -100,4 +100,4 @@ 1 levels of nesting julia> grad(x -> x*grad(f, x), 1); -2 levels of nesting +2 levels of nesting diff --git a/dev/complex/index.html b/dev/complex/index.html index ca8022f7f..4b8c216bd 100644 --- a/dev/complex/index.html +++ b/dev/complex/index.html @@ -27,4 +27,4 @@ (8.0 + 12.0im, 0.0 + 0.0im) julia> wirtinger(x -> abs2(x), 1+2im) -(1.0 - 2.0im, 1.0 + 2.0im) +(1.0 - 2.0im, 1.0 + 2.0im) diff --git a/dev/glossary/index.html b/dev/glossary/index.html index be11966be..2fad4fa59 100644 --- a/dev/glossary/index.html +++ b/dev/glossary/index.html @@ -6,4 +6,4 @@ ga('create', 'UA-36890222-9', 'auto'); ga('send', 'pageview', {'page': location.pathname + location.search + location.hash}); -

Glossary

Differentiation is a minefield of conflicting and overlapping terminology, partly because the ideas have been re-discovered in many different fields (e.g. calculus and differential geometry, the traditional AD community, deep learning, finance, etc.) Many of these terms are not well-defined and others may disagree on the details. Nevertheless, we aim to at least say how we use these terms, which will be helpful when reading over Zygote issues, discussions and source code.

The list is certainly not complete; if you see new terms you'd like defined, or would like to add one yourself, please do open an issue or PR.

Adjoint: See pullback. Used when defining new pullbacks (i.e. the @adjoint macro) since this involves defining the adjoint of the Jacobian, in most cases.

Backpropagation: Essentially equivalent to "reverse-mode AD". Used particularly in the machine learning world to refer to simple chains of functions f(g(h(x))), but has generalised beyond that.

Derivative: Given a scalar function $y = f(x)$, the derivative is $\frac{\partial y}{\partial x}$. "Partial" is taken for granted in AD; there's no interesting distinction between partial and total derivatives for our purposes. It's all in the eye of the beholder.

Differential: Given a function $f(x)$, the linearisation $\partial f$ such that $f(x + \epsilon) \approx f(x) + \partial f \epsilon$. This is a generalisation of the derivative since it applies to, for example, vector-to-vector functions ($\partial f$ is a Jacobian) and holomorphic complex functions ($\partial f$ is the first Wirtinger derivative). This is not, in general, what Zygote calculates, though differentials can usually be derived from gradients.

IR: Intermediate Representation. Essentially source code, but usually lower level – e.g. control flow constructs like loops and branches have all been replaced by gotos. The idea is that it's harder for humans to read/write but easier to manipulate programmatically. Worth looking at SSA form as a paradigmatic example.

Gradient: See sensitivity. There is no technical difference in Zygote's view, though "gradient" sometimes distinguishes the sensitivity we actually want from e.g. the internal ones that Zygote produces as it backpropagates.

Graph: ML people tend to think of models as "computation graphs", but this is no more true than any program is a graph. In fact, pretty much anything is a graph if you squint hard enough. This also refers to the data structure that e.g. TensorFlow and PyTorch build to represent your model, but see trace for that.

Pullback: Given $y = f(x)$ the function $\bar x = back(̄\bar y)$. In other words, the function back in y, back = Zygote.pullback(f, x).

Sensitivity: Used to refer to the gradient $\bar x = \frac{\partial l}{\partial x}$ with some scalar loss $l$. In other words, you have a value $x$ (which need not be scalar) at some point in your program, and $\bar x$ tells you how you should change that value to decrease the loss. In the AD world, sometimes used to refer to adjoint rules.

Source to Source Differentiation: Or Source Code Transformation (SCT). As opposed to tracing programs to simplify them, an alternative is to operate directly on a language's source code or IR, generating new source code for pullbacks. This describes Zygote, Swift for TensorFlow, Tapenade and a few other old ADs that worked on C source files. Zygote and Swift are unusual in that they work on in-memory IR rather than text source.

To an extent, tracing ADs can be viewed as source transform of a Wengert list / trace. The key difference is that the trace is a lossy representation of the original semantics, which causes problems with e.g. control flow. Systems which can preserve some of those semantics (e.g. autograph) begin to blur the line here, though they are still not nearly as expressive as language IRs.

Symbolic Differentiation: Used to refer to differentiation of "mathematical expressions", that is, things like 3x^2 + sin(x). Often distinguished from AD, though this is somewhat arbitrary; you can happily produce a symbolic adjoint for a Wengert list, the only difference being that you're allowed to make variable bindings. So it's really just a special case of AD on an unusually limited language.

Tape: This term can refer to pretty much any part of an AD implementation. In particular confusion is caused by conflating the trace with the set of values sometimes closed over by a pullback. Autograd has a combined trace/closure data structure which is usually described as the tape. On the other hand, PyTorch described their implementation as tape-free because the trace/closure is stored as a DAG rather than a vector, so basically all bets are off here.

Trace: A recording of each mathematical operation used by a program, made at runtime and usually forming a Wengert list. Traces may or may not also record actual runtime values (e.g. PyTorch vs. TensorFlow). They can often be treated as an IR and compiled, but are distinguished from true IRs in that they unroll and inline all control flow, functions and data structures. The tracing process can be thought of as a kind of partial evaluation, though tracers are typically much less worried about losing information.

Vector-Jacobian product: see pullback. So called because all pullbacks are linear functions that can be represented by (left) multiplication with the Jacobian matrix.

Wengert List: A set of simple variable assignments and mathematical expressions, forming a directed graph. Can be thought of as a limited programming language with variable bindings and numerical functions but no control flow or data structures. If you trace a program for AD it will typically take this form.

+

Glossary

Differentiation is a minefield of conflicting and overlapping terminology, partly because the ideas have been re-discovered in many different fields (e.g. calculus and differential geometry, the traditional AD community, deep learning, finance, etc.) Many of these terms are not well-defined and others may disagree on the details. Nevertheless, we aim to at least say how we use these terms, which will be helpful when reading over Zygote issues, discussions and source code.

The list is certainly not complete; if you see new terms you'd like defined, or would like to add one yourself, please do open an issue or PR.

Adjoint: See pullback. Used when defining new pullbacks (i.e. the @adjoint macro) since this involves defining the adjoint of the Jacobian, in most cases.

Backpropagation: Essentially equivalent to "reverse-mode AD". Used particularly in the machine learning world to refer to simple chains of functions f(g(h(x))), but has generalised beyond that.

Derivative: Given a scalar function $y = f(x)$, the derivative is $\frac{\partial y}{\partial x}$. "Partial" is taken for granted in AD; there's no interesting distinction between partial and total derivatives for our purposes. It's all in the eye of the beholder.

Differential: Given a function $f(x)$, the linearisation $\partial f$ such that $f(x + \epsilon) \approx f(x) + \partial f \epsilon$. This is a generalisation of the derivative since it applies to, for example, vector-to-vector functions ($\partial f$ is a Jacobian) and holomorphic complex functions ($\partial f$ is the first Wirtinger derivative). This is not, in general, what Zygote calculates, though differentials can usually be derived from gradients.

IR: Intermediate Representation. Essentially source code, but usually lower level – e.g. control flow constructs like loops and branches have all been replaced by gotos. The idea is that it's harder for humans to read/write but easier to manipulate programmatically. Worth looking at SSA form as a paradigmatic example.

Gradient: See sensitivity. There is no technical difference in Zygote's view, though "gradient" sometimes distinguishes the sensitivity we actually want from e.g. the internal ones that Zygote produces as it backpropagates.

Graph: ML people tend to think of models as "computation graphs", but this is no more true than any program is a graph. In fact, pretty much anything is a graph if you squint hard enough. This also refers to the data structure that e.g. TensorFlow and PyTorch build to represent your model, but see trace for that.

Pullback: Given $y = f(x)$ the function $\bar x = back(̄\bar y)$. In other words, the function back in y, back = Zygote.pullback(f, x).

Sensitivity: Used to refer to the gradient $\bar x = \frac{\partial l}{\partial x}$ with some scalar loss $l$. In other words, you have a value $x$ (which need not be scalar) at some point in your program, and $\bar x$ tells you how you should change that value to decrease the loss. In the AD world, sometimes used to refer to adjoint rules.

Source to Source Differentiation: Or Source Code Transformation (SCT). As opposed to tracing programs to simplify them, an alternative is to operate directly on a language's source code or IR, generating new source code for pullbacks. This describes Zygote, Swift for TensorFlow, Tapenade and a few other old ADs that worked on C source files. Zygote and Swift are unusual in that they work on in-memory IR rather than text source.

To an extent, tracing ADs can be viewed as source transform of a Wengert list / trace. The key difference is that the trace is a lossy representation of the original semantics, which causes problems with e.g. control flow. Systems which can preserve some of those semantics (e.g. autograph) begin to blur the line here, though they are still not nearly as expressive as language IRs.

Symbolic Differentiation: Used to refer to differentiation of "mathematical expressions", that is, things like 3x^2 + sin(x). Often distinguished from AD, though this is somewhat arbitrary; you can happily produce a symbolic adjoint for a Wengert list, the only difference being that you're allowed to make variable bindings. So it's really just a special case of AD on an unusually limited language.

Tape: This term can refer to pretty much any part of an AD implementation. In particular confusion is caused by conflating the trace with the set of values sometimes closed over by a pullback. Autograd has a combined trace/closure data structure which is usually described as the tape. On the other hand, PyTorch described their implementation as tape-free because the trace/closure is stored as a DAG rather than a vector, so basically all bets are off here.

Trace: A recording of each mathematical operation used by a program, made at runtime and usually forming a Wengert list. Traces may or may not also record actual runtime values (e.g. PyTorch vs. TensorFlow). They can often be treated as an IR and compiled, but are distinguished from true IRs in that they unroll and inline all control flow, functions and data structures. The tracing process can be thought of as a kind of partial evaluation, though tracers are typically much less worried about losing information.

Vector-Jacobian product: see pullback. So called because all pullbacks are linear functions that can be represented by (left) multiplication with the Jacobian matrix.

Wengert List: A set of simple variable assignments and mathematical expressions, forming a directed graph. Can be thought of as a limited programming language with variable bindings and numerical functions but no control flow or data structures. If you trace a program for AD it will typically take this form.

diff --git a/dev/index.html b/dev/index.html index e1b0b2503..2b38bb21b 100644 --- a/dev/index.html +++ b/dev/index.html @@ -79,7 +79,7 @@ p = size(x, d) sum(x.^p .+ y) end -([14.0, 22.0], 2.0, nothing)source
julia> linear(θ, x) = θ[:W] * x .+ θ[:b]
+([14.0, 22.0], 2.0, nothing)
source
julia> linear(θ, x) = θ[:W] * x .+ θ[:b]
 linear (generic function with 1 method)
 
 julia> x = rand(5);
@@ -121,7 +121,7 @@
  8.0  80.0  800.0
 
 julia> haskey(g, z)  # only x and y are parameters
-false
source
julia> W = rand(2, 5); b = rand(2);
+false
source
julia> W = rand(2, 5); b = rand(2);
 
 julia> linear(x) = W * x .+ b
 linear (generic function with 2 methods)
@@ -130,4 +130,4 @@
 Grads(...)
 
 julia> grads[W], grads[b] # access gradients using arrays as keys
-([0.652543 … 0.683588], [1.0, 1.0])

Here grads is a dictionary-like object, whose keys are the same parameters we indicated in Params. (In fact it wraps a dictionary using objectid(W) as keys, which does not change if the values in W are mutated).

This implicit style is the one presently used by Flux.jl, a closely related machine learning library. It uses structs like Linear above to define layers, and the function Flux.params(model) returns a Params object containing all the parameters of all layers. See its documentation for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.

+([0.652543 … 0.683588], [1.0, 1.0])

Here grads is a dictionary-like object, whose keys are the same parameters we indicated in Params. (In fact it wraps a dictionary using objectid(W) as keys, which does not change if the values in W are mutated).

This implicit style is the one presently used by Flux.jl, a closely related machine learning library. It uses structs like Linear above to define layers, and the function Flux.params(model) returns a Params object containing all the parameters of all layers. See its documentation for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.

diff --git a/dev/internals/index.html b/dev/internals/index.html index faab2702d..9ab9ead1f 100644 --- a/dev/internals/index.html +++ b/dev/internals/index.html @@ -135,4 +135,4 @@ julia> y, back = Zygote._pullback(bad, 1); julia> back(1) # ok, here's our issue. Lather, rinse, repeat. -ERROR: bad

Of course, our goal is that you never have to do this, but until Zygote is more mature it can be a useful way to narrow down test cases.

+ERROR: bad

Of course, our goal is that you never have to do this, but until Zygote is more mature it can be a useful way to narrow down test cases.

diff --git a/dev/limitations/index.html b/dev/limitations/index.html index 269a945bf..17de4ce6b 100644 --- a/dev/limitations/index.html +++ b/dev/limitations/index.html @@ -23,7 +23,7 @@ https://fluxml.ai/Zygote.jl/latest/limitations Stacktrace: - ...

We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling copyto! (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes x .= ... which is given as an example of array mutation. Other examples of mutating operations include:

Warning

Non-mutating functions may also use mutation under the hood. This can be done for performance reasons or code re-use.

function g!(x, y)
+  ...

We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling copyto! (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes x .= ... which is given as an example of array mutation. Other examples of mutating operations include:

Warning

Non-mutating functions might also use mutation under the hood. This can be done for performance reasons or code re-use.

function g!(x, y)
   x .= 2 .* y
 
   return x
@@ -87,4 +87,4 @@
     tot += x^n  # binds symbol `tot` to new value
   end
   return tot
-end

However, sometimes such re-binding confuses Zygote, especially if the type of the value changes. Especially if the variable is "boxed", as will happen if you re-bind from within a closure (such as the function created by a do block).

Second derivatives

In principle Zygote supports taking derivatives of derivatives. There are, however, a few problems:

The issue tracker has a label for second order, which will outline where the bodies are buried.

Often using a different AD system over Zygote is a better solution. This is what hessian does, using ForwardDiff over Zygote, but other combinations are possible. (Note that rules defined here mean that Zygote over ForwardDiff is translated to ForwardDiff over ForwardDiff.)

+end

However, sometimes such re-binding confuses Zygote, especially if the type of the value changes. Especially if the variable is "boxed", as will happen if you re-bind from within a closure (such as the function created by a do block).

Second derivatives

In principle Zygote supports taking derivatives of derivatives. There are, however, a few problems:

The issue tracker has a label for second order, which will outline where the bodies are buried.

Often using a different AD system over Zygote is a better solution. This is what hessian does, using ForwardDiff over Zygote, but other combinations are possible. (Note that rules defined here mean that Zygote over ForwardDiff is translated to ForwardDiff over ForwardDiff.)

diff --git a/dev/profiling/index.html b/dev/profiling/index.html index f73169869..57dfed967 100644 --- a/dev/profiling/index.html +++ b/dev/profiling/index.html @@ -28,4 +28,4 @@ │ %2 = (Base.mul_int)(Δ, 1)::Int64 │ %3 = (Zygote.tuple)(nothing, %1, %2)::PartialTuple(Tuple{Nothing,Int64,Int64}, Any[Const(nothing, false), Int64, Int64]) └── return %3 -) => Tuple{Nothing,Int64,Int64} +) => Tuple{Nothing,Int64,Int64} diff --git a/dev/search/index.html b/dev/search/index.html index 631cc9c57..eb15b39d8 100644 --- a/dev/search/index.html +++ b/dev/search/index.html @@ -6,4 +6,4 @@ ga('create', 'UA-36890222-9', 'auto'); ga('send', 'pageview', {'page': location.pathname + location.search + location.hash}); -

Loading search...

    +

    Loading search...

      diff --git a/dev/search_index.js b/dev/search_index.js index 993c5bfe3..63d287640 100644 --- a/dev/search_index.js +++ b/dev/search_index.js @@ -1,3 +1,3 @@ var documenterSearchIndex = {"docs": -[{"location":"profiling/#Debugging-in-Time-and-Space-1","page":"Profiling","title":"Debugging in Time and Space","text":"","category":"section"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Because Zygote generates Julia code for the backwards pass, many of Julia's normal profiling and performance debugging tools work well on it out of the box.","category":"page"},{"location":"profiling/#Performance-Profiling-1","page":"Profiling","title":"Performance Profiling","text":"","category":"section"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Julia's sampling profiler is useful for understanding performance. We recommend running the profiler in Juno, but the terminal or ProfileView.jl also work well.","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"(Image: )","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"The bars indicate time taken in both the forwards and backwards passes at that line. The canopy chart on the right shows us each function call as a block, arranged so that when f calls g, g gets a block just below f, which is bigger the longer it took to run. If we dig down the call stack we'll eventually find the adjoints for things like matmul, which we can click on to view.","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"(Image: )","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"The trace inside the adjoint can be used to distinguish time taken by the forwards and backwards passes.","category":"page"},{"location":"profiling/#Memory-Profiling-1","page":"Profiling","title":"Memory Profiling","text":"","category":"section"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Reverse-mode AD typically uses memory proportional to the number of operations in the program, so long-running programs can also suffer memory usage issues. Zygote includes a space profiler to help debug these issues. Like the time profiler, it shows a canopy chart, but this time hovering over it displays the number of bytes stored by each line of the program.","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"(Image: )","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Note that this currently only works inside Juno.","category":"page"},{"location":"profiling/#Reflection-1","page":"Profiling","title":"Reflection","text":"","category":"section"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Julia's code and type inference reflection tools can also be useful, though Zygote's use of closures can make the output noisy. To see the code Julia runs you should use the low-level _pullback method and the pullback it returns. This will directly show either the derived adjoint code or the code for a custom adjoint, if there is one.","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"julia> using Zygote: Context, _pullback\n\njulia> add(a, b) = a+b\n\njulia> @code_typed _pullback(Context(), add, 1, 2)\nCodeInfo(\n1 ─ %1 = (Base.getfield)(args, 1)::Int64\n│ %2 = (Base.getfield)(args, 2)::Int64\n│ %3 = (Base.add_int)(%1, %2)::Int64\n│ %4 = (Base.tuple)(%3, $(QuoteNode(∂(add))))::PartialTuple(Tuple{Int64,typeof(∂(add))}, Any[Int64, Const(∂(add), false)])\n└── return %4\n) => Tuple{Int64,typeof(∂(add))}\n\njulia> y, back = _pullback(Context(), add, 1, 2)\n(3, ∂(add))\n\njulia> @code_typed back(1)\nCodeInfo(\n1 ─ %1 = (Base.mul_int)(Δ, 1)::Int64\n│ %2 = (Base.mul_int)(Δ, 1)::Int64\n│ %3 = (Zygote.tuple)(nothing, %1, %2)::PartialTuple(Tuple{Nothing,Int64,Int64}, Any[Const(nothing, false), Int64, Int64])\n└── return %3\n) => Tuple{Nothing,Int64,Int64}","category":"page"},{"location":"complex/#Complex-Differentiation-1","page":"Complex Differentiation","title":"Complex Differentiation","text":"","category":"section"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"Complex numbers add some difficulty to the idea of a \"gradient\". To talk about gradient(f, x) here we need to talk a bit more about f.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"If f returns a real number, things are fairly straightforward. For c = x + yi and z = f(c), we can define the adjoint bar c = fracpartial zpartial x + fracpartial zpartial yi = bar x + bar y i (note that bar c means gradient, and c means conjugate). It's exactly as if the complex number were just a pair of reals (re, im). This works out of the box.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> using Zygote\n\njulia> gradient(c -> abs2(c), 1+2im)\n(2.0 + 4.0im,)","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"However, while this is a very pragmatic definition that works great for gradient descent, it's not quite aligned with the mathematical notion of the derivative: i.e. f(c + epsilon) approx f(c) + bar c epsilon. In general, such a bar c is not possible for complex numbers except when f is holomorphic (or analytic). Roughly speaking this means that the function is defined over c as if it were a normal real number, without exploiting its complex structure – it can't use real, imag, conj, or anything that depends on these like abs2 (abs2(x) = x*x'). (This constraint also means there's no overlap with the Real case above; holomorphic functions always return complex numbers for complex input.) But most \"normal\" numerical functions – exp, log, anything that can be represented by a Taylor series – are fine.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"Fortunately it's also possible to get these derivatives; they are the conjugate of the gradients for the real part.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> gradient(x -> real(log(x)), 1+2im)[1] |> conj\n0.2 - 0.4im","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"We can check that this function is holomorphic – and thus that the gradient we got out is sensible – by checking the Cauchy-Riemann equations. In other words this should give the same answer:","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> -im*gradient(x -> imag(log(x)), 1+2im)[1] |> conj\n0.2 - 0.4im","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"Notice that this fails in a non-holomorphic case, f(x) = log(x'):","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> gradient(x -> real(log(x')), 1+2im)[1] |> conj\n0.2 - 0.4im\n\njulia> -im*gradient(x -> imag(log(x')), 1+2im)[1] |> conj\n-0.2 + 0.4im","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where frac f z is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"function jacobi(f, x)\n y, back = Zygote.pullback(f, x)\n back(1)[1], back(im)[1]\nend\n\nfunction wirtinger(f, x)\n du, dv = jacobi(f, x)\n (du' + im*dv')/2, (du + im*dv)/2\nend","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> wirtinger(x -> 3x^2 + 2x + 1, 1+2im)\n(8.0 + 12.0im, 0.0 + 0.0im)\n\njulia> wirtinger(x -> abs2(x), 1+2im)\n(1.0 - 2.0im, 1.0 + 2.0im)","category":"page"},{"location":"utils/#Utilities-1","page":"Utilities","title":"Utilities","text":"","category":"section"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Zygote's gradients can be used to construct a Jacobian (by repeated evaluation) or a Hessian (by taking a second derivative).","category":"page"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Zygote.jacobian\nZygote.hessian\nZygote.hessian_reverse\nZygote.diaghessian","category":"page"},{"location":"utils/#Zygote.jacobian","page":"Utilities","title":"Zygote.jacobian","text":"jacobian(f, args...) -> Tuple\n\nFor each array a ∈ args this returns a matrix with Ja[k,i] = ∂y[k]/∂a[i] where y = f(args...) is usually a vector. Arrays of higher dimension are treated like vec(a), or vec(y) for output.\n\nFor scalar x::Number ∈ args, the result is a vector Jx[k] = ∂y[k]/∂x, while for scalar y all results have just one row.\n\nWith any other argument type, no result is produced, even if gradient would work.\n\nThis reverse-mode Jacobian needs to evaluate the pullback once for each element of y. Doing so is usually only efficient when length(y) is small compared to length(a), otherwise forward mode is likely to be better.\n\nSee also withjacobian, hessian, hessian_reverse.\n\nExamples\n\njulia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output\n3×7 Matrix{Int64}:\n 200 0 0 0 0 0 0\n 0 400 0 0 0 0 0\n 0 0 600 0 0 0 0\n\njulia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian\n([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])\n\njulia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)\n([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])\n\nwarning: Warning\nFor arguments of any type except Number & AbstractArray, the result is nothing.\n\njulia> jacobian((a,s) -> a.^length(s), [1,2,3], \"str\")\n([3 0 0; 0 12 0; 0 0 27], nothing)\n\njulia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))\n([4 4 4], nothing)\n\njulia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple\n([4 4 4], (6, 1))\n\n\n\n\n\njacobian(loss, ::Params)\n\nLike gradient with implicit parameters, this method takes a zero-argument function and returns an IdDict-like object, now containing the Jacobian for each parameter.\n\nExamples\n\njulia> xs = [1 2; 3 4]; ys = [5,7,9];\n\njulia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))\nGrads(...)\n\njulia> Jxy[ys]\n2×3 Matrix{Int64}:\n 1 0 0\n 0 1 0\n\njulia> Jxy[xs]\n2×4 Matrix{Int64}:\n 2 6 4 8\n 2 6 4 8\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.hessian","page":"Utilities","title":"Zygote.hessian","text":"hessian(f, x)\n\nConstruct the Hessian ∂²f/∂x², where x is a real number or an array, and f(x) is a real number. When x is an array, the result is a matrix H[i,j] = ∂²f/∂x[i]∂x[j], using linear indexing x[i] even if the argument is higher-dimensional.\n\nThis uses forward over reverse, ForwardDiff over Zygote, calling hessian_dual(f, x). See hessian_reverse for an all-Zygote alternative.\n\nSee also diaghessian to compute only the diagonal part.\n\nExamples\n\njulia> hessian(x -> x[1]*x[2], randn(2))\n2×2 Matrix{Float64}:\n 0.0 1.0\n 1.0 0.0\n\njulia> hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x\n4×4 Matrix{Int64}:\n 6 0 0 0\n 0 18 0 0\n 0 0 12 0\n 0 0 0 24\n\njulia> hessian(sin, pi/2)\n-1.0\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.hessian_reverse","page":"Utilities","title":"Zygote.hessian_reverse","text":"hessian_reverse(f, x)\n\nThis should be equivalent to hessian(f, x), but implemented using reverse over reverse mode, all Zygote. (This is usually much slower, and more likely to find errors.)\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.diaghessian","page":"Utilities","title":"Zygote.diaghessian","text":"diaghessian(f, args...) -> Tuple\n\nDiagonal part of the Hessian. Returns a tuple containing, for each argument x, h of the same shape with h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]. The original evaluation y = f(args...) must give a real number y.\n\nFor one vector argument x, this is equivalent to (diag(hessian(f,x)),). Like hessian it uses ForwardDiff over Zygote. \n\nwarning: Warning\nFor arguments of any type except Number & AbstractArray, the result is nothing.\n\nExamples\n\njulia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]\n2×2 Matrix{Int64}:\n 6 12\n 18 24\n\njulia> Diagonal(vec(ans)) == hessian(x -> sum(x.^3), [1 2; 3 4]) # full Hessian is diagonal\ntrue\n\njulia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666]) # two array arguments\n([0.0 0.0; 0.0 0.0], [2.0, 8.0])\n\njulia> diaghessian(atan, 1, 2) # two scalar arguments\n(-0.16, 0.16)\n\njulia> hessian(xy -> atan(xy[1], xy[2]), [1, 2]) # full Hessian is not diagonal\n2×2 Matrix{Float64}:\n -0.16 -0.12\n -0.12 0.16\n\n\n\n\n\n","category":"function"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Zygote also provides a set of helpful utilities. These are all \"user-level\" tools – in other words you could have written them easily yourself, but they live in Zygote for convenience.","category":"page"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"See ChainRules.ignore_derivatives if you want to exclude some of your code from the gradient calculation. This replaces previous Zygote-specific ignore and dropgrad functionality.","category":"page"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Zygote.withgradient\nZygote.withjacobian\nZygote.@showgrad\nZygote.hook\nZygote.Buffer\nZygote.forwarddiff\nZygote.checkpointed","category":"page"},{"location":"utils/#Zygote.withgradient","page":"Utilities","title":"Zygote.withgradient","text":"withgradient(f, args...)\nwithgradient(f, ::Params)\n\nReturns both the value of the function and the gradient, as a named tuple. \n\njulia> y, ∇ = withgradient(/, 1, 2)\n(val = 0.5, grad = (0.5, -0.25))\n\njulia> ∇ == gradient(/, 1, 2)\ntrue\n\nAllows you to capture auxillary outputs, in addition to the scalar used by gradient. To do this, f must return a Tuple or NamedTuple. Then it calculates grad = gradient(first∘f, args...) but returns the wholeval = f(args...)`:\n\njulia> withgradient([1,2,4]) do x\n z = 1 ./ x\n sum(z), z # here z is an auxillary output\n end\n(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))\n\njulia> withgradient(3.0, 4.0) do x, y\n (div = x/y, mul = x*y)\n end\n(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))\n\nAlso supports implicit mode:\n\njulia> w = [3.0];\n\njulia> res = withgradient(() -> sum(abs2, w), Params([w]))\n(val = 9.0, grad = Grads(...))\n\njulia> res.grad[w]\n1-element Vector{Float64}:\n 6.0\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.withjacobian","page":"Utilities","title":"Zygote.withjacobian","text":"withjacobian(f, args...)\n\nReturns both the value f(args...) and the jacobian as a named tuple.\n\njulia> withjacobian(cumsum, [1,2,3])\n(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.@showgrad","page":"Utilities","title":"Zygote.@showgrad","text":"@showgrad(x) -> x\n\nMuch like @show, but shows the gradient about to accumulate to x. Useful for debugging gradients.\n\njulia> gradient(2, 3) do a, b\n @showgrad(a)*b\n end\n∂(a) = 3\n(3, 2)\n\nNote that the gradient depends on how the output of @showgrad is used, and is not the overall gradient of the variable a. For example:\n\njulia> gradient(2) do a\n @showgrad(a)*a\n end\n∂(a) = 2\n(4,)\n\njulia> gradient(2, 3) do a, b\n @showgrad(a) # not used, so no gradient\n a*b\n end\n∂(a) = nothing\n(3, 2)\n\n\n\n\n\n","category":"macro"},{"location":"utils/#Zygote.hook","page":"Utilities","title":"Zygote.hook","text":"hook(x̄ -> ..., x) -> x\n\nGradient hooks. Allows you to apply an arbitrary function to the gradient for x.\n\njulia> gradient(2, 3) do a, b\n hook(ā -> @show(ā), a)*b\n end\nā = 3\n(3, 2)\n\njulia> gradient(2, 3) do a, b\n hook(-, a)*b\n end\n(-3, 2)\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.Buffer","page":"Utilities","title":"Zygote.Buffer","text":"Buffer(xs, ...)\n\nBuffer is an array-like type which is mutable when taking gradients. You can construct a Buffer with the same syntax as similar (e.g. Buffer(xs, 5)) and then use normal indexing. Finally, use copy to get back a normal array.\n\nFor example:\n\njulia> function vstack(xs)\n buf = Buffer(xs, length(xs), 5)\n for i = 1:5\n buf[:, i] = xs\n end\n return copy(buf)\n end\nvstack (generic function with 1 method)\n\njulia> vstack([1, 2, 3])\n3×5 Array{Int64,2}:\n 1 1 1 1 1\n 2 2 2 2 2\n 3 3 3 3 3\n\njulia> gradient(x -> sum(vstack(x)), [1, 2, 3])\n([5.0, 5.0, 5.0],)\n\nBuffer is not an AbstractArray and can't be used for linear algebra operations like matrix multiplication. This prevents it from being captured by pullbacks.\n\ncopy is a semantic copy, but does not allocate memory. Instead the Buffer is made immutable after copying.\n\n\n\n\n\n","category":"type"},{"location":"utils/#Zygote.forwarddiff","page":"Utilities","title":"Zygote.forwarddiff","text":"forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD) -> f(x)\n\nRuns f(x) as usual, but instructs Zygote to differentiate f using forward mode, rather than the usual reverse mode. The chunk_threshold argument controls the maximum chunk size (c.f. ForwardDiff documentation).\n\nForward mode takes time linear in length(x) but only has constant memory overhead, and is very efficient for scalars, so in some cases this can be a useful optimisation.\n\njulia> function pow(x, n)\n r = one(x)\n for i = 1:n\n r *= x\n end\n return r\n end\npow (generic function with 1 method)\n\njulia> gradient(5) do x\n forwarddiff(x) do x\n pow(x, 2)\n end\n end\n(10,)\n\nNote that the function f will drop gradients for any closed-over values.\n\njulia> gradient(2, 3) do a, b\n forwarddiff(a) do a\n a*b\n end\n end\n(3, nothing)\n\nThis can be rewritten by explicitly passing through b, i.e.\n\ngradient(2, 3) do a, b\n forwarddiff([a, b]) do (a, b)\n a*b\n end\nend\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.checkpointed","page":"Utilities","title":"Zygote.checkpointed","text":"checkpointed(f, xs...)\n\nUse gradient checkpointing on the call f(xs...). This means that checkpointed(f, xs...) === f(xs...), but when computing the derivative intermediate results from the forward pass of f will not be stored. Instead the forward pass will be repeated, when computing the derivative. This saves memory at the cost of increasing execution time.\n\nwarning: Warning\nIf f is not a pure function, checkpointed will likely give wrong results.\n\n\n\n\n\n","category":"function"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Params and Grads can be copied to and from arrays using the copy! function.","category":"page"},{"location":"utils/#Working-with-Grads-1","page":"Utilities","title":"Working with Grads","text":"","category":"section"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Map, broadcast, and iteration are supported for the dictionary-like Grads objects. These operations are value based and preserve the keys.","category":"page"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"using Zygote, Test\n\nw, x1, x2, b = rand(2), rand(2), rand(2), rand(2)\n\ngs1 = gradient(() -> sum(tanh.(w .* x1 .+ b)), Params([w, b]))\ngs2 = gradient(() -> sum(tanh.(w .* x2 .+ b)), Params([w, b]))\n\n# accumulate gradients\ngs = gs1 .+ gs2\n@test gs[w] ≈ gs1[w] + gs2[w]\n@test gs[b] ≈ gs1[b] + gs2[b]\n\n# gradients and IdDict interact nicely\n# note that an IdDict must be used for gradient algebra on the GPU\ngs .+= IdDict(p => randn(size(p)) for p in keys(gs))\n\n# clip gradients\nmap(x -> clamp.(x, -0.1, 0.1), gs)\n\n# clip gradients in-place\nforeach(x -> clamp!(x, -0.1, 0.1), gs)\n\nfor (p, g) in pairs(gs)\n # do something with parameter `p` and corresponding gradient `g`\nend\n\n# note that gradients must be w.r.t. to the same parameter key set\ngs3 = gradient(() -> sum(tanh.(w .* x2)), Params([w]))\n# gs3 does not have the key b\n@test_throws ArgumentError gs1 .+ gs3","category":"page"},{"location":"adjoints/#Custom-Adjoints-1","page":"Custom Adjoints","title":"Custom Adjoints","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"note: Prefer to use ChainRulesCore to define custom adjoints\nZygote supports the use of ChainRulesCore to define custom sensitivities. It is preferred to define the custom sensitivities using ChainRulesCore.rrule as they will work for many AD systems, not just Zygote. These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to ChainRules.jl. To define custom sensitivities using ChainRulesCore, define a ChainRulesCore.rrule(f, args...; kwargs...). Head to ChainRules project's documentation for more information. If you are defining your custom adjoints using ChainRulesCore then you do not need to read this page, and can consider it as documenting a legacy feature.This page exists to describe how Zygote works, and how adjoints can be directly defined for Zygote. Defining adjoints this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Zygote works. It allows for specific definitions of adjoints that are only defined for Zygote (which might work differently to more generic definitions defined for all AD).","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"The @adjoint macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are.","category":"page"},{"location":"adjoints/#Pullbacks-1","page":"Custom Adjoints","title":"Pullbacks","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"gradient is really just syntactic sugar around the more fundamental function pullback.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> using Zygote\n\njulia> y, back = Zygote.pullback(sin, 0.5);\n\njulia> y\n0.479425538604203","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"pullback gives two outputs: the result of the original function, sin(0.5), and a pullback, here called back. back implements the gradient computation for sin, accepting a derivative and producing a new one. In mathematical terms, it implements a vector-Jacobian product. Where y = f(x) and the gradient fracpartial lpartial x is written barx, the pullback mathcalB_y computes:","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"barx = fracpartial lpartial x = fracpartial lpartial y fracpartial ypartial x = mathcalB_y(bary)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"To make this concrete, take the function y = sin(x). fracpartial ypartial x = cos(x), so the pullback is bary cos(x). In other words pullback(sin, x) behaves the same as","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"dsin(x) = (sin(x), ȳ -> (ȳ * cos(x),))","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"gradient takes a function l = f(x) and assumes l = fracpartial lpartial l = 1 and feeds this in to the pullback. In the case of sin,","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> function gradsin(x)\n _, back = dsin(x)\n back(1)\n end\ngradsin (generic function with 1 method)\n\njulia> gradsin(0.5)\n(0.8775825618903728,)\n\njulia> cos(0.5)\n0.8775825618903728","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"More generally","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> function mygradient(f, x...)\n _, back = Zygote.pullback(f, x...)\n back(1)\n end\nmygradient (generic function with 1 method)\n\njulia> mygradient(sin, 0.5)\n(0.8775825618903728,)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"The rest of this section contains more technical detail. It can be skipped if you only need an intuition for pullbacks; you generally won't need to worry about it as a user.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"If x and y are vectors, fracpartial ypartial x becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. v'J, rather than the more usual J*v. Transposing v to a row vector and back (v'J)' is equivalent to J'v so our gradient rules actually implement the adjoint of the Jacobian. This is relevant even for scalar code: the adjoint for y = sin(x) is x̄ = cos(x)'*ȳ; the conjugation is usually moot but gives the correct behaviour for complex code. \"Pullbacks\" are therefore sometimes called \"vector-Jacobian products\" (VJPs), and we refer to the reverse mode rules themselves as \"adjoints\".","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"Zygote has many adjoints for non-mathematical operations such as for indexing and data structures. Though these can still be seen as linear functions of vectors, it's not particularly enlightening to implement them with an actual matrix multiply. In these cases it's easiest to think of the adjoint as a kind of inverse. For example, the gradient of a function that takes a tuple to a struct (e.g. y = Complex(a, b)) will generally take a struct to a tuple ((ȳ.re, ȳ.im)). The gradient of a getindex y = x[i...] is a setindex! x̄[i...] = ȳ, etc.","category":"page"},{"location":"adjoints/#Custom-Adjoints-2","page":"Custom Adjoints","title":"Custom Adjoints","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"We can extend Zygote to a new function with the @adjoint function.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> mul(a, b) = a*b;\n\njulia> using Zygote: @adjoint\n\njulia> @adjoint mul(a, b) = mul(a, b), c̄ -> (c̄*b, c̄*a)\n\njulia> gradient(mul, 2, 3)\n(3.0, 2.0)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"It might look strange that we write mul(a, b) twice here. In this case we want to call the normal mul function for the pullback pass, but you may also want to modify the pullback pass (for example, to capture intermediate results in the pullback).","category":"page"},{"location":"adjoints/#Custom-Types-1","page":"Custom Adjoints","title":"Custom Types","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"One good use for custom adjoints is to customise how your own types behave during differentiation. For example, in our Point example we noticed that the adjoint is a named tuple, rather than another point.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"import Base: +, -\n\nstruct Point\n x::Float64\n y::Float64\nend\n\nwidth(p::Point) = p.x\nheight(p::Point) = p.y\n\na::Point + b::Point = Point(width(a) + width(b), height(a) + height(b))\na::Point - b::Point = Point(width(a) - width(b), height(a) - height(b))\ndist(p::Point) = sqrt(width(p)^2 + height(p)^2)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> gradient(a -> dist(a), Point(1, 2))[1]\n(x = 0.4472135954999579, y = 0.8944271909999159)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"Fundamentally, this happens because of Zygote's default adjoint for getfield.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> gradient(a -> a.x, Point(1, 2))\n((x = 1, y = nothing),)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"We can overload this by modifying the getters height and width.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> @adjoint width(p::Point) = p.x, x̄ -> (Point(x̄, 0),)\n\njulia> @adjoint height(p::Point) = p.y, ȳ -> (Point(0, ȳ),)\n\njulia> Zygote.refresh() # currently needed when defining new adjoints\n\njulia> gradient(a -> height(a), Point(1, 2))\n(Point(0.0, 1.0),)\n\njulia> gradient(a -> dist(a), Point(1, 2))[1]\nPoint(0.4472135954999579, 0.8944271909999159)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"If you do this you should also overload the Point constructor, so that it can handle a Point gradient (otherwise this function will error).","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> @adjoint Point(a, b) = Point(a, b), p̄ -> (p̄.x, p̄.y)\n\njulia> gradient(x -> dist(Point(x, 1)), 1)\n(0.7071067811865475,)","category":"page"},{"location":"adjoints/#Advanced-Adjoints-1","page":"Custom Adjoints","title":"Advanced Adjoints","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"We usually use custom adjoints to add gradients that Zygote can't derive itself (for example, because they ccall to BLAS). But there are some more advanced and fun things we can to with @adjoint.","category":"page"},{"location":"adjoints/#Gradient-Hooks-1","page":"Custom Adjoints","title":"Gradient Hooks","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> hook(f, x) = x\nhook (generic function with 1 method)\n\njulia> @adjoint hook(f, x) = x, x̄ -> (nothing, f(x̄))","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"hook doesn't seem that interesting, as it doesn't do anything. But the fun part is in the adjoint; it's allowing us to apply a function f to the gradient of x.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> gradient((a, b) -> hook(-, a)*b, 2, 3)\n(-3.0, 2.0)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"We could use this for debugging or modifying gradients (e.g. gradient clipping).","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> gradient((a, b) -> hook(ā -> @show(ā), a)*b, 2, 3)\nā = 3.0\n(3.0, 2.0)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"Zygote provides both hook and @showgrad so you don't have to write these yourself.","category":"page"},{"location":"adjoints/#Checkpointing-1","page":"Custom Adjoints","title":"Checkpointing","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"A more advanced example is checkpointing, in which we save memory by re-computing the pullback pass of a function during the backwards pass. To wit:","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> checkpoint(f, x) = f(x)\ncheckpoint (generic function with 1 method)\n\njulia> @adjoint checkpoint(f, x) = f(x), ȳ -> Zygote._pullback(f, x)[2](ȳ)\n\njulia> gradient(x -> checkpoint(sin, x), 1)\n(0.5403023058681398,)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"If a function has side effects we'll see that the pullback pass happens twice, as expected.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> foo(x) = (println(x); sin(x))\nfoo (generic function with 1 method)\n\njulia> gradient(x -> checkpoint(foo, x), 1)\n1\n1\n(0.5403023058681398,)","category":"page"},{"location":"adjoints/#Gradient-Reflection-1","page":"Custom Adjoints","title":"Gradient Reflection","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"It's easy to check whether the code we're running is currently being differentiated.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"isderiving() = false\n\n@adjoint isderiving() = true, _ -> nothing","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"A more interesting example is to actually detect how many levels of nesting are going on.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"nestlevel() = 0\n\n@adjoint nestlevel() = nestlevel()+1, _ -> nothing","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"Demo:","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> function f(x)\n println(nestlevel(), \" levels of nesting\")\n return x\n end\nf (generic function with 1 method)\n\njulia> grad(f, x) = gradient(f, x)[1]\ngrad (generic function with 1 method)\n\njulia> f(1);\n0 levels of nesting\n\njulia> grad(f, 1);\n1 levels of nesting\n\njulia> grad(x -> x*grad(f, x), 1);\n2 levels of nesting","category":"page"},{"location":"limitations/#Design-Limitations-1","page":"Limitations","title":"Design Limitations","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Zygote aims to support differentiating any Julia code, but it still has a few limitations. Notably, you might encounter errors when trying to differentiate:","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"array mutation,\ntry/catch statements,\n\"foreign call\" expressions.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"This section gives examples where each of these errors occurs, as well as possible work-arounds.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Below, it also describes some known bugs in expressions Zygote ought to be able to handle.","category":"page"},{"location":"limitations/#Array-mutation-1","page":"Limitations","title":"Array mutation","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Array mutation is by far the most commonly encountered Zygote limitation.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Automatic differentiation (AD) systems like Zygote are built on basic principles of calculus where we encounter pure functions. This means that the function, y = f(x), does not modify x and only produces the output y based on x. If we have a chain of functions, such as y = h(g(f(x))), we can apply the chain rule to differentiate it. AD systems are built to programmatically apply the chain rule to a series of function calls. Unfortunately, typical programs do not behave this way. We might allocate some memory, x, then call a function y = f!(x) that modifies x to produce the output y. This mutating behavior is a side-effect of f!. Side-effects are difficult for AD systems to handle, because the must track changes to mutated variables and store older versions of the variable. For these reasons, Zygote does not handle array mutation for now.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Let's explore this with a more concrete example. Here we define a simple mutating function, f!, which modifies the elements of its input argument, x, in place.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function f!(x)\n x .= 2 .* x\n\n return x\nend","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Let's see what happens when we differentiate f!","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"julia> gradient(rand(3)) do x\n sum(f!(x))\n end\nERROR: Mutating arrays is not supported -- called copyto!(Vector{Float64}, ...)\nThis error occurs when you ask Zygote to differentiate operations that change\nthe elements of arrays in-place (e.g. setting values with x .= ...)\n\nPossible fixes:\n- avoid mutating operations (preferred)\n- or read the documentation and solutions for this error\n https://fluxml.ai/Zygote.jl/latest/limitations\n\nStacktrace:\n ...","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling copyto! (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes x .= ... which is given as an example of array mutation. Other examples of mutating operations include:","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"setting values (x .= ...)\nappending/popping values (push!(x, v) / pop!(x))\ncalling mutating functions (mul!(C, A, B))","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"warning: Warning\nNon-mutating functions may also use mutation under the hood. This can be done for performance reasons or code re-use.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function g!(x, y)\n x .= 2 .* y\n\n return x\nend\ng(y) = g!(similar(y), y)","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Here g is a \"non-mutating function,\" and it indeed does not mutate y, its only argument. But it still allocates a new array and calls g! on this array which will result in a mutating operation. You may encounter such functions when working with another package.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Specifically for array mutation, we can use Zygote.Buffer to re-write our function. For example, let's fix the function g! above.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function g!(x, y)\n x .= 2 .* y\n\n return x\nend\n\nfunction g(y)\n x = Zygote.Buffer(y) # Buffer supports syntax like similar\n g!(x, y)\n return copy(x) # this step makes the Buffer immutable (w/o actually copying)\nend\n\njulia> gradient(rand(3)) do y\n sum(g(y))\n end\n([2.0, 2.0, 2.0],)","category":"page"},{"location":"limitations/#Try-catch-statements-1","page":"Limitations","title":"Try-catch statements","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Any expressions involving try/catch statements is not supported.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function tryme(x)\n try\n 2 * x\n catch e\n throw(e)\n end\nend\n\njulia> gradient(rand(3)) do x\n sum(tryme(x))\n end\nERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported.\nRefer to the Zygote documentation for fixes.\nhttps://fluxml.ai/Zygote.jl/latest/limitations\n\nStacktrace:\n ...","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Here tryme uses a try/catch statement, and Zygote throws an error when trying to differentiate it as expected. try/catch expressions are used for error handling, but they are less common in Julia compared to some other languages.","category":"page"},{"location":"limitations/#Foreign-call-expressions-1","page":"Limitations","title":"Foreign call expressions","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Foreign call expressions refer to expressions that call external libraries such as code written in C or Fortran. You may want to read more about these calls in the Julia documentation. Scientific computing libraries in Julia may call established C or Fortran libraries under the hood. Since the underlying code for a foreign call expression is not in Julia, it is not possible for Zygote to differentiate this expression.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Below, we define a function that calls a standard C function, clock. This function returns the Unix clock as an Int32.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"julia> jclock(x) = ccall(:clock, Int32, ()) * 2\njclock (generic function with 1 method)\n\njulia> jclock(2)\n30921278\n\njulia> gradient(jclock, rand())\nERROR: Can't differentiate foreigncall expression\nYou might want to check the Zygote limitations documentation.\nhttps://fluxml.ai/Zygote.jl/latest/limitations\n\nStacktrace:\n ...","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"jclock will multiply the result of our C function by an argument. When we try to differentiate with respect to this argument, we get an foreigncall error.","category":"page"},{"location":"limitations/#Solutions-1","page":"Limitations","title":"Solutions","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"For all of the errors above, the suggested solutions are similar. You have the following possible work arounds available (in order of preference):","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"avoid the error-inducing operation (e.g. do not use mutating functions)\ndefine a custom ChainRulesCore.rrule\nopen an issue on Zygote","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using try/catch statements, try to use more graceful error handling such as returning nothing or another sentinel value. Recall that array mutation can also be avoided by using Zygote.Buffer as discussed above.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write a custom rrule using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. This solution is the only solution available for foreign call expressions. Below, we provide a custom rrule for jclock.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"jclock(x) = ccall(:clock, Int32, ()) * x\n\nfunction ChainRulesCore.rrule(::typeof(jclock), x)\n y = jclock(x)\n pb(ȳ) = (ChainRulesCore.NoTangent(), ȳ * y)\n\n return y, pb\nend\n\njulia> gradient(jclock, rand())\n(674298.4243400148,)","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Lastly, if the code causing problems can be fixed, but it is package code instead of your code, then you should open an issue. For functions built into Julia or its standard libraries, you can open an issue with Zygote.jl or ChainRules.jl. For functions in other packages, you can open an issue with the corresponding package issue tracker.","category":"page"},{"location":"limitations/#Known-Issues-1","page":"Limitations","title":"Known Issues","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Zygote's issue tracker has the current list of open bugs. There are some general principles about things you may wish to avoid if you can:","category":"page"},{"location":"limitations/#mutable-structs-1","page":"Limitations","title":"mutable structs","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Zygote has limited support for mutation, and in particular will allow you to change a field in some mutable struct X; a; b; end by setting x.a = val.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"However, this has many limitations and should be avoided if possible.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"The simple solution is to use only immutable structs. ","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"If you need to modify them, using something like @set from Accessors.jl should work well. This returns a new object, but does not have side-effects on other copies of it. ","category":"page"},{"location":"limitations/#Re-using-variable-names-1","page":"Limitations","title":"Re-using variable names","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"It is common to accumulate values in a loop by re-binding the same variable name to a new value many times, for example:","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function mysum(x::Real, n::Int)\n tot = 0.0\n for i in 1:n\n tot += x^n # binds symbol `tot` to new value\n end\n return tot\nend","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"However, sometimes such re-binding confuses Zygote, especially if the type of the value changes. Especially if the variable is \"boxed\", as will happen if you re-bind from within a closure (such as the function created by a do block).","category":"page"},{"location":"limitations/#Second-derivatives-1","page":"Limitations","title":"Second derivatives","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"In principle Zygote supports taking derivatives of derivatives. There are, however, a few problems:","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Quite a few of its rules are not written in a way that is itself differentiable. For instance they may work by making an array then writing into it, which is mutation of the sort forbidden above. \nThe complexity of the code grows rapidly, as Zygote differentiates its own un-optimised output.\nReverse mode over reverse mode is seldom the best algorithm.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"The issue tracker has a label for second order, which will outline where the bodies are buried.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Often using a different AD system over Zygote is a better solution. This is what hessian does, using ForwardDiff over Zygote, but other combinations are possible. (Note that rules defined here mean that Zygote over ForwardDiff is translated to ForwardDiff over ForwardDiff.)","category":"page"},{"location":"glossary/#Glossary-1","page":"Glossary","title":"Glossary","text":"","category":"section"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Differentiation is a minefield of conflicting and overlapping terminology, partly because the ideas have been re-discovered in many different fields (e.g. calculus and differential geometry, the traditional AD community, deep learning, finance, etc.) Many of these terms are not well-defined and others may disagree on the details. Nevertheless, we aim to at least say how we use these terms, which will be helpful when reading over Zygote issues, discussions and source code.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"The list is certainly not complete; if you see new terms you'd like defined, or would like to add one yourself, please do open an issue or PR.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Adjoint: See pullback. Used when defining new pullbacks (i.e. the @adjoint macro) since this involves defining the adjoint of the Jacobian, in most cases.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Backpropagation: Essentially equivalent to \"reverse-mode AD\". Used particularly in the machine learning world to refer to simple chains of functions f(g(h(x))), but has generalised beyond that.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Derivative: Given a scalar function y = f(x), the derivative is fracpartial ypartial x. \"Partial\" is taken for granted in AD; there's no interesting distinction between partial and total derivatives for our purposes. It's all in the eye of the beholder.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Differential: Given a function f(x), the linearisation partial f such that f(x + epsilon) approx f(x) + partial f epsilon. This is a generalisation of the derivative since it applies to, for example, vector-to-vector functions (partial f is a Jacobian) and holomorphic complex functions (partial f is the first Wirtinger derivative). This is not, in general, what Zygote calculates, though differentials can usually be derived from gradients.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"IR: Intermediate Representation. Essentially source code, but usually lower level – e.g. control flow constructs like loops and branches have all been replaced by gotos. The idea is that it's harder for humans to read/write but easier to manipulate programmatically. Worth looking at SSA form as a paradigmatic example.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Gradient: See sensitivity. There is no technical difference in Zygote's view, though \"gradient\" sometimes distinguishes the sensitivity we actually want from e.g. the internal ones that Zygote produces as it backpropagates.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Graph: ML people tend to think of models as \"computation graphs\", but this is no more true than any program is a graph. In fact, pretty much anything is a graph if you squint hard enough. This also refers to the data structure that e.g. TensorFlow and PyTorch build to represent your model, but see trace for that.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Pullback: Given y = f(x) the function bar x = back(bar y). In other words, the function back in y, back = Zygote.pullback(f, x).","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Sensitivity: Used to refer to the gradient bar x = fracpartial lpartial x with some scalar loss l. In other words, you have a value x (which need not be scalar) at some point in your program, and bar x tells you how you should change that value to decrease the loss. In the AD world, sometimes used to refer to adjoint rules.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Source to Source Differentiation: Or Source Code Transformation (SCT). As opposed to tracing programs to simplify them, an alternative is to operate directly on a language's source code or IR, generating new source code for pullbacks. This describes Zygote, Swift for TensorFlow, Tapenade and a few other old ADs that worked on C source files. Zygote and Swift are unusual in that they work on in-memory IR rather than text source.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"To an extent, tracing ADs can be viewed as source transform of a Wengert list / trace. The key difference is that the trace is a lossy representation of the original semantics, which causes problems with e.g. control flow. Systems which can preserve some of those semantics (e.g. autograph) begin to blur the line here, though they are still not nearly as expressive as language IRs.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Symbolic Differentiation: Used to refer to differentiation of \"mathematical expressions\", that is, things like 3x^2 + sin(x). Often distinguished from AD, though this is somewhat arbitrary; you can happily produce a symbolic adjoint for a Wengert list, the only difference being that you're allowed to make variable bindings. So it's really just a special case of AD on an unusually limited language.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Tape: This term can refer to pretty much any part of an AD implementation. In particular confusion is caused by conflating the trace with the set of values sometimes closed over by a pullback. Autograd has a combined trace/closure data structure which is usually described as the tape. On the other hand, PyTorch described their implementation as tape-free because the trace/closure is stored as a DAG rather than a vector, so basically all bets are off here.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Trace: A recording of each mathematical operation used by a program, made at runtime and usually forming a Wengert list. Traces may or may not also record actual runtime values (e.g. PyTorch vs. TensorFlow). They can often be treated as an IR and compiled, but are distinguished from true IRs in that they unroll and inline all control flow, functions and data structures. The tracing process can be thought of as a kind of partial evaluation, though tracers are typically much less worried about losing information.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Vector-Jacobian product: see pullback. So called because all pullbacks are linear functions that can be represented by (left) multiplication with the Jacobian matrix.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Wengert List: A set of simple variable assignments and mathematical expressions, forming a directed graph. Can be thought of as a limited programming language with variable bindings and numerical functions but no control flow or data structures. If you trace a program for AD it will typically take this form.","category":"page"},{"location":"#Zygote-1","page":"Home","title":"Zygote","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"Welcome! Zygote extends the Julia language to support differentiable programming. With Zygote you can write down any Julia code you feel like – including using existing Julia packages – then get gradients and optimise your program. Deep learning, ML and probabilistic programming are all different kinds of differentiable programming that you can do with Zygote.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"At least, that's the idea. We're still in beta so expect some adventures.","category":"page"},{"location":"#Setup-1","page":"Home","title":"Setup","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"Zygote can be installed from the package manager in Julia's REPL:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"] add Zygote","category":"page"},{"location":"#Taking-Gradients-1","page":"Home","title":"Taking Gradients","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"Zygote is easy to understand since, at its core, it has a one-function API (pullback), along with a few simple conveniences. Before explaining pullback, we'll look at the higher-level function gradient.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"gradient calculates derivatives. For example, the derivative of 3x^2 + 2x + 1 is 6x + 2, so when x = 5, dx = 32.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> using Zygote\n\njulia> gradient(x -> 3x^2 + 2x + 1, 5)\n(32.0,)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"gradient returns a tuple, with a gradient for each argument to the function.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> gradient((a, b) -> a*b, 2, 3)\n(3.0, 2.0)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"This will work equally well if the arguments are arrays, structs, or any other Julia type, but the function should return a scalar (like a loss or objective l, if you're doing optimisation / ML).","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> W = rand(2, 3); x = rand(3);\n\njulia> gradient(W -> sum(W*x), W)[1]\n2×3 Array{Float64,2}:\n 0.0462002 0.817608 0.979036\n 0.0462002 0.817608 0.979036\n\njulia> gradient(x -> 3x^2 + 2x + 1, 1//4)\n(7//2,)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Control flow is fully supported, including recursion.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> function pow(x, n)\n r = 1\n for i = 1:n\n r *= x\n end\n return r\n end\npow (generic function with 1 method)\n\njulia> gradient(x -> pow(x, 3), 5)\n(75.0,)\n\njulia> pow2(x, n) = n <= 0 ? 1 : x*pow2(x, n-1)\npow2 (generic function with 1 method)\n\njulia> gradient(x -> pow2(x, 3), 5)\n(75.0,)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Data structures are also supported, including mutable ones like dictionaries. Arrays are currently immutable, though this may change in future.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> d = Dict()\nDict{Any, Any}()\n\njulia> gradient(5) do x\n d[:x] = x\n d[:x] * d[:x]\n end\n(10.0,)\n\njulia> d[:x]\n5","category":"page"},{"location":"#Structs-and-Types-1","page":"Home","title":"Structs and Types","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"Julia makes it easy to work with custom types, and Zygote makes it easy to differentiate them. For example, given a simple Point type:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"import Base: +, -\n\nstruct Point\n x::Float64\n y::Float64\nend\n\na::Point + b::Point = Point(a.x + b.x, a.y + b.y)\na::Point - b::Point = Point(a.x - b.x, a.y - b.y)\ndist(p::Point) = sqrt(p.x^2 + p.y^2)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> a = Point(1, 2)\nPoint(1.0, 2.0)\n\njulia> b = Point(3, 4)\nPoint(3.0, 4.0)\n\njulia> dist(a + b)\n7.211102550927978\n\njulia> gradient(a -> dist(a + b), a)[1]\n(x = 0.5547001962252291, y = 0.8320502943378437)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Zygote's default representation of the \"point adjoint\" is a named tuple with gradients for both fields, but this can of course be customised too.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"This means we can do something very powerful: differentiating through Julia libraries, even if they weren't designed for this. For example, colordiff might be a smarter loss function on colours than simple mean-squared-error:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> using Colors\n\njulia> colordiff(RGB(1, 0, 0), RGB(0, 1, 0))\n86.60823557376344\n\njulia> gradient(colordiff, RGB(1, 0, 0), RGB(0, 1, 0))\n((r = 0.4590887719632896, g = -9.598786801605689, b = 14.181383399012862), (r = -1.7697549557037275, g = 28.88472330558805, b = -0.044793892637761346))","category":"page"},{"location":"#Explicit-and-Implicit-Parameters-1","page":"Home","title":"Explicit and Implicit Parameters","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights. Depending on how you write your model, there are multiple ways to explicitly take gradients with respect to parameters. For example, the function linear accepts the parameters as an argument to the model. So, we directly pass in the parameters, θ, as an argument to the function being differentiated.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"gradient(f, args...)","category":"page"},{"location":"#Zygote.gradient-Tuple{Any, Vararg{Any}}","page":"Home","title":"Zygote.gradient","text":"gradient(f, args...)\n\nReturns a tuple containing ∂f/∂x for each argument x, the derivative (for scalar x) or the gradient.\n\nf(args...) must be a real number, see jacobian for array output.\n\nSee also withgradient to keep the value f(args...), and pullback for value and back-propagator.\n\njulia> gradient(*, 2.0, 3.0, 5.0)\n(15.0, 10.0, 6.0)\n\njulia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])\n([14.0, 22.0, 26.0],)\n\njulia> gradient([7, 11], 0, 1) do x, y, d\n p = size(x, d)\n sum(x.^p .+ y)\n end\n([14.0, 22.0], 2.0, nothing)\n\n\n\n\n\n","category":"method"},{"location":"#","page":"Home","title":"Home","text":"julia> linear(θ, x) = θ[:W] * x .+ θ[:b]\nlinear (generic function with 1 method)\n\njulia> x = rand(5);\n\njulia> θ = Dict(:W => rand(2, 5), :b => rand(2))\nDict{Any,Any} with 2 entries:\n :b => [0.0430585, 0.530201]\n :W => [0.923268 … 0.589691]\n\n# Alternatively, use a named tuple or struct rather than a dict.\n# θ = (W = rand(2, 5), b = rand(2))\n\njulia> θ̄ = gradient(θ -> sum(linear(θ, x)), θ)[1]\nDict{Any,Any} with 2 entries:\n :b => [1.0, 1.0]\n :W => [0.628998 … 0.433006]","category":"page"},{"location":"#","page":"Home","title":"Home","text":"We can combine the role of the dictionary and the function here by making a callable struct which contains the parameters, equivalent to a closure. Passed explicitly to gradient, we get a named tuple with the same field names:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> struct Linear\n W\n b\n end\n\njulia> (l::Linear)(x) = l.W * x .+ l.b\n\njulia> model = Linear(rand(2, 5), rand(2))\nLinear([0.267663 … 0.334385], [0.0386873, 0.0203294])\n\njulia> x = rand(5);\n\njulia> dmodel = gradient(model -> sum(model(x)), model)[1]\n(W = [0.652543 … 0.683588], b = [1.0, 1.0])","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Zygote also supports another way to take gradients, via implicit parameters. Here the loss function takes zero arguments, but the variables of interest are indicated by a special Params object. The function linear which depends on W and b is executed when the loss function () -> sum(linear(x)) is called, and hence this dependence is visible to Zygote:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"gradient","category":"page"},{"location":"#Zygote.gradient","page":"Home","title":"Zygote.gradient","text":"gradient(() -> loss(), ps::Params) -> Grads\n\nGradient with implicit parameters. Takes a zero-argument function, and returns a dictionary-like container, whose keys are arrays x in ps.\n\nSee also withgradient to keep the value loss().\n\njulia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];\n\njulia> g = gradient(Params([x, y])) do\n sum(x .* y .* z')\n end\nGrads(...)\n\njulia> g[x]\n2×3 Matrix{Float64}:\n 7.0 70.0 700.0\n 8.0 80.0 800.0\n\njulia> haskey(g, z) # only x and y are parameters\nfalse\n\n\n\n\n\n","category":"function"},{"location":"#","page":"Home","title":"Home","text":"julia> W = rand(2, 5); b = rand(2);\n\njulia> linear(x) = W * x .+ b\nlinear (generic function with 2 methods)\n\njulia> grads = gradient(() -> sum(linear(x)), Params([W, b]))\nGrads(...)\n\njulia> grads[W], grads[b] # access gradients using arrays as keys\n([0.652543 … 0.683588], [1.0, 1.0])","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Here grads is a dictionary-like object, whose keys are the same parameters we indicated in Params. (In fact it wraps a dictionary using objectid(W) as keys, which does not change if the values in W are mutated).","category":"page"},{"location":"#","page":"Home","title":"Home","text":"This implicit style is the one presently used by Flux.jl, a closely related machine learning library. It uses structs like Linear above to define layers, and the function Flux.params(model) returns a Params object containing all the parameters of all layers. See its documentation for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.","category":"page"},{"location":"internals/#Internals-1","page":"Internals","title":"Internals","text":"","category":"section"},{"location":"internals/#What-Zygote-Does-1","page":"Internals","title":"What Zygote Does","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"These notebooks and the Zygote paper provide useful background on Zygote's transform; this page is particularly focused on implementation details.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Before we think about AD, we'll consider some simple cases. We can start by defining a function that produces pullbacks, J, explicitly for some simple functions.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"J(::typeof(sin), x) = sin(x), ȳ -> ȳ*cos(x)\nJ(::typeof(cos), x) = cos(x), ȳ -> -ȳ*sin(x)\nJ(::typeof(*), a, b) = a*b, c̄ -> (b*c̄, a*c̄)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Now we can call J to take a gradient.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"gradient(f, x...) = J(f, x...)[2](1)\n\ngradient(sin, 1) # (0.540,)\n\ngradient(*, 2, 3) # (3, 2)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Now consider a composite function that calls two simple ones:","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"function foo(x)\n a = sin(x)\n b = cos(a)\n return b\nend","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"We can easily differentiate foo if we can differentiate the functions it calls. If we can get pullbacks via J, the pullback for foo looks as follows. Where the forward pass calculates x -> a -> b, the backwards takes b̄ -> ā -> x̄ via the pullbacks.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"function J(::typeof(foo), x)\n a, da = J(sin, x)\n b, db = J(cos, a)\n return b, function(b̄)\n ā, = db(b̄)\n x̄, = da(ā)\n return x̄\n end\nend\n\ngradient(foo, 1) # (-0.403,)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Things get just a little more complex when control flow is involved. You can see that the derived adjoint for pow mirrors the original function, except that the loop runs in reverse. The easiest way to see why it looks like this is to imagine unrolling the loop n times, working out the adjoint, and then turning it back into a loop.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"function pow(x, n) # x^n\n r = 1\n for _ = 1:n\n r *= x\n end\n return r\nend\n\nfunction J(::typeof(pow), x, n)\n r = 1\n Js = []\n for i = 1:n\n r, back = J(*, r, x)\n push!(Js, back)\n end\n return r, function(r̄)\n x̄ = 0\n for i = n:-1:1\n r̄, x̄′ = Js[i](r̄)\n x̄ += x̄′\n end\n return (x̄, 0)\n end\nend\n\ngradient(pow, 2, 3) # (12, 0)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Despite being reasonably fiddly, this is a fully mechanical transformation, so the only remaining thing is to automate it – a small matter of programming.","category":"page"},{"location":"internals/#Closures-1","page":"Internals","title":"Closures","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"The J function here corresponds to pullback in Zygote. However, pullback is actually a wrapper around the lower level _pullback function.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> y, back = Zygote._pullback(sin, 0.5);\n\njulia> back(1)\n(nothing, 0.8775825618903728)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Why the extra nothing here? This actually represents the gradient of the function sin. This is often nothing, but when we have closures the function contains data we need gradients for.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> f = let a = 3; x -> x*a; end\n#19 (generic function with 1 method)\n\njulia> y, back = Zygote._pullback(f, 2);\n\njulia> back(1)\n((a = 2,), 3)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"This is a minor point for the most part, but _pullback will come up in future examples.","category":"page"},{"location":"internals/#Entry-Points-1","page":"Internals","title":"Entry Points","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"You might notice that Zygote is, in effect, just a macro. We could happily implement Zygote by writing definitions like","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"@differentiable foo(x) = sin(cos(x))","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"which would expand to generate an appropriate overload to J. As long as every function we want to differentiate is annotated, this will work just fine. However, it's obviously not ideal to have to annotate every function inside every Julia package in order to make it differentiable.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"This is where generated functions come in. Making J a generated function allows us to apply the Zygote macro on an as-needed basis; calling J(f, x...) looks up the code for f(x...), transforms it, and then behaves as if you had defined J for that specific function ahead of time.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"When we look up the code, we actually get lowered (desugared) code rather than an AST.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> foo(x) = baz(bar(x))\nfoo (generic function with 1 method)\n\njulia> @code_lowered foo(1)\nCodeInfo(\n1 ─ %1 = (Main.bar)(x)\n│ %2 = (Main.baz)(%1)\n└── return %2","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"We convert the code to SSA form using Julia's built-in IR data structure, after which it looks like this.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> Zygote.@code_ir foo(1)\n1 1 ─ %1 = (Main.bar)(_2)::Any\n │ %2 = (Main.baz)(%1)::Any\n └── return %2","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"(There isn't much difference unless there's some control flow.)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"The code is then differentiated by the code in compiler/reverse.jl. You can see the output with @code_adjoint.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> Zygote.@code_adjoint foo(1)\n1 1 ─ %1 = (Zygote._pullback)(_2, Zygote.unwrap, Main.bar)::Any\n │ %2 = (Base.getindex)(%1, 1)::Any\n │ (Base.getindex)(%1, 2)::Any\n │ %4 = (Zygote._pullback)(_2, %2, _4)::Any\n │ %5 = (Base.getindex)(%4, 1)::Any\n │ (Base.getindex)(%4, 2)::Any\n │ %7 = (Zygote._pullback)(_2, Zygote.unwrap, Main.baz)::Any\n │ %8 = (Base.getindex)(%7, 1)::Any\n │ (Base.getindex)(%7, 2)::Any\n │ %10 = (Zygote._pullback)(_2, %8, %5)::Any\n │ %11 = (Base.getindex)(%10, 1)::Any\n │ (Base.getindex)(%10, 2)::Any\n └── return %11\n 1 ─ %1 = Δ()::Any\n1 │ %2 = (@12)(%1)::Any\n │ %3 = (Zygote.gradindex)(%2, 1)::Any\n │ %4 = (Zygote.gradindex)(%2, 2)::Any\n │ (@9)(%3)::Any\n │ %6 = (@6)(%4)::Any\n │ %7 = (Zygote.gradindex)(%6, 1)::Any\n │ %8 = (Zygote.gradindex)(%6, 2)::Any\n │ (@3)(%7)::Any\n │ %10 = (Zygote.tuple)(nothing, %8)::Any\n └── return %10\n, [1])","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"This code is quite verbose, mainly due to all the tuple unpacking (gradindex is just like getindex, but handles nothing gracefully). There are two pieces of IR here, one for the modified pullback pass and one for the pullback closure. The @ nodes allow the closure to refer to values from the pullback pass, and the Δ() represents the incoming gradient ȳ. In essence, this is just what we wrote above by hand for J(::typeof(foo), x).","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"compiler/emit.jl lowers this code into runnable IR (e.g. by turning @ references into getfields and stacks), and it's then turned back into lowered code for Julia to run.","category":"page"},{"location":"internals/#Closure-Conversion-1","page":"Internals","title":"Closure Conversion","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"There are no closures in lowered Julia code, so we can't actually emit one directly in lowered code. To work around this we have a trick: we have a generic struct like","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"struct Pullback{F}\n data\nend","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"We can put whatever we want in data, and the F will be the signature for the original call, like Tuple{typeof(foo),Int}. When the pullback gets called it hits another generated function which emits the pullback code.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"In hand written code this would look like:","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"struct Pullback{F}\n data\nend\n\nfunction J(::typeof(foo), x)\n a, da = J(sin, x)\n b, db = J(cos, a)\n return b, Pullback{typeof(foo)}((da, db))\nend\n\nfunction (p::Pullback{typeof(foo)})(b̄)\n da, db = p.data[1], p.data[2]\n ā = db(b̄)\n x̄ = da(ā)\n return x̄\nend","category":"page"},{"location":"internals/#Debugging-1","page":"Internals","title":"Debugging","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Say some of our code is throwing an error.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"bad(x) = x\n\nZygote.@adjoint bad(x) = x, _ -> error(\"bad\")\n\nfoo(x) = bad(sin(x))\n\ngradient(foo, 1) # error!","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Zygote can usually give a stacktrace pointing right to the issue here, but in some cases there are compiler crashes that make this harder. In these cases it's best to (a) use _pullback and (b) take advantage of Zygote's recursion to narrow down the problem function.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> y, back = Zygote._pullback(foo, 1);\n\njulia> back(1) # just make up a value here, it just needs to look similar to `y`\nERROR: bad\n\n# Ok, so we try functions that foo calls\n\njulia> y, back = Zygote._pullback(sin, 1);\n\njulia> back(1)\n(nothing, 0.5403023058681398)\n\n# Looks like that's fine\n\njulia> y, back = Zygote._pullback(bad, 1);\n\njulia> back(1) # ok, here's our issue. Lather, rinse, repeat.\nERROR: bad","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Of course, our goal is that you never have to do this, but until Zygote is more mature it can be a useful way to narrow down test cases.","category":"page"}] +[{"location":"profiling/#Debugging-in-Time-and-Space-1","page":"Profiling","title":"Debugging in Time and Space","text":"","category":"section"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Because Zygote generates Julia code for the backwards pass, many of Julia's normal profiling and performance debugging tools work well on it out of the box.","category":"page"},{"location":"profiling/#Performance-Profiling-1","page":"Profiling","title":"Performance Profiling","text":"","category":"section"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Julia's sampling profiler is useful for understanding performance. We recommend running the profiler in Juno, but the terminal or ProfileView.jl also work well.","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"(Image: )","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"The bars indicate time taken in both the forwards and backwards passes at that line. The canopy chart on the right shows us each function call as a block, arranged so that when f calls g, g gets a block just below f, which is bigger the longer it took to run. If we dig down the call stack we'll eventually find the adjoints for things like matmul, which we can click on to view.","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"(Image: )","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"The trace inside the adjoint can be used to distinguish time taken by the forwards and backwards passes.","category":"page"},{"location":"profiling/#Memory-Profiling-1","page":"Profiling","title":"Memory Profiling","text":"","category":"section"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Reverse-mode AD typically uses memory proportional to the number of operations in the program, so long-running programs can also suffer memory usage issues. Zygote includes a space profiler to help debug these issues. Like the time profiler, it shows a canopy chart, but this time hovering over it displays the number of bytes stored by each line of the program.","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"(Image: )","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Note that this currently only works inside Juno.","category":"page"},{"location":"profiling/#Reflection-1","page":"Profiling","title":"Reflection","text":"","category":"section"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"Julia's code and type inference reflection tools can also be useful, though Zygote's use of closures can make the output noisy. To see the code Julia runs you should use the low-level _pullback method and the pullback it returns. This will directly show either the derived adjoint code or the code for a custom adjoint, if there is one.","category":"page"},{"location":"profiling/#","page":"Profiling","title":"Profiling","text":"julia> using Zygote: Context, _pullback\n\njulia> add(a, b) = a+b\n\njulia> @code_typed _pullback(Context(), add, 1, 2)\nCodeInfo(\n1 ─ %1 = (Base.getfield)(args, 1)::Int64\n│ %2 = (Base.getfield)(args, 2)::Int64\n│ %3 = (Base.add_int)(%1, %2)::Int64\n│ %4 = (Base.tuple)(%3, $(QuoteNode(∂(add))))::PartialTuple(Tuple{Int64,typeof(∂(add))}, Any[Int64, Const(∂(add), false)])\n└── return %4\n) => Tuple{Int64,typeof(∂(add))}\n\njulia> y, back = _pullback(Context(), add, 1, 2)\n(3, ∂(add))\n\njulia> @code_typed back(1)\nCodeInfo(\n1 ─ %1 = (Base.mul_int)(Δ, 1)::Int64\n│ %2 = (Base.mul_int)(Δ, 1)::Int64\n│ %3 = (Zygote.tuple)(nothing, %1, %2)::PartialTuple(Tuple{Nothing,Int64,Int64}, Any[Const(nothing, false), Int64, Int64])\n└── return %3\n) => Tuple{Nothing,Int64,Int64}","category":"page"},{"location":"complex/#Complex-Differentiation-1","page":"Complex Differentiation","title":"Complex Differentiation","text":"","category":"section"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"Complex numbers add some difficulty to the idea of a \"gradient\". To talk about gradient(f, x) here we need to talk a bit more about f.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"If f returns a real number, things are fairly straightforward. For c = x + yi and z = f(c), we can define the adjoint bar c = fracpartial zpartial x + fracpartial zpartial yi = bar x + bar y i (note that bar c means gradient, and c means conjugate). It's exactly as if the complex number were just a pair of reals (re, im). This works out of the box.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> using Zygote\n\njulia> gradient(c -> abs2(c), 1+2im)\n(2.0 + 4.0im,)","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"However, while this is a very pragmatic definition that works great for gradient descent, it's not quite aligned with the mathematical notion of the derivative: i.e. f(c + epsilon) approx f(c) + bar c epsilon. In general, such a bar c is not possible for complex numbers except when f is holomorphic (or analytic). Roughly speaking this means that the function is defined over c as if it were a normal real number, without exploiting its complex structure – it can't use real, imag, conj, or anything that depends on these like abs2 (abs2(x) = x*x'). (This constraint also means there's no overlap with the Real case above; holomorphic functions always return complex numbers for complex input.) But most \"normal\" numerical functions – exp, log, anything that can be represented by a Taylor series – are fine.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"Fortunately it's also possible to get these derivatives; they are the conjugate of the gradients for the real part.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> gradient(x -> real(log(x)), 1+2im)[1] |> conj\n0.2 - 0.4im","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"We can check that this function is holomorphic – and thus that the gradient we got out is sensible – by checking the Cauchy-Riemann equations. In other words this should give the same answer:","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> -im*gradient(x -> imag(log(x)), 1+2im)[1] |> conj\n0.2 - 0.4im","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"Notice that this fails in a non-holomorphic case, f(x) = log(x'):","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> gradient(x -> real(log(x')), 1+2im)[1] |> conj\n0.2 - 0.4im\n\njulia> -im*gradient(x -> imag(log(x')), 1+2im)[1] |> conj\n-0.2 + 0.4im","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where frac f z is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice.","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"function jacobi(f, x)\n y, back = Zygote.pullback(f, x)\n back(1)[1], back(im)[1]\nend\n\nfunction wirtinger(f, x)\n du, dv = jacobi(f, x)\n (du' + im*dv')/2, (du + im*dv)/2\nend","category":"page"},{"location":"complex/#","page":"Complex Differentiation","title":"Complex Differentiation","text":"julia> wirtinger(x -> 3x^2 + 2x + 1, 1+2im)\n(8.0 + 12.0im, 0.0 + 0.0im)\n\njulia> wirtinger(x -> abs2(x), 1+2im)\n(1.0 - 2.0im, 1.0 + 2.0im)","category":"page"},{"location":"utils/#Utilities-1","page":"Utilities","title":"Utilities","text":"","category":"section"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Zygote's gradients can be used to construct a Jacobian (by repeated evaluation) or a Hessian (by taking a second derivative).","category":"page"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Zygote.jacobian\nZygote.hessian\nZygote.hessian_reverse\nZygote.diaghessian","category":"page"},{"location":"utils/#Zygote.jacobian","page":"Utilities","title":"Zygote.jacobian","text":"jacobian(f, args...) -> Tuple\n\nFor each array a ∈ args this returns a matrix with Ja[k,i] = ∂y[k]/∂a[i] where y = f(args...) is usually a vector. Arrays of higher dimension are treated like vec(a), or vec(y) for output.\n\nFor scalar x::Number ∈ args, the result is a vector Jx[k] = ∂y[k]/∂x, while for scalar y all results have just one row.\n\nWith any other argument type, no result is produced, even if gradient would work.\n\nThis reverse-mode Jacobian needs to evaluate the pullback once for each element of y. Doing so is usually only efficient when length(y) is small compared to length(a), otherwise forward mode is likely to be better.\n\nSee also withjacobian, hessian, hessian_reverse.\n\nExamples\n\njulia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output\n3×7 Matrix{Int64}:\n 200 0 0 0 0 0 0\n 0 400 0 0 0 0 0\n 0 0 600 0 0 0 0\n\njulia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian\n([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])\n\njulia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)\n([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])\n\nwarning: Warning\nFor arguments of any type except Number & AbstractArray, the result is nothing.\n\njulia> jacobian((a,s) -> a.^length(s), [1,2,3], \"str\")\n([3 0 0; 0 12 0; 0 0 27], nothing)\n\njulia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))\n([4 4 4], nothing)\n\njulia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple\n([4 4 4], (6, 1))\n\n\n\n\n\njacobian(loss, ::Params)\n\nLike gradient with implicit parameters, this method takes a zero-argument function and returns an IdDict-like object, now containing the Jacobian for each parameter.\n\nExamples\n\njulia> xs = [1 2; 3 4]; ys = [5,7,9];\n\njulia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))\nGrads(...)\n\njulia> Jxy[ys]\n2×3 Matrix{Int64}:\n 1 0 0\n 0 1 0\n\njulia> Jxy[xs]\n2×4 Matrix{Int64}:\n 2 6 4 8\n 2 6 4 8\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.hessian","page":"Utilities","title":"Zygote.hessian","text":"hessian(f, x)\n\nConstruct the Hessian ∂²f/∂x², where x is a real number or an array, and f(x) is a real number. When x is an array, the result is a matrix H[i,j] = ∂²f/∂x[i]∂x[j], using linear indexing x[i] even if the argument is higher-dimensional.\n\nThis uses forward over reverse, ForwardDiff over Zygote, calling hessian_dual(f, x). See hessian_reverse for an all-Zygote alternative.\n\nSee also diaghessian to compute only the diagonal part.\n\nExamples\n\njulia> hessian(x -> x[1]*x[2], randn(2))\n2×2 Matrix{Float64}:\n 0.0 1.0\n 1.0 0.0\n\njulia> hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x\n4×4 Matrix{Int64}:\n 6 0 0 0\n 0 18 0 0\n 0 0 12 0\n 0 0 0 24\n\njulia> hessian(sin, pi/2)\n-1.0\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.hessian_reverse","page":"Utilities","title":"Zygote.hessian_reverse","text":"hessian_reverse(f, x)\n\nThis should be equivalent to hessian(f, x), but implemented using reverse over reverse mode, all Zygote. (This is usually much slower, and more likely to find errors.)\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.diaghessian","page":"Utilities","title":"Zygote.diaghessian","text":"diaghessian(f, args...) -> Tuple\n\nDiagonal part of the Hessian. Returns a tuple containing, for each argument x, h of the same shape with h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]. The original evaluation y = f(args...) must give a real number y.\n\nFor one vector argument x, this is equivalent to (diag(hessian(f,x)),). Like hessian it uses ForwardDiff over Zygote. \n\nwarning: Warning\nFor arguments of any type except Number & AbstractArray, the result is nothing.\n\nExamples\n\njulia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]\n2×2 Matrix{Int64}:\n 6 12\n 18 24\n\njulia> Diagonal(vec(ans)) == hessian(x -> sum(x.^3), [1 2; 3 4]) # full Hessian is diagonal\ntrue\n\njulia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666]) # two array arguments\n([0.0 0.0; 0.0 0.0], [2.0, 8.0])\n\njulia> diaghessian(atan, 1, 2) # two scalar arguments\n(-0.16, 0.16)\n\njulia> hessian(xy -> atan(xy[1], xy[2]), [1, 2]) # full Hessian is not diagonal\n2×2 Matrix{Float64}:\n -0.16 -0.12\n -0.12 0.16\n\n\n\n\n\n","category":"function"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Zygote also provides a set of helpful utilities. These are all \"user-level\" tools – in other words you could have written them easily yourself, but they live in Zygote for convenience.","category":"page"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"See ChainRules.ignore_derivatives if you want to exclude some of your code from the gradient calculation. This replaces previous Zygote-specific ignore and dropgrad functionality.","category":"page"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Zygote.withgradient\nZygote.withjacobian\nZygote.@showgrad\nZygote.hook\nZygote.Buffer\nZygote.forwarddiff\nZygote.checkpointed","category":"page"},{"location":"utils/#Zygote.withgradient","page":"Utilities","title":"Zygote.withgradient","text":"withgradient(f, args...)\nwithgradient(f, ::Params)\n\nReturns both the value of the function and the gradient, as a named tuple. \n\njulia> y, ∇ = withgradient(/, 1, 2)\n(val = 0.5, grad = (0.5, -0.25))\n\njulia> ∇ == gradient(/, 1, 2)\ntrue\n\nAllows you to capture auxillary outputs, in addition to the scalar used by gradient. To do this, f must return a Tuple or NamedTuple. Then it calculates grad = gradient(first∘f, args...) but returns the wholeval = f(args...)`:\n\njulia> withgradient([1,2,4]) do x\n z = 1 ./ x\n sum(z), z # here z is an auxillary output\n end\n(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))\n\njulia> withgradient(3.0, 4.0) do x, y\n (div = x/y, mul = x*y)\n end\n(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))\n\nAlso supports implicit mode:\n\njulia> w = [3.0];\n\njulia> res = withgradient(() -> sum(abs2, w), Params([w]))\n(val = 9.0, grad = Grads(...))\n\njulia> res.grad[w]\n1-element Vector{Float64}:\n 6.0\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.withjacobian","page":"Utilities","title":"Zygote.withjacobian","text":"withjacobian(f, args...)\n\nReturns both the value f(args...) and the jacobian as a named tuple.\n\njulia> withjacobian(cumsum, [1,2,3])\n(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.@showgrad","page":"Utilities","title":"Zygote.@showgrad","text":"@showgrad(x) -> x\n\nMuch like @show, but shows the gradient about to accumulate to x. Useful for debugging gradients.\n\njulia> gradient(2, 3) do a, b\n @showgrad(a)*b\n end\n∂(a) = 3\n(3, 2)\n\nNote that the gradient depends on how the output of @showgrad is used, and is not the overall gradient of the variable a. For example:\n\njulia> gradient(2) do a\n @showgrad(a)*a\n end\n∂(a) = 2\n(4,)\n\njulia> gradient(2, 3) do a, b\n @showgrad(a) # not used, so no gradient\n a*b\n end\n∂(a) = nothing\n(3, 2)\n\n\n\n\n\n","category":"macro"},{"location":"utils/#Zygote.hook","page":"Utilities","title":"Zygote.hook","text":"hook(x̄ -> ..., x) -> x\n\nGradient hooks. Allows you to apply an arbitrary function to the gradient for x.\n\njulia> gradient(2, 3) do a, b\n hook(ā -> @show(ā), a)*b\n end\nā = 3\n(3, 2)\n\njulia> gradient(2, 3) do a, b\n hook(-, a)*b\n end\n(-3, 2)\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.Buffer","page":"Utilities","title":"Zygote.Buffer","text":"Buffer(xs, ...)\n\nBuffer is an array-like type which is mutable when taking gradients. You can construct a Buffer with the same syntax as similar (e.g. Buffer(xs, 5)) and then use normal indexing. Finally, use copy to get back a normal array.\n\nFor example:\n\njulia> function vstack(xs)\n buf = Buffer(xs, length(xs), 5)\n for i = 1:5\n buf[:, i] = xs\n end\n return copy(buf)\n end\nvstack (generic function with 1 method)\n\njulia> vstack([1, 2, 3])\n3×5 Array{Int64,2}:\n 1 1 1 1 1\n 2 2 2 2 2\n 3 3 3 3 3\n\njulia> gradient(x -> sum(vstack(x)), [1, 2, 3])\n([5.0, 5.0, 5.0],)\n\nBuffer is not an AbstractArray and can't be used for linear algebra operations like matrix multiplication. This prevents it from being captured by pullbacks.\n\ncopy is a semantic copy, but does not allocate memory. Instead the Buffer is made immutable after copying.\n\n\n\n\n\n","category":"type"},{"location":"utils/#Zygote.forwarddiff","page":"Utilities","title":"Zygote.forwarddiff","text":"forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD) -> f(x)\n\nRuns f(x) as usual, but instructs Zygote to differentiate f using forward mode, rather than the usual reverse mode. The chunk_threshold argument controls the maximum chunk size (c.f. ForwardDiff documentation).\n\nForward mode takes time linear in length(x) but only has constant memory overhead, and is very efficient for scalars, so in some cases this can be a useful optimisation.\n\njulia> function pow(x, n)\n r = one(x)\n for i = 1:n\n r *= x\n end\n return r\n end\npow (generic function with 1 method)\n\njulia> gradient(5) do x\n forwarddiff(x) do x\n pow(x, 2)\n end\n end\n(10,)\n\nNote that the function f will drop gradients for any closed-over values.\n\njulia> gradient(2, 3) do a, b\n forwarddiff(a) do a\n a*b\n end\n end\n(3, nothing)\n\nThis can be rewritten by explicitly passing through b, i.e.\n\ngradient(2, 3) do a, b\n forwarddiff([a, b]) do (a, b)\n a*b\n end\nend\n\n\n\n\n\n","category":"function"},{"location":"utils/#Zygote.checkpointed","page":"Utilities","title":"Zygote.checkpointed","text":"checkpointed(f, xs...)\n\nUse gradient checkpointing on the call f(xs...). This means that checkpointed(f, xs...) === f(xs...), but when computing the derivative intermediate results from the forward pass of f will not be stored. Instead the forward pass will be repeated, when computing the derivative. This saves memory at the cost of increasing execution time.\n\nwarning: Warning\nIf f is not a pure function, checkpointed will likely give wrong results.\n\n\n\n\n\n","category":"function"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Params and Grads can be copied to and from arrays using the copy! function.","category":"page"},{"location":"utils/#Working-with-Grads-1","page":"Utilities","title":"Working with Grads","text":"","category":"section"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"Map, broadcast, and iteration are supported for the dictionary-like Grads objects. These operations are value based and preserve the keys.","category":"page"},{"location":"utils/#","page":"Utilities","title":"Utilities","text":"using Zygote, Test\n\nw, x1, x2, b = rand(2), rand(2), rand(2), rand(2)\n\ngs1 = gradient(() -> sum(tanh.(w .* x1 .+ b)), Params([w, b]))\ngs2 = gradient(() -> sum(tanh.(w .* x2 .+ b)), Params([w, b]))\n\n# accumulate gradients\ngs = gs1 .+ gs2\n@test gs[w] ≈ gs1[w] + gs2[w]\n@test gs[b] ≈ gs1[b] + gs2[b]\n\n# gradients and IdDict interact nicely\n# note that an IdDict must be used for gradient algebra on the GPU\ngs .+= IdDict(p => randn(size(p)) for p in keys(gs))\n\n# clip gradients\nmap(x -> clamp.(x, -0.1, 0.1), gs)\n\n# clip gradients in-place\nforeach(x -> clamp!(x, -0.1, 0.1), gs)\n\nfor (p, g) in pairs(gs)\n # do something with parameter `p` and corresponding gradient `g`\nend\n\n# note that gradients must be w.r.t. to the same parameter key set\ngs3 = gradient(() -> sum(tanh.(w .* x2)), Params([w]))\n# gs3 does not have the key b\n@test_throws ArgumentError gs1 .+ gs3","category":"page"},{"location":"adjoints/#Custom-Adjoints-1","page":"Custom Adjoints","title":"Custom Adjoints","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"note: Prefer to use ChainRulesCore to define custom adjoints\nZygote supports the use of ChainRulesCore to define custom sensitivities. It is preferred to define the custom sensitivities using ChainRulesCore.rrule as they will work for many AD systems, not just Zygote. These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to ChainRules.jl. To define custom sensitivities using ChainRulesCore, define a ChainRulesCore.rrule(f, args...; kwargs...). Head to ChainRules project's documentation for more information. If you are defining your custom adjoints using ChainRulesCore then you do not need to read this page, and can consider it as documenting a legacy feature.This page exists to describe how Zygote works, and how adjoints can be directly defined for Zygote. Defining adjoints this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Zygote works. It allows for specific definitions of adjoints that are only defined for Zygote (which might work differently to more generic definitions defined for all AD).","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"The @adjoint macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are.","category":"page"},{"location":"adjoints/#Pullbacks-1","page":"Custom Adjoints","title":"Pullbacks","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"gradient is really just syntactic sugar around the more fundamental function pullback.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> using Zygote\n\njulia> y, back = Zygote.pullback(sin, 0.5);\n\njulia> y\n0.479425538604203","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"pullback gives two outputs: the result of the original function, sin(0.5), and a pullback, here called back. back implements the gradient computation for sin, accepting a derivative and producing a new one. In mathematical terms, it implements a vector-Jacobian product. Where y = f(x) and the gradient fracpartial lpartial x is written barx, the pullback mathcalB_y computes:","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"barx = fracpartial lpartial x = fracpartial lpartial y fracpartial ypartial x = mathcalB_y(bary)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"To make this concrete, take the function y = sin(x). fracpartial ypartial x = cos(x), so the pullback is bary cos(x). In other words pullback(sin, x) behaves the same as","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"dsin(x) = (sin(x), ȳ -> (ȳ * cos(x),))","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"gradient takes a function l = f(x) and assumes l = fracpartial lpartial l = 1 and feeds this in to the pullback. In the case of sin,","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> function gradsin(x)\n _, back = dsin(x)\n back(1)\n end\ngradsin (generic function with 1 method)\n\njulia> gradsin(0.5)\n(0.8775825618903728,)\n\njulia> cos(0.5)\n0.8775825618903728","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"More generally","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> function mygradient(f, x...)\n _, back = Zygote.pullback(f, x...)\n back(1)\n end\nmygradient (generic function with 1 method)\n\njulia> mygradient(sin, 0.5)\n(0.8775825618903728,)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"The rest of this section contains more technical detail. It can be skipped if you only need an intuition for pullbacks; you generally won't need to worry about it as a user.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"If x and y are vectors, fracpartial ypartial x becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. v'J, rather than the more usual J*v. Transposing v to a row vector and back (v'J)' is equivalent to J'v so our gradient rules actually implement the adjoint of the Jacobian. This is relevant even for scalar code: the adjoint for y = sin(x) is x̄ = cos(x)'*ȳ; the conjugation is usually moot but gives the correct behaviour for complex code. \"Pullbacks\" are therefore sometimes called \"vector-Jacobian products\" (VJPs), and we refer to the reverse mode rules themselves as \"adjoints\".","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"Zygote has many adjoints for non-mathematical operations such as for indexing and data structures. Though these can still be seen as linear functions of vectors, it's not particularly enlightening to implement them with an actual matrix multiply. In these cases it's easiest to think of the adjoint as a kind of inverse. For example, the gradient of a function that takes a tuple to a struct (e.g. y = Complex(a, b)) will generally take a struct to a tuple ((ȳ.re, ȳ.im)). The gradient of a getindex y = x[i...] is a setindex! x̄[i...] = ȳ, etc.","category":"page"},{"location":"adjoints/#Custom-Adjoints-2","page":"Custom Adjoints","title":"Custom Adjoints","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"We can extend Zygote to a new function with the @adjoint function.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> mul(a, b) = a*b;\n\njulia> using Zygote: @adjoint\n\njulia> @adjoint mul(a, b) = mul(a, b), c̄ -> (c̄*b, c̄*a)\n\njulia> gradient(mul, 2, 3)\n(3.0, 2.0)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"It might look strange that we write mul(a, b) twice here. In this case we want to call the normal mul function for the pullback pass, but you may also want to modify the pullback pass (for example, to capture intermediate results in the pullback).","category":"page"},{"location":"adjoints/#Custom-Types-1","page":"Custom Adjoints","title":"Custom Types","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"One good use for custom adjoints is to customise how your own types behave during differentiation. For example, in our Point example we noticed that the adjoint is a named tuple, rather than another point.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"import Base: +, -\n\nstruct Point\n x::Float64\n y::Float64\nend\n\nwidth(p::Point) = p.x\nheight(p::Point) = p.y\n\na::Point + b::Point = Point(width(a) + width(b), height(a) + height(b))\na::Point - b::Point = Point(width(a) - width(b), height(a) - height(b))\ndist(p::Point) = sqrt(width(p)^2 + height(p)^2)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> gradient(a -> dist(a), Point(1, 2))[1]\n(x = 0.4472135954999579, y = 0.8944271909999159)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"Fundamentally, this happens because of Zygote's default adjoint for getfield.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> gradient(a -> a.x, Point(1, 2))\n((x = 1, y = nothing),)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"We can overload this by modifying the getters height and width.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> @adjoint width(p::Point) = p.x, x̄ -> (Point(x̄, 0),)\n\njulia> @adjoint height(p::Point) = p.y, ȳ -> (Point(0, ȳ),)\n\njulia> Zygote.refresh() # currently needed when defining new adjoints\n\njulia> gradient(a -> height(a), Point(1, 2))\n(Point(0.0, 1.0),)\n\njulia> gradient(a -> dist(a), Point(1, 2))[1]\nPoint(0.4472135954999579, 0.8944271909999159)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"If you do this you should also overload the Point constructor, so that it can handle a Point gradient (otherwise this function will error).","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> @adjoint Point(a, b) = Point(a, b), p̄ -> (p̄.x, p̄.y)\n\njulia> gradient(x -> dist(Point(x, 1)), 1)\n(0.7071067811865475,)","category":"page"},{"location":"adjoints/#Advanced-Adjoints-1","page":"Custom Adjoints","title":"Advanced Adjoints","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"We usually use custom adjoints to add gradients that Zygote can't derive itself (for example, because they ccall to BLAS). But there are some more advanced and fun things we can to with @adjoint.","category":"page"},{"location":"adjoints/#Gradient-Hooks-1","page":"Custom Adjoints","title":"Gradient Hooks","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> hook(f, x) = x\nhook (generic function with 1 method)\n\njulia> @adjoint hook(f, x) = x, x̄ -> (nothing, f(x̄))","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"hook doesn't seem that interesting, as it doesn't do anything. But the fun part is in the adjoint; it's allowing us to apply a function f to the gradient of x.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> gradient((a, b) -> hook(-, a)*b, 2, 3)\n(-3.0, 2.0)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"We could use this for debugging or modifying gradients (e.g. gradient clipping).","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> gradient((a, b) -> hook(ā -> @show(ā), a)*b, 2, 3)\nā = 3.0\n(3.0, 2.0)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"Zygote provides both hook and @showgrad so you don't have to write these yourself.","category":"page"},{"location":"adjoints/#Checkpointing-1","page":"Custom Adjoints","title":"Checkpointing","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"A more advanced example is checkpointing, in which we save memory by re-computing the pullback pass of a function during the backwards pass. To wit:","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> checkpoint(f, x) = f(x)\ncheckpoint (generic function with 1 method)\n\njulia> @adjoint checkpoint(f, x) = f(x), ȳ -> Zygote._pullback(f, x)[2](ȳ)\n\njulia> gradient(x -> checkpoint(sin, x), 1)\n(0.5403023058681398,)","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"If a function has side effects we'll see that the pullback pass happens twice, as expected.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> foo(x) = (println(x); sin(x))\nfoo (generic function with 1 method)\n\njulia> gradient(x -> checkpoint(foo, x), 1)\n1\n1\n(0.5403023058681398,)","category":"page"},{"location":"adjoints/#Gradient-Reflection-1","page":"Custom Adjoints","title":"Gradient Reflection","text":"","category":"section"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"It's easy to check whether the code we're running is currently being differentiated.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"isderiving() = false\n\n@adjoint isderiving() = true, _ -> nothing","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"A more interesting example is to actually detect how many levels of nesting are going on.","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"nestlevel() = 0\n\n@adjoint nestlevel() = nestlevel()+1, _ -> nothing","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"Demo:","category":"page"},{"location":"adjoints/#","page":"Custom Adjoints","title":"Custom Adjoints","text":"julia> function f(x)\n println(nestlevel(), \" levels of nesting\")\n return x\n end\nf (generic function with 1 method)\n\njulia> grad(f, x) = gradient(f, x)[1]\ngrad (generic function with 1 method)\n\njulia> f(1);\n0 levels of nesting\n\njulia> grad(f, 1);\n1 levels of nesting\n\njulia> grad(x -> x*grad(f, x), 1);\n2 levels of nesting","category":"page"},{"location":"limitations/#Design-Limitations-1","page":"Limitations","title":"Design Limitations","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Zygote aims to support differentiating any Julia code, but it still has a few limitations. Notably, you might encounter errors when trying to differentiate:","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"array mutation,\ntry/catch statements,\n\"foreign call\" expressions.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"This section gives examples where each of these errors occurs, as well as possible work-arounds.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Below, it also describes some known bugs in expressions Zygote ought to be able to handle.","category":"page"},{"location":"limitations/#Array-mutation-1","page":"Limitations","title":"Array mutation","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Array mutation is by far the most commonly encountered Zygote limitation.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Automatic differentiation (AD) systems like Zygote are built on basic principles of calculus where we encounter pure functions. This means that the function, y = f(x), does not modify x and only produces the output y based on x. If we have a chain of functions, such as y = h(g(f(x))), we can apply the chain rule to differentiate it. AD systems are built to programmatically apply the chain rule to a series of function calls. Unfortunately, typical programs do not behave this way. We might allocate some memory, x, then call a function y = f!(x) that modifies x to produce the output y. This mutating behavior is a side-effect of f!. Side-effects are difficult for AD systems to handle, because the must track changes to mutated variables and store older versions of the variable. For these reasons, Zygote does not handle array mutation for now.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Let's explore this with a more concrete example. Here we define a simple mutating function, f!, which modifies the elements of its input argument, x, in place.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function f!(x)\n x .= 2 .* x\n\n return x\nend","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Let's see what happens when we differentiate f!","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"julia> gradient(rand(3)) do x\n sum(f!(x))\n end\nERROR: Mutating arrays is not supported -- called copyto!(Vector{Float64}, ...)\nThis error occurs when you ask Zygote to differentiate operations that change\nthe elements of arrays in-place (e.g. setting values with x .= ...)\n\nPossible fixes:\n- avoid mutating operations (preferred)\n- or read the documentation and solutions for this error\n https://fluxml.ai/Zygote.jl/latest/limitations\n\nStacktrace:\n ...","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling copyto! (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes x .= ... which is given as an example of array mutation. Other examples of mutating operations include:","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"setting values (x .= ...)\nappending/popping values (push!(x, v) / pop!(x))\ncalling mutating functions (mul!(C, A, B))","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"warning: Warning\nNon-mutating functions might also use mutation under the hood. This can be done for performance reasons or code re-use.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function g!(x, y)\n x .= 2 .* y\n\n return x\nend\ng(y) = g!(similar(y), y)","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Here g is a \"non-mutating function,\" and it indeed does not mutate y, its only argument. But it still allocates a new array and calls g! on this array which will result in a mutating operation. You may encounter such functions when working with another package.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Specifically for array mutation, we can use Zygote.Buffer to re-write our function. For example, let's fix the function g! above.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function g!(x, y)\n x .= 2 .* y\n\n return x\nend\n\nfunction g(y)\n x = Zygote.Buffer(y) # Buffer supports syntax like similar\n g!(x, y)\n return copy(x) # this step makes the Buffer immutable (w/o actually copying)\nend\n\njulia> gradient(rand(3)) do y\n sum(g(y))\n end\n([2.0, 2.0, 2.0],)","category":"page"},{"location":"limitations/#Try-catch-statements-1","page":"Limitations","title":"Try-catch statements","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Any expressions involving try/catch statements is not supported.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function tryme(x)\n try\n 2 * x\n catch e\n throw(e)\n end\nend\n\njulia> gradient(rand(3)) do x\n sum(tryme(x))\n end\nERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported.\nRefer to the Zygote documentation for fixes.\nhttps://fluxml.ai/Zygote.jl/latest/limitations\n\nStacktrace:\n ...","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Here tryme uses a try/catch statement, and Zygote throws an error when trying to differentiate it as expected. try/catch expressions are used for error handling, but they are less common in Julia compared to some other languages.","category":"page"},{"location":"limitations/#Foreign-call-expressions-1","page":"Limitations","title":"Foreign call expressions","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Foreign call expressions refer to expressions that call external libraries such as code written in C or Fortran. You may want to read more about these calls in the Julia documentation. Scientific computing libraries in Julia may call established C or Fortran libraries under the hood. Since the underlying code for a foreign call expression is not in Julia, it is not possible for Zygote to differentiate this expression.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Below, we define a function that calls a standard C function, clock. This function returns the Unix clock as an Int32.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"julia> jclock(x) = ccall(:clock, Int32, ()) * 2\njclock (generic function with 1 method)\n\njulia> jclock(2)\n30921278\n\njulia> gradient(jclock, rand())\nERROR: Can't differentiate foreigncall expression\nYou might want to check the Zygote limitations documentation.\nhttps://fluxml.ai/Zygote.jl/latest/limitations\n\nStacktrace:\n ...","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"jclock will multiply the result of our C function by an argument. When we try to differentiate with respect to this argument, we get an foreigncall error.","category":"page"},{"location":"limitations/#Solutions-1","page":"Limitations","title":"Solutions","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"For all of the errors above, the suggested solutions are similar. You have the following possible work arounds available (in order of preference):","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"avoid the error-inducing operation (e.g. do not use mutating functions)\ndefine a custom ChainRulesCore.rrule\nopen an issue on Zygote","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using try/catch statements, try to use more graceful error handling such as returning nothing or another sentinel value. Recall that array mutation can also be avoided by using Zygote.Buffer as discussed above.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write a custom rrule using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. This solution is the only solution available for foreign call expressions. Below, we provide a custom rrule for jclock.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"jclock(x) = ccall(:clock, Int32, ()) * x\n\nfunction ChainRulesCore.rrule(::typeof(jclock), x)\n y = jclock(x)\n pb(ȳ) = (ChainRulesCore.NoTangent(), ȳ * y)\n\n return y, pb\nend\n\njulia> gradient(jclock, rand())\n(674298.4243400148,)","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Lastly, if the code causing problems can be fixed, but it is package code instead of your code, then you should open an issue. For functions built into Julia or its standard libraries, you can open an issue with Zygote.jl or ChainRules.jl. For functions in other packages, you can open an issue with the corresponding package issue tracker.","category":"page"},{"location":"limitations/#Known-Issues-1","page":"Limitations","title":"Known Issues","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Zygote's issue tracker has the current list of open bugs. There are some general principles about things you may wish to avoid if you can:","category":"page"},{"location":"limitations/#mutable-structs-1","page":"Limitations","title":"mutable structs","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Zygote has limited support for mutation, and in particular will allow you to change a field in some mutable struct X; a; b; end by setting x.a = val.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"However, this has many limitations and should be avoided if possible.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"The simple solution is to use only immutable structs. ","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"If you need to modify them, using something like @set from Accessors.jl should work well. This returns a new object, but does not have side-effects on other copies of it. ","category":"page"},{"location":"limitations/#Re-using-variable-names-1","page":"Limitations","title":"Re-using variable names","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"It is common to accumulate values in a loop by re-binding the same variable name to a new value many times, for example:","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"function mysum(x::Real, n::Int)\n tot = 0.0\n for i in 1:n\n tot += x^n # binds symbol `tot` to new value\n end\n return tot\nend","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"However, sometimes such re-binding confuses Zygote, especially if the type of the value changes. Especially if the variable is \"boxed\", as will happen if you re-bind from within a closure (such as the function created by a do block).","category":"page"},{"location":"limitations/#Second-derivatives-1","page":"Limitations","title":"Second derivatives","text":"","category":"section"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"In principle Zygote supports taking derivatives of derivatives. There are, however, a few problems:","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Quite a few of its rules are not written in a way that is itself differentiable. For instance they may work by making an array then writing into it, which is mutation of the sort forbidden above. \nThe complexity of the code grows rapidly, as Zygote differentiates its own un-optimised output.\nReverse mode over reverse mode is seldom the best algorithm.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"The issue tracker has a label for second order, which will outline where the bodies are buried.","category":"page"},{"location":"limitations/#","page":"Limitations","title":"Limitations","text":"Often using a different AD system over Zygote is a better solution. This is what hessian does, using ForwardDiff over Zygote, but other combinations are possible. (Note that rules defined here mean that Zygote over ForwardDiff is translated to ForwardDiff over ForwardDiff.)","category":"page"},{"location":"glossary/#Glossary-1","page":"Glossary","title":"Glossary","text":"","category":"section"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Differentiation is a minefield of conflicting and overlapping terminology, partly because the ideas have been re-discovered in many different fields (e.g. calculus and differential geometry, the traditional AD community, deep learning, finance, etc.) Many of these terms are not well-defined and others may disagree on the details. Nevertheless, we aim to at least say how we use these terms, which will be helpful when reading over Zygote issues, discussions and source code.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"The list is certainly not complete; if you see new terms you'd like defined, or would like to add one yourself, please do open an issue or PR.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Adjoint: See pullback. Used when defining new pullbacks (i.e. the @adjoint macro) since this involves defining the adjoint of the Jacobian, in most cases.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Backpropagation: Essentially equivalent to \"reverse-mode AD\". Used particularly in the machine learning world to refer to simple chains of functions f(g(h(x))), but has generalised beyond that.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Derivative: Given a scalar function y = f(x), the derivative is fracpartial ypartial x. \"Partial\" is taken for granted in AD; there's no interesting distinction between partial and total derivatives for our purposes. It's all in the eye of the beholder.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Differential: Given a function f(x), the linearisation partial f such that f(x + epsilon) approx f(x) + partial f epsilon. This is a generalisation of the derivative since it applies to, for example, vector-to-vector functions (partial f is a Jacobian) and holomorphic complex functions (partial f is the first Wirtinger derivative). This is not, in general, what Zygote calculates, though differentials can usually be derived from gradients.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"IR: Intermediate Representation. Essentially source code, but usually lower level – e.g. control flow constructs like loops and branches have all been replaced by gotos. The idea is that it's harder for humans to read/write but easier to manipulate programmatically. Worth looking at SSA form as a paradigmatic example.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Gradient: See sensitivity. There is no technical difference in Zygote's view, though \"gradient\" sometimes distinguishes the sensitivity we actually want from e.g. the internal ones that Zygote produces as it backpropagates.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Graph: ML people tend to think of models as \"computation graphs\", but this is no more true than any program is a graph. In fact, pretty much anything is a graph if you squint hard enough. This also refers to the data structure that e.g. TensorFlow and PyTorch build to represent your model, but see trace for that.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Pullback: Given y = f(x) the function bar x = back(bar y). In other words, the function back in y, back = Zygote.pullback(f, x).","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Sensitivity: Used to refer to the gradient bar x = fracpartial lpartial x with some scalar loss l. In other words, you have a value x (which need not be scalar) at some point in your program, and bar x tells you how you should change that value to decrease the loss. In the AD world, sometimes used to refer to adjoint rules.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Source to Source Differentiation: Or Source Code Transformation (SCT). As opposed to tracing programs to simplify them, an alternative is to operate directly on a language's source code or IR, generating new source code for pullbacks. This describes Zygote, Swift for TensorFlow, Tapenade and a few other old ADs that worked on C source files. Zygote and Swift are unusual in that they work on in-memory IR rather than text source.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"To an extent, tracing ADs can be viewed as source transform of a Wengert list / trace. The key difference is that the trace is a lossy representation of the original semantics, which causes problems with e.g. control flow. Systems which can preserve some of those semantics (e.g. autograph) begin to blur the line here, though they are still not nearly as expressive as language IRs.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Symbolic Differentiation: Used to refer to differentiation of \"mathematical expressions\", that is, things like 3x^2 + sin(x). Often distinguished from AD, though this is somewhat arbitrary; you can happily produce a symbolic adjoint for a Wengert list, the only difference being that you're allowed to make variable bindings. So it's really just a special case of AD on an unusually limited language.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Tape: This term can refer to pretty much any part of an AD implementation. In particular confusion is caused by conflating the trace with the set of values sometimes closed over by a pullback. Autograd has a combined trace/closure data structure which is usually described as the tape. On the other hand, PyTorch described their implementation as tape-free because the trace/closure is stored as a DAG rather than a vector, so basically all bets are off here.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Trace: A recording of each mathematical operation used by a program, made at runtime and usually forming a Wengert list. Traces may or may not also record actual runtime values (e.g. PyTorch vs. TensorFlow). They can often be treated as an IR and compiled, but are distinguished from true IRs in that they unroll and inline all control flow, functions and data structures. The tracing process can be thought of as a kind of partial evaluation, though tracers are typically much less worried about losing information.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Vector-Jacobian product: see pullback. So called because all pullbacks are linear functions that can be represented by (left) multiplication with the Jacobian matrix.","category":"page"},{"location":"glossary/#","page":"Glossary","title":"Glossary","text":"Wengert List: A set of simple variable assignments and mathematical expressions, forming a directed graph. Can be thought of as a limited programming language with variable bindings and numerical functions but no control flow or data structures. If you trace a program for AD it will typically take this form.","category":"page"},{"location":"#Zygote-1","page":"Home","title":"Zygote","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"Welcome! Zygote extends the Julia language to support differentiable programming. With Zygote you can write down any Julia code you feel like – including using existing Julia packages – then get gradients and optimise your program. Deep learning, ML and probabilistic programming are all different kinds of differentiable programming that you can do with Zygote.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"At least, that's the idea. We're still in beta so expect some adventures.","category":"page"},{"location":"#Setup-1","page":"Home","title":"Setup","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"Zygote can be installed from the package manager in Julia's REPL:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"] add Zygote","category":"page"},{"location":"#Taking-Gradients-1","page":"Home","title":"Taking Gradients","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"Zygote is easy to understand since, at its core, it has a one-function API (pullback), along with a few simple conveniences. Before explaining pullback, we'll look at the higher-level function gradient.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"gradient calculates derivatives. For example, the derivative of 3x^2 + 2x + 1 is 6x + 2, so when x = 5, dx = 32.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> using Zygote\n\njulia> gradient(x -> 3x^2 + 2x + 1, 5)\n(32.0,)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"gradient returns a tuple, with a gradient for each argument to the function.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> gradient((a, b) -> a*b, 2, 3)\n(3.0, 2.0)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"This will work equally well if the arguments are arrays, structs, or any other Julia type, but the function should return a scalar (like a loss or objective l, if you're doing optimisation / ML).","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> W = rand(2, 3); x = rand(3);\n\njulia> gradient(W -> sum(W*x), W)[1]\n2×3 Array{Float64,2}:\n 0.0462002 0.817608 0.979036\n 0.0462002 0.817608 0.979036\n\njulia> gradient(x -> 3x^2 + 2x + 1, 1//4)\n(7//2,)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Control flow is fully supported, including recursion.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> function pow(x, n)\n r = 1\n for i = 1:n\n r *= x\n end\n return r\n end\npow (generic function with 1 method)\n\njulia> gradient(x -> pow(x, 3), 5)\n(75.0,)\n\njulia> pow2(x, n) = n <= 0 ? 1 : x*pow2(x, n-1)\npow2 (generic function with 1 method)\n\njulia> gradient(x -> pow2(x, 3), 5)\n(75.0,)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Data structures are also supported, including mutable ones like dictionaries. Arrays are currently immutable, though this may change in future.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> d = Dict()\nDict{Any, Any}()\n\njulia> gradient(5) do x\n d[:x] = x\n d[:x] * d[:x]\n end\n(10.0,)\n\njulia> d[:x]\n5","category":"page"},{"location":"#Structs-and-Types-1","page":"Home","title":"Structs and Types","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"Julia makes it easy to work with custom types, and Zygote makes it easy to differentiate them. For example, given a simple Point type:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"import Base: +, -\n\nstruct Point\n x::Float64\n y::Float64\nend\n\na::Point + b::Point = Point(a.x + b.x, a.y + b.y)\na::Point - b::Point = Point(a.x - b.x, a.y - b.y)\ndist(p::Point) = sqrt(p.x^2 + p.y^2)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> a = Point(1, 2)\nPoint(1.0, 2.0)\n\njulia> b = Point(3, 4)\nPoint(3.0, 4.0)\n\njulia> dist(a + b)\n7.211102550927978\n\njulia> gradient(a -> dist(a + b), a)[1]\n(x = 0.5547001962252291, y = 0.8320502943378437)","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Zygote's default representation of the \"point adjoint\" is a named tuple with gradients for both fields, but this can of course be customised too.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"This means we can do something very powerful: differentiating through Julia libraries, even if they weren't designed for this. For example, colordiff might be a smarter loss function on colours than simple mean-squared-error:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> using Colors\n\njulia> colordiff(RGB(1, 0, 0), RGB(0, 1, 0))\n86.60823557376344\n\njulia> gradient(colordiff, RGB(1, 0, 0), RGB(0, 1, 0))\n((r = 0.4590887719632896, g = -9.598786801605689, b = 14.181383399012862), (r = -1.7697549557037275, g = 28.88472330558805, b = -0.044793892637761346))","category":"page"},{"location":"#Explicit-and-Implicit-Parameters-1","page":"Home","title":"Explicit and Implicit Parameters","text":"","category":"section"},{"location":"#","page":"Home","title":"Home","text":"It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights. Depending on how you write your model, there are multiple ways to explicitly take gradients with respect to parameters. For example, the function linear accepts the parameters as an argument to the model. So, we directly pass in the parameters, θ, as an argument to the function being differentiated.","category":"page"},{"location":"#","page":"Home","title":"Home","text":"gradient(f, args...)","category":"page"},{"location":"#Zygote.gradient-Tuple{Any, Vararg{Any}}","page":"Home","title":"Zygote.gradient","text":"gradient(f, args...)\n\nReturns a tuple containing ∂f/∂x for each argument x, the derivative (for scalar x) or the gradient.\n\nf(args...) must be a real number, see jacobian for array output.\n\nSee also withgradient to keep the value f(args...), and pullback for value and back-propagator.\n\njulia> gradient(*, 2.0, 3.0, 5.0)\n(15.0, 10.0, 6.0)\n\njulia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])\n([14.0, 22.0, 26.0],)\n\njulia> gradient([7, 11], 0, 1) do x, y, d\n p = size(x, d)\n sum(x.^p .+ y)\n end\n([14.0, 22.0], 2.0, nothing)\n\n\n\n\n\n","category":"method"},{"location":"#","page":"Home","title":"Home","text":"julia> linear(θ, x) = θ[:W] * x .+ θ[:b]\nlinear (generic function with 1 method)\n\njulia> x = rand(5);\n\njulia> θ = Dict(:W => rand(2, 5), :b => rand(2))\nDict{Any,Any} with 2 entries:\n :b => [0.0430585, 0.530201]\n :W => [0.923268 … 0.589691]\n\n# Alternatively, use a named tuple or struct rather than a dict.\n# θ = (W = rand(2, 5), b = rand(2))\n\njulia> θ̄ = gradient(θ -> sum(linear(θ, x)), θ)[1]\nDict{Any,Any} with 2 entries:\n :b => [1.0, 1.0]\n :W => [0.628998 … 0.433006]","category":"page"},{"location":"#","page":"Home","title":"Home","text":"We can combine the role of the dictionary and the function here by making a callable struct which contains the parameters, equivalent to a closure. Passed explicitly to gradient, we get a named tuple with the same field names:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"julia> struct Linear\n W\n b\n end\n\njulia> (l::Linear)(x) = l.W * x .+ l.b\n\njulia> model = Linear(rand(2, 5), rand(2))\nLinear([0.267663 … 0.334385], [0.0386873, 0.0203294])\n\njulia> x = rand(5);\n\njulia> dmodel = gradient(model -> sum(model(x)), model)[1]\n(W = [0.652543 … 0.683588], b = [1.0, 1.0])","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Zygote also supports another way to take gradients, via implicit parameters. Here the loss function takes zero arguments, but the variables of interest are indicated by a special Params object. The function linear which depends on W and b is executed when the loss function () -> sum(linear(x)) is called, and hence this dependence is visible to Zygote:","category":"page"},{"location":"#","page":"Home","title":"Home","text":"gradient","category":"page"},{"location":"#Zygote.gradient","page":"Home","title":"Zygote.gradient","text":"gradient(() -> loss(), ps::Params) -> Grads\n\nGradient with implicit parameters. Takes a zero-argument function, and returns a dictionary-like container, whose keys are arrays x in ps.\n\nSee also withgradient to keep the value loss().\n\njulia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];\n\njulia> g = gradient(Params([x, y])) do\n sum(x .* y .* z')\n end\nGrads(...)\n\njulia> g[x]\n2×3 Matrix{Float64}:\n 7.0 70.0 700.0\n 8.0 80.0 800.0\n\njulia> haskey(g, z) # only x and y are parameters\nfalse\n\n\n\n\n\n","category":"function"},{"location":"#","page":"Home","title":"Home","text":"julia> W = rand(2, 5); b = rand(2);\n\njulia> linear(x) = W * x .+ b\nlinear (generic function with 2 methods)\n\njulia> grads = gradient(() -> sum(linear(x)), Params([W, b]))\nGrads(...)\n\njulia> grads[W], grads[b] # access gradients using arrays as keys\n([0.652543 … 0.683588], [1.0, 1.0])","category":"page"},{"location":"#","page":"Home","title":"Home","text":"Here grads is a dictionary-like object, whose keys are the same parameters we indicated in Params. (In fact it wraps a dictionary using objectid(W) as keys, which does not change if the values in W are mutated).","category":"page"},{"location":"#","page":"Home","title":"Home","text":"This implicit style is the one presently used by Flux.jl, a closely related machine learning library. It uses structs like Linear above to define layers, and the function Flux.params(model) returns a Params object containing all the parameters of all layers. See its documentation for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.","category":"page"},{"location":"internals/#Internals-1","page":"Internals","title":"Internals","text":"","category":"section"},{"location":"internals/#What-Zygote-Does-1","page":"Internals","title":"What Zygote Does","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"These notebooks and the Zygote paper provide useful background on Zygote's transform; this page is particularly focused on implementation details.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Before we think about AD, we'll consider some simple cases. We can start by defining a function that produces pullbacks, J, explicitly for some simple functions.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"J(::typeof(sin), x) = sin(x), ȳ -> ȳ*cos(x)\nJ(::typeof(cos), x) = cos(x), ȳ -> -ȳ*sin(x)\nJ(::typeof(*), a, b) = a*b, c̄ -> (b*c̄, a*c̄)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Now we can call J to take a gradient.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"gradient(f, x...) = J(f, x...)[2](1)\n\ngradient(sin, 1) # (0.540,)\n\ngradient(*, 2, 3) # (3, 2)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Now consider a composite function that calls two simple ones:","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"function foo(x)\n a = sin(x)\n b = cos(a)\n return b\nend","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"We can easily differentiate foo if we can differentiate the functions it calls. If we can get pullbacks via J, the pullback for foo looks as follows. Where the forward pass calculates x -> a -> b, the backwards takes b̄ -> ā -> x̄ via the pullbacks.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"function J(::typeof(foo), x)\n a, da = J(sin, x)\n b, db = J(cos, a)\n return b, function(b̄)\n ā, = db(b̄)\n x̄, = da(ā)\n return x̄\n end\nend\n\ngradient(foo, 1) # (-0.403,)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Things get just a little more complex when control flow is involved. You can see that the derived adjoint for pow mirrors the original function, except that the loop runs in reverse. The easiest way to see why it looks like this is to imagine unrolling the loop n times, working out the adjoint, and then turning it back into a loop.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"function pow(x, n) # x^n\n r = 1\n for _ = 1:n\n r *= x\n end\n return r\nend\n\nfunction J(::typeof(pow), x, n)\n r = 1\n Js = []\n for i = 1:n\n r, back = J(*, r, x)\n push!(Js, back)\n end\n return r, function(r̄)\n x̄ = 0\n for i = n:-1:1\n r̄, x̄′ = Js[i](r̄)\n x̄ += x̄′\n end\n return (x̄, 0)\n end\nend\n\ngradient(pow, 2, 3) # (12, 0)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Despite being reasonably fiddly, this is a fully mechanical transformation, so the only remaining thing is to automate it – a small matter of programming.","category":"page"},{"location":"internals/#Closures-1","page":"Internals","title":"Closures","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"The J function here corresponds to pullback in Zygote. However, pullback is actually a wrapper around the lower level _pullback function.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> y, back = Zygote._pullback(sin, 0.5);\n\njulia> back(1)\n(nothing, 0.8775825618903728)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Why the extra nothing here? This actually represents the gradient of the function sin. This is often nothing, but when we have closures the function contains data we need gradients for.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> f = let a = 3; x -> x*a; end\n#19 (generic function with 1 method)\n\njulia> y, back = Zygote._pullback(f, 2);\n\njulia> back(1)\n((a = 2,), 3)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"This is a minor point for the most part, but _pullback will come up in future examples.","category":"page"},{"location":"internals/#Entry-Points-1","page":"Internals","title":"Entry Points","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"You might notice that Zygote is, in effect, just a macro. We could happily implement Zygote by writing definitions like","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"@differentiable foo(x) = sin(cos(x))","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"which would expand to generate an appropriate overload to J. As long as every function we want to differentiate is annotated, this will work just fine. However, it's obviously not ideal to have to annotate every function inside every Julia package in order to make it differentiable.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"This is where generated functions come in. Making J a generated function allows us to apply the Zygote macro on an as-needed basis; calling J(f, x...) looks up the code for f(x...), transforms it, and then behaves as if you had defined J for that specific function ahead of time.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"When we look up the code, we actually get lowered (desugared) code rather than an AST.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> foo(x) = baz(bar(x))\nfoo (generic function with 1 method)\n\njulia> @code_lowered foo(1)\nCodeInfo(\n1 ─ %1 = (Main.bar)(x)\n│ %2 = (Main.baz)(%1)\n└── return %2","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"We convert the code to SSA form using Julia's built-in IR data structure, after which it looks like this.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> Zygote.@code_ir foo(1)\n1 1 ─ %1 = (Main.bar)(_2)::Any\n │ %2 = (Main.baz)(%1)::Any\n └── return %2","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"(There isn't much difference unless there's some control flow.)","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"The code is then differentiated by the code in compiler/reverse.jl. You can see the output with @code_adjoint.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> Zygote.@code_adjoint foo(1)\n1 1 ─ %1 = (Zygote._pullback)(_2, Zygote.unwrap, Main.bar)::Any\n │ %2 = (Base.getindex)(%1, 1)::Any\n │ (Base.getindex)(%1, 2)::Any\n │ %4 = (Zygote._pullback)(_2, %2, _4)::Any\n │ %5 = (Base.getindex)(%4, 1)::Any\n │ (Base.getindex)(%4, 2)::Any\n │ %7 = (Zygote._pullback)(_2, Zygote.unwrap, Main.baz)::Any\n │ %8 = (Base.getindex)(%7, 1)::Any\n │ (Base.getindex)(%7, 2)::Any\n │ %10 = (Zygote._pullback)(_2, %8, %5)::Any\n │ %11 = (Base.getindex)(%10, 1)::Any\n │ (Base.getindex)(%10, 2)::Any\n └── return %11\n 1 ─ %1 = Δ()::Any\n1 │ %2 = (@12)(%1)::Any\n │ %3 = (Zygote.gradindex)(%2, 1)::Any\n │ %4 = (Zygote.gradindex)(%2, 2)::Any\n │ (@9)(%3)::Any\n │ %6 = (@6)(%4)::Any\n │ %7 = (Zygote.gradindex)(%6, 1)::Any\n │ %8 = (Zygote.gradindex)(%6, 2)::Any\n │ (@3)(%7)::Any\n │ %10 = (Zygote.tuple)(nothing, %8)::Any\n └── return %10\n, [1])","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"This code is quite verbose, mainly due to all the tuple unpacking (gradindex is just like getindex, but handles nothing gracefully). There are two pieces of IR here, one for the modified pullback pass and one for the pullback closure. The @ nodes allow the closure to refer to values from the pullback pass, and the Δ() represents the incoming gradient ȳ. In essence, this is just what we wrote above by hand for J(::typeof(foo), x).","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"compiler/emit.jl lowers this code into runnable IR (e.g. by turning @ references into getfields and stacks), and it's then turned back into lowered code for Julia to run.","category":"page"},{"location":"internals/#Closure-Conversion-1","page":"Internals","title":"Closure Conversion","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"There are no closures in lowered Julia code, so we can't actually emit one directly in lowered code. To work around this we have a trick: we have a generic struct like","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"struct Pullback{F}\n data\nend","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"We can put whatever we want in data, and the F will be the signature for the original call, like Tuple{typeof(foo),Int}. When the pullback gets called it hits another generated function which emits the pullback code.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"In hand written code this would look like:","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"struct Pullback{F}\n data\nend\n\nfunction J(::typeof(foo), x)\n a, da = J(sin, x)\n b, db = J(cos, a)\n return b, Pullback{typeof(foo)}((da, db))\nend\n\nfunction (p::Pullback{typeof(foo)})(b̄)\n da, db = p.data[1], p.data[2]\n ā = db(b̄)\n x̄ = da(ā)\n return x̄\nend","category":"page"},{"location":"internals/#Debugging-1","page":"Internals","title":"Debugging","text":"","category":"section"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Say some of our code is throwing an error.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"bad(x) = x\n\nZygote.@adjoint bad(x) = x, _ -> error(\"bad\")\n\nfoo(x) = bad(sin(x))\n\ngradient(foo, 1) # error!","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Zygote can usually give a stacktrace pointing right to the issue here, but in some cases there are compiler crashes that make this harder. In these cases it's best to (a) use _pullback and (b) take advantage of Zygote's recursion to narrow down the problem function.","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"julia> y, back = Zygote._pullback(foo, 1);\n\njulia> back(1) # just make up a value here, it just needs to look similar to `y`\nERROR: bad\n\n# Ok, so we try functions that foo calls\n\njulia> y, back = Zygote._pullback(sin, 1);\n\njulia> back(1)\n(nothing, 0.5403023058681398)\n\n# Looks like that's fine\n\njulia> y, back = Zygote._pullback(bad, 1);\n\njulia> back(1) # ok, here's our issue. Lather, rinse, repeat.\nERROR: bad","category":"page"},{"location":"internals/#","page":"Internals","title":"Internals","text":"Of course, our goal is that you never have to do this, but until Zygote is more mature it can be a useful way to narrow down test cases.","category":"page"}] } diff --git a/dev/utils/index.html b/dev/utils/index.html index f7ff066f0..869b6f7d8 100644 --- a/dev/utils/index.html +++ b/dev/utils/index.html @@ -23,7 +23,7 @@ ([4 4 4], nothing) julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple -([4 4 4], (6, 1))source
      jacobian(loss, ::Params)

      Like gradient with implicit parameters, this method takes a zero-argument function and returns an IdDict-like object, now containing the Jacobian for each parameter.

      Examples

      julia> xs = [1 2; 3 4]; ys = [5,7,9];
      +([4 4 4], (6, 1))
      source
      jacobian(loss, ::Params)

      Like gradient with implicit parameters, this method takes a zero-argument function and returns an IdDict-like object, now containing the Jacobian for each parameter.

      Examples

      julia> xs = [1 2; 3 4]; ys = [5,7,9];
       
       julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))
       Grads(...)
      @@ -36,7 +36,7 @@
       julia> Jxy[xs]
       2×4 Matrix{Int64}:
        2  6  4  8
      - 2  6  4  8
      source
      Zygote.hessianFunction
      hessian(f, x)

      Construct the Hessian ∂²f/∂x², where x is a real number or an array, and f(x) is a real number. When x is an array, the result is a matrix H[i,j] = ∂²f/∂x[i]∂x[j], using linear indexing x[i] even if the argument is higher-dimensional.

      This uses forward over reverse, ForwardDiff over Zygote, calling hessian_dual(f, x). See hessian_reverse for an all-Zygote alternative.

      See also diaghessian to compute only the diagonal part.

      Examples

      julia> hessian(x -> x[1]*x[2], randn(2))
      + 2  6  4  8
      source
      Zygote.hessianFunction
      hessian(f, x)

      Construct the Hessian ∂²f/∂x², where x is a real number or an array, and f(x) is a real number. When x is an array, the result is a matrix H[i,j] = ∂²f/∂x[i]∂x[j], using linear indexing x[i] even if the argument is higher-dimensional.

      This uses forward over reverse, ForwardDiff over Zygote, calling hessian_dual(f, x). See hessian_reverse for an all-Zygote alternative.

      See also diaghessian to compute only the diagonal part.

      Examples

      julia> hessian(x -> x[1]*x[2], randn(2))
       2×2 Matrix{Float64}:
        0.0  1.0
        1.0  0.0
      @@ -49,7 +49,7 @@
        0   0   0  24
       
       julia> hessian(sin, pi/2)
      --1.0
      source
      Zygote.hessian_reverseFunction
      hessian_reverse(f, x)

      This should be equivalent to hessian(f, x), but implemented using reverse over reverse mode, all Zygote. (This is usually much slower, and more likely to find errors.)

      source
      Zygote.diaghessianFunction
      diaghessian(f, args...) -> Tuple

      Diagonal part of the Hessian. Returns a tuple containing, for each argument x, h of the same shape with h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]. The original evaluation y = f(args...) must give a real number y.

      For one vector argument x, this is equivalent to (diag(hessian(f,x)),). Like hessian it uses ForwardDiff over Zygote.

      Warning

      For arguments of any type except Number & AbstractArray, the result is nothing.

      Examples

      julia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]
      +-1.0
      source
      Zygote.hessian_reverseFunction
      hessian_reverse(f, x)

      This should be equivalent to hessian(f, x), but implemented using reverse over reverse mode, all Zygote. (This is usually much slower, and more likely to find errors.)

      source
      Zygote.diaghessianFunction
      diaghessian(f, args...) -> Tuple

      Diagonal part of the Hessian. Returns a tuple containing, for each argument x, h of the same shape with h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]. The original evaluation y = f(args...) must give a real number y.

      For one vector argument x, this is equivalent to (diag(hessian(f,x)),). Like hessian it uses ForwardDiff over Zygote.

      Warning

      For arguments of any type except Number & AbstractArray, the result is nothing.

      Examples

      julia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]
       2×2 Matrix{Int64}:
         6  12
        18  24
      @@ -66,7 +66,7 @@
       julia> hessian(xy -> atan(xy[1], xy[2]), [1, 2])  # full Hessian is not diagonal
       2×2 Matrix{Float64}:
        -0.16  -0.12
      - -0.12   0.16
      source

      Zygote also provides a set of helpful utilities. These are all "user-level" tools – in other words you could have written them easily yourself, but they live in Zygote for convenience.

      See ChainRules.ignore_derivatives if you want to exclude some of your code from the gradient calculation. This replaces previous Zygote-specific ignore and dropgrad functionality.

      Zygote.withgradientFunction
      withgradient(f, args...)
      + -0.12   0.16
      source

      Zygote also provides a set of helpful utilities. These are all "user-level" tools – in other words you could have written them easily yourself, but they live in Zygote for convenience.

      See ChainRules.ignore_derivatives if you want to exclude some of your code from the gradient calculation. This replaces previous Zygote-specific ignore and dropgrad functionality.

      Zygote.withgradientFunction
      withgradient(f, args...)
       withgradient(f, ::Params)

      Returns both the value of the function and the gradient, as a named tuple.

      julia> y, ∇ = withgradient(/, 1, 2)
       (val = 0.5, grad = (0.5, -0.25))
       
      @@ -87,8 +87,8 @@
       
       julia> res.grad[w]
       1-element Vector{Float64}:
      - 6.0
      source
      Zygote.withjacobianFunction
      withjacobian(f, args...)

      Returns both the value f(args...) and the jacobian as a named tuple.

      julia> withjacobian(cumsum, [1,2,3])
      -(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))
      source
      Zygote.@showgradMacro
      @showgrad(x) -> x

      Much like @show, but shows the gradient about to accumulate to x. Useful for debugging gradients.

      julia> gradient(2, 3) do a, b
      + 6.0
      source
      Zygote.withjacobianFunction
      withjacobian(f, args...)

      Returns both the value f(args...) and the jacobian as a named tuple.

      julia> withjacobian(cumsum, [1,2,3])
      +(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))
      source
      Zygote.@showgradMacro
      @showgrad(x) -> x

      Much like @show, but shows the gradient about to accumulate to x. Useful for debugging gradients.

      julia> gradient(2, 3) do a, b
                @showgrad(a)*b
              end
       ∂(a) = 3
      @@ -103,7 +103,7 @@
                a*b
              end
       ∂(a) = nothing
      -(3, 2)
      source
      Zygote.hookFunction
      hook(x̄ -> ..., x) -> x

      Gradient hooks. Allows you to apply an arbitrary function to the gradient for x.

      julia> gradient(2, 3) do a, b
      +(3, 2)
      source
      Zygote.hookFunction
      hook(x̄ -> ..., x) -> x

      Gradient hooks. Allows you to apply an arbitrary function to the gradient for x.

      julia> gradient(2, 3) do a, b
                hook(ā -> @show(ā), a)*b
              end
       ā = 3
      @@ -112,7 +112,7 @@
       julia> gradient(2, 3) do a, b
                hook(-, a)*b
              end
      -(-3, 2)
      source
      Zygote.BufferType
      Buffer(xs, ...)

      Buffer is an array-like type which is mutable when taking gradients. You can construct a Buffer with the same syntax as similar (e.g. Buffer(xs, 5)) and then use normal indexing. Finally, use copy to get back a normal array.

      For example:

      julia> function vstack(xs)
      +(-3, 2)
      source
      Zygote.BufferType
      Buffer(xs, ...)

      Buffer is an array-like type which is mutable when taking gradients. You can construct a Buffer with the same syntax as similar (e.g. Buffer(xs, 5)) and then use normal indexing. Finally, use copy to get back a normal array.

      For example:

      julia> function vstack(xs)
                  buf = Buffer(xs, length(xs), 5)
                  for i = 1:5
                    buf[:, i] = xs
      @@ -128,7 +128,7 @@
        3  3  3  3  3
       
       julia> gradient(x -> sum(vstack(x)), [1, 2, 3])
      -([5.0, 5.0, 5.0],)

      Buffer is not an AbstractArray and can't be used for linear algebra operations like matrix multiplication. This prevents it from being captured by pullbacks.

      copy is a semantic copy, but does not allocate memory. Instead the Buffer is made immutable after copying.

      source
      Zygote.forwarddiffFunction
      forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD) -> f(x)

      Runs f(x) as usual, but instructs Zygote to differentiate f using forward mode, rather than the usual reverse mode. The chunk_threshold argument controls the maximum chunk size (c.f. ForwardDiff documentation).

      Forward mode takes time linear in length(x) but only has constant memory overhead, and is very efficient for scalars, so in some cases this can be a useful optimisation.

      julia> function pow(x, n)
      +([5.0, 5.0, 5.0],)

      Buffer is not an AbstractArray and can't be used for linear algebra operations like matrix multiplication. This prevents it from being captured by pullbacks.

      copy is a semantic copy, but does not allocate memory. Instead the Buffer is made immutable after copying.

      source
      Zygote.forwarddiffFunction
      forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD) -> f(x)

      Runs f(x) as usual, but instructs Zygote to differentiate f using forward mode, rather than the usual reverse mode. The chunk_threshold argument controls the maximum chunk size (c.f. ForwardDiff documentation).

      Forward mode takes time linear in length(x) but only has constant memory overhead, and is very efficient for scalars, so in some cases this can be a useful optimisation.

      julia> function pow(x, n)
                r = one(x)
                for i = 1:n
                  r *= x
      @@ -151,7 +151,7 @@
         forwarddiff([a, b]) do (a, b)
           a*b
         end
      -end
      source
      Zygote.checkpointedFunction
      checkpointed(f, xs...)

      Use gradient checkpointing on the call f(xs...). This means that checkpointed(f, xs...) === f(xs...), but when computing the derivative intermediate results from the forward pass of f will not be stored. Instead the forward pass will be repeated, when computing the derivative. This saves memory at the cost of increasing execution time.

      Warning

      If f is not a pure function, checkpointed will likely give wrong results.

      source

      Params and Grads can be copied to and from arrays using the copy! function.

      Working with Grads

      Map, broadcast, and iteration are supported for the dictionary-like Grads objects. These operations are value based and preserve the keys.

      using Zygote, Test
      +end
      source
      Zygote.checkpointedFunction
      checkpointed(f, xs...)

      Use gradient checkpointing on the call f(xs...). This means that checkpointed(f, xs...) === f(xs...), but when computing the derivative intermediate results from the forward pass of f will not be stored. Instead the forward pass will be repeated, when computing the derivative. This saves memory at the cost of increasing execution time.

      Warning

      If f is not a pure function, checkpointed will likely give wrong results.

      source

      Params and Grads can be copied to and from arrays using the copy! function.

      Working with Grads

      Map, broadcast, and iteration are supported for the dictionary-like Grads objects. These operations are value based and preserve the keys.

      using Zygote, Test
       
       w, x1, x2, b = rand(2), rand(2), rand(2), rand(2)
       
      @@ -180,4 +180,4 @@
       # note that gradients must be w.r.t. to the same parameter key set
       gs3 = gradient(() -> sum(tanh.(w .* x2)), Params([w]))
       # gs3 does not have the key b
      -@test_throws ArgumentError gs1 .+ gs3
      +@test_throws ArgumentError gs1 .+ gs3