From ded74a0155bd0a322c7cfc0aac16a9ae3a9a732b Mon Sep 17 00:00:00 2001 From: "Documenter.jl" Date: Mon, 7 Oct 2024 09:26:24 +0000 Subject: [PATCH] build based on c00f198 --- previews/PR284/.documenter-siteinfo.json | 2 +- previews/PR284/algorithmic_differentiation/index.html | 2 +- previews/PR284/debug_mode/index.html | 4 ++-- previews/PR284/debugging_and_mwes/index.html | 4 ++-- previews/PR284/index.html | 2 +- previews/PR284/known_limitations/index.html | 2 +- previews/PR284/mathematical_interpretation/index.html | 6 +++--- previews/PR284/understanding_intro/index.html | 2 +- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/previews/PR284/.documenter-siteinfo.json b/previews/PR284/.documenter-siteinfo.json index 6724acfb..07023849 100644 --- a/previews/PR284/.documenter-siteinfo.json +++ b/previews/PR284/.documenter-siteinfo.json @@ -1 +1 @@ -{"documenter":{"julia_version":"1.10.5","generation_timestamp":"2024-10-07T09:25:07","documenter_version":"1.7.0"}} \ No newline at end of file +{"documenter":{"julia_version":"1.10.5","generation_timestamp":"2024-10-07T09:26:19","documenter_version":"1.7.0"}} \ No newline at end of file diff --git a/previews/PR284/algorithmic_differentiation/index.html b/previews/PR284/algorithmic_differentiation/index.html index cdd7b0c2..d84ea5af 100644 --- a/previews/PR284/algorithmic_differentiation/index.html +++ b/previews/PR284/algorithmic_differentiation/index.html @@ -19,4 +19,4 @@ D f [x] (\dot{x}) &= [(D \mathcal{l} [g(x)]) \circ (D g [x])](\dot{x}) \nonumber \\ &= \langle \bar{y}, D g [x] (\dot{x}) \rangle \nonumber \\ &= \langle D g [x]^\ast (\bar{y}), \dot{x} \rangle, \nonumber -\end{align}\]

from which we conclude that $D g [x]^\ast (\bar{y})$ is the gradient of the composition $l \circ g$ at $x$.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.

The above shows that if $\mathcal{Y} = \RR$ and $g$ is the function we wish to compute the gradient of, we can simply set $\bar{y} = 1$ and compute $D g [x]^\ast (\bar{y})$ to obtain the gradient of $g$ at $x$.

Summary

This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.

Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.

Asides

How does Forwards-Mode AD work?

Forwards-mode AD achieves this by breaking down $f$ into the composition $f = f_N \circ \dots \circ f_1$, where each $f_n$ is a simple function whose derivative (function) $D f_n [x_n]$ we know for any given $x_n$. By the chain rule, we have that

\[D f [x] (\dot{x}) = D f_N [x_N] \circ \dots \circ D f_1 [x_1] (\dot{x})\]

which suggests the following algorithm:

  1. let $x_1 = x$, $\dot{x}_1 = \dot{x}$, and $n = 1$
  2. let $\dot{x}_{n+1} = D f_n [x_n] (\dot{x}_n)$
  3. let $x_{n+1} = f(x_n)$
  4. let $n = n + 1$
  5. if $n = N+1$ then return $\dot{x}_{N+1}$, otherwise go to 2.

When each function $f_n$ maps between Euclidean spaces, the applications of derivatives $D f_n [x_n] (\dot{x}_n)$ are given by $J_n \dot{x}_n$ where $J_n$ is the Jacobian of $f_n$ at $x_n$.

[1]
M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).
[2]
T. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).
+\end{align}\]

from which we conclude that $D g [x]^\ast (\bar{y})$ is the gradient of the composition $l \circ g$ at $x$.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.

The above shows that if $\mathcal{Y} = \RR$ and $g$ is the function we wish to compute the gradient of, we can simply set $\bar{y} = 1$ and compute $D g [x]^\ast (\bar{y})$ to obtain the gradient of $g$ at $x$.

Summary

This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.

Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.

Asides

How does Forwards-Mode AD work?

Forwards-mode AD achieves this by breaking down $f$ into the composition $f = f_N \circ \dots \circ f_1$, where each $f_n$ is a simple function whose derivative (function) $D f_n [x_n]$ we know for any given $x_n$. By the chain rule, we have that

\[D f [x] (\dot{x}) = D f_N [x_N] \circ \dots \circ D f_1 [x_1] (\dot{x})\]

which suggests the following algorithm:

  1. let $x_1 = x$, $\dot{x}_1 = \dot{x}$, and $n = 1$
  2. let $\dot{x}_{n+1} = D f_n [x_n] (\dot{x}_n)$
  3. let $x_{n+1} = f(x_n)$
  4. let $n = n + 1$
  5. if $n = N+1$ then return $\dot{x}_{N+1}$, otherwise go to 2.

When each function $f_n$ maps between Euclidean spaces, the applications of derivatives $D f_n [x_n] (\dot{x}_n)$ are given by $J_n \dot{x}_n$ where $J_n$ is the Jacobian of $f_n$ at $x_n$.

[1]
M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).
[2]
T. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).
diff --git a/previews/PR284/debug_mode/index.html b/previews/PR284/debug_mode/index.html index 13a516ec..c76de4a0 100644 --- a/previews/PR284/debug_mode/index.html +++ b/previews/PR284/debug_mode/index.html @@ -2,5 +2,5 @@ Debug Mode · Mooncake.jl

Debug Mode

The Problem

A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like

function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})
     plus_reverse_pass(dz::Real) = NoRData(), dz, dz
     return zero_fcodual(primal(x) + primal(y))
-end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, so the appropriate tangent type is also Float32. This error may causes the reverse pass to fail loudly, but it may fail silently. It may cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
-AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.

+end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, so the appropriate tangent type is also Float32. This error may causes the reverse pass to fail loudly, but it may fail silently. It may cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
+AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.

diff --git a/previews/PR284/debugging_and_mwes/index.html b/previews/PR284/debugging_and_mwes/index.html index 4c43bb91..77eec2ac 100644 --- a/previews/PR284/debugging_and_mwes/index.html +++ b/previews/PR284/debugging_and_mwes/index.html @@ -6,5 +6,5 @@ perf_flag::Symbol=:none, interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), debug_mode::Bool=false, -)

Run standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.

This function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.

source

This approach is convenient because it can

  1. check whether AD runs at all,
  2. check whether AD produces the correct answers,
  3. check whether AD is performant, and
  4. can be used without having to manually generate tangents.

Example

For example

f(x) = Core.bitcast(Float64, x)
-Mooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)

will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)

In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.

Segfaults

These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.

+)

Run standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.

This function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.

source

This approach is convenient because it can

  1. check whether AD runs at all,
  2. check whether AD produces the correct answers,
  3. check whether AD is performant, and
  4. can be used without having to manually generate tangents.

Example

For example

f(x) = Core.bitcast(Float64, x)
+Mooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)

will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)

In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.

Segfaults

These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.

diff --git a/previews/PR284/index.html b/previews/PR284/index.html index 32bf299a..68411736 100644 --- a/previews/PR284/index.html +++ b/previews/PR284/index.html @@ -1,2 +1,2 @@ -Mooncake.jl · Mooncake.jl

Mooncake.jl

Documentation for Mooncake.jl is on its way!

Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.

Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!

+Mooncake.jl · Mooncake.jl

Mooncake.jl

Documentation for Mooncake.jl is on its way!

Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.

Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!

diff --git a/previews/PR284/known_limitations/index.html b/previews/PR284/known_limitations/index.html index 9d1a2bea..404d3efa 100644 --- a/previews/PR284/known_limitations/index.html +++ b/previews/PR284/known_limitations/index.html @@ -38,4 +38,4 @@ Mooncake.value_and_gradient!!(rule, foo, [5.0, 4.0]) # output -(4.0, (NoTangent(), [0.0, 1.0]))

The Solution

This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.

While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.

Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.

+(4.0, (NoTangent(), [0.0, 1.0]))

The Solution

This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.

While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.

Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.

diff --git a/previews/PR284/mathematical_interpretation/index.html b/previews/PR284/mathematical_interpretation/index.html index 4ba52ef7..5388c808 100644 --- a/previews/PR284/mathematical_interpretation/index.html +++ b/previews/PR284/mathematical_interpretation/index.html @@ -51,7 +51,7 @@ x::Float64 end

you will find that it is

julia> tangent_type(Bar)
 MutableTangent{@NamedTuple{x::Float64}}

Primitive Types

We've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.

One interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example

julia> tangent_type(Ptr{Float64})
-Ptr{Float64}
source

FData and RData

While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.

Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this "unique gradient address" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.

The following docstring provides the best in-depth explanation.

Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
+Ptr{Float64}
source

FData and RData

While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.

Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this "unique gradient address" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.

The following docstring provides the best in-depth explanation.

Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
 (NoFData, NoRData)

Float64

The tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so

julia> (fdata_type(Float64), rdata_type(Float64))
 (NoFData, Float64)

Vector{Float64}

The tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so

julia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))
 (Vector{Float64}, NoRData)

Tuple{Float64, Vector{Float64}, Int}

This is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:

julia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})
@@ -70,7 +70,7 @@
            z::Int
        end

has tangent type

julia> tangent_type(Bar)
 MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}

and fdata / rdata types

julia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))
-(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source

CoDuals

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.

A quick aside: Non-Differentiable Data

In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.

Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a $0$ tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.

Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.

TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.

The Rule Interface (Round 2)

Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.

Consider the function

julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])
+(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source

CoDuals

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.

A quick aside: Non-Differentiable Data

In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.

Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a $0$ tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.

Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.

TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.

The Rule Interface (Round 2)

Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.

Consider the function

julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])
 foo (generic function with 1 method)

The fdata for x is a Tuple{NoFData, Vector{Float64}}, and its rdata is a Tuple{Float64, NoRData}. The function returns a Float64, which has no fdata, and whose rdata is Float64. So on the forwards pass there is really nothing that needs to happen with the fdata for x.

Under the framework introduced above, the model for this function is

\[f(x) = (x, x_1 + \sum_{n=1}^N (x_2)_n)\]

where the vector in the second element of x is of length $N$. Now, following our usual steps, the derivative is

\[D f [x](\dot{x}) = (\dot{x}, \dot{x}_1 + \sum_{n=1}^N (\dot{x}_2)_n)\]

A gradient for this is a tuple $(\bar{y}_x, \bar{y}_a)$ where $\bar{y}_a \in \RR$ and $\bar{y}_x \in \RR \times \RR^N$. A quick derivation will show that the adjoint is

\[D f [x]^\ast(\bar{y}) = ((\bar{y}_x)_1 + \bar{y}_a, (\bar{y}_x)_2 + \bar{y}_a \mathbf{1})\]

where $\mathbf{1}$ is the vector of length $N$ in which each element is equal to $1$. (Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).

Now that we know what the adjoint is, we'll write down the rrule!!, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding – Mooncake.jl should be relied upon to generate this code automatically in practice.

julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})
            dx_fdata = x.dx
            function dfoo_adjoint(dy::Float64)
@@ -105,4 +105,4 @@
 typeof(g)
 
 # output
-typeof(g) (singleton type of function g, subtype of Function)

Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.

+typeof(g) (singleton type of function g, subtype of Function)

Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.

diff --git a/previews/PR284/understanding_intro/index.html b/previews/PR284/understanding_intro/index.html index 83305bd0..cbb2cacf 100644 --- a/previews/PR284/understanding_intro/index.html +++ b/previews/PR284/understanding_intro/index.html @@ -1,2 +1,2 @@ -Introduction · Mooncake.jl

Mooncake.jl and Reverse-Mode AD

The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.

  1. we recap what AD is, and introduce the mathematics necessary to understand is,
  2. explain how this mathematics relates to functions and data structures in Julia, and
  3. how this is handled in Mooncake.jl.

Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.

Who Are These Docs For?

These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.

Prerequisites and Resources

This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.

In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and "transposed Jacobian". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the "Matrix Calculus for Machine Learning and Beyond" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.

+Introduction · Mooncake.jl

Mooncake.jl and Reverse-Mode AD

The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.

  1. we recap what AD is, and introduce the mathematics necessary to understand is,
  2. explain how this mathematics relates to functions and data structures in Julia, and
  3. how this is handled in Mooncake.jl.

Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.

Who Are These Docs For?

These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.

Prerequisites and Resources

This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.

In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and "transposed Jacobian". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the "Matrix Calculus for Machine Learning and Beyond" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.