Skip to content

Commit

Permalink
Further ChainRulesCore.rrule Integration (#254)
Browse files Browse the repository at this point in the history
* Bump patch version

* Fix usage with benchmarktools

* Initial pass

* Bump patch

* Unit test to_tapir_tangent and to_cr_tangent

* Make use of macro

* More testing and tidying up

* Add some basic type checking and a test

* Improve formatting and commenting

* Formatting

* Improve documentation

* Explain how not to use rrule functionality

* Add rules for BLAS utilities

* Initial NNlib integration

* Thunks and batched_mul

* More rules + kwargs + rename

* Fix link in docs

* Rename chain_rules_macro to chain_rules_interop

* Complete rename of chain rules interop file

* Refactor chain rules interop

* Add more nnlib functionality

* Remove old tests

* Some work

* Remove errant show statment

* Remove redundant test

* Support where

* Make use of where params

* Improve kwarg interface

* Default kwargs test

* Improve docstring

* Some work

* Some work

* Better conv support in nnlib rules

* More LuxLib rules

* Permit :meta nodes in IR

* Remove redundant test

* Uncomment some tests

* Rename chain rules doc

* Add notes to docs on rule writing strategies

* Add mooncake_overlay

* Add simpler method of build_rrule

* Fix dispatch problem

* Tidy up

* Tidy up build_rrule calls

* Improve zero_adjoint docs

* Improve documentation of from_rrule

* Fix formatting

* Explain what is new

* Improve from_rrule documentation

* Formatting

* Fix formatting

* Add compat for ChainRulesCore in docs

* Tidy up mooncake_method_table usage

* Add extra luxlib test

* Add another luxlib test

* Bump patch version

* Update ext/MooncakeNNlibExt.jl

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

* Restrict CI to 1.10 for now

* Apply suggestions from code review

Co-authored-by: Markus Hauru <markus@mhauru.org>

* Restrict version consistently

* Fix typo in docstring

* Shove all testing functionality inside module

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: Markus Hauru <markus@mhauru.org>
  • Loading branch information
3 people authored Oct 8, 2024
1 parent d4285ef commit d6110f0
Show file tree
Hide file tree
Showing 38 changed files with 1,264 additions and 261 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ steps:
- label: "Julia v1"
plugins:
- JuliaCI/julia#v1:
version: "1"
version: "1.10"
- JuliaCI/julia-test#v1: ~
- JuliaCI/julia-coverage#v1:
dirs:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
version: '1.10'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
Expand All @@ -62,7 +62,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
version: '1.10'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
Expand All @@ -81,7 +81,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
version: '1.10'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
version: '1.10'
arch: x64
include-all-prereleases: false
- name: Install dependencies
Expand Down
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.9"
version = "0.4.10"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -22,13 +22,17 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
MooncakeCUDAExt = "CUDA"
MooncakeDynamicPPLExt = "DynamicPPL"
MooncakeJETExt = "JET"
MooncakeLogDensityProblemsADExt = "LogDensityProblemsAD"
MooncakeLuxLibExt = "LuxLib"
MooncakeNNlibExt = "NNlib"
MooncakeSpecialFunctionsExt = "SpecialFunctions"

[compat]
Expand All @@ -46,7 +50,9 @@ FillArrays = "1"
Graphs = "1"
JET = "0.9"
LogDensityProblemsAD = "1"
LuxLib = "1.2"
MistyClosures = "1"
NNlib = "0.9"
PDMats = "0.11"
Setfield = "1"
SpecialFunctions = "2"
Expand All @@ -66,11 +72,14 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"]
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[compat]
ChainRulesCore = "1"
Documenter = "1"
Mooncake = "0.4.0"
7 changes: 5 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ makedocs(
"Algorithmic Differentiation" => "algorithmic_differentiation.md",
"Mooncake.jl's Rule System" => "mathematical_interpretation.md",
],
"Utilities" => [
"Tools for Rules" => "tools_for_rules.md",
"Debug Mode" => "debug_mode.md",
"Debugging and MWEs" => "debugging_and_mwes.md",
],
"Known Limitations" => "known_limitations.md",
"Debug Mode" => "debug_mode.md",
"Debugging and MWEs" => "debugging_and_mwes.md",
]
)

Expand Down
5 changes: 5 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

Documentation for Mooncake.jl is on its way!

Note (03/10/2024): Various bits of utility functionality are now carefully documented. This
includes how to change the code which Mooncake sees, declare that the derivative of a
function is zero, make use of existing `ChainRules.rrule`s to quicky create new rules in
Mooncake, and more.

Note (02/07/2024): The first round of documentation has arrived.
This is largely targetted at those who are interested in contributing to Mooncake.jl -- you can find this work in the "Understanding Mooncake.jl" section of the docs.
There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/known_limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function foo(x::Vector{Float64})
return unsafe_load(p)
end
rule = build_rrule(get_interpreter(), Tuple{typeof(foo), Vector{Float64}})
rule = build_rrule(Tuple{typeof(foo), Vector{Float64}})
Mooncake.value_and_gradient!!(rule, foo, [5.0, 4.0])
# output
Expand Down
33 changes: 33 additions & 0 deletions docs/src/tools_for_rules.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Tools for Rules

Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported.
However, this does not always necessitate writing your own `rrule!!` from scratch.
In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations.

## Simplfiying Code via Overlays

```@docs
Mooncake.@mooncake_overlay
```

## Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common
situation that the adjoint of the derivative of your function is always zero, you can very
straightforwardly write a rule by making use of the following:
```@docs
Mooncake.@zero_adjoint
Mooncake.zero_adjoint
```

## Using ChainRules.jl

[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode.
These rules are methods of the `ChainRulesCore.rrule` function.
There are some instances where it is most convenient to implement a `Mooncake.rrule!!` by wrapping an existing `ChainRulesCore.rrule`.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

```@docs
Mooncake.@from_rrule
```
14 changes: 3 additions & 11 deletions ext/MooncakeDynamicPPLExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
module MooncakeDynamicPPLExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL, istrans
using Mooncake: Mooncake
else
using ..DynamicPPL: DynamicPPL, istrans
using ..Mooncake: Mooncake
end

using Mooncake: DefaultCtx, CoDual, simple_zero_adjoint
using DynamicPPL: DynamicPPL, istrans
using Mooncake: Mooncake

# This is purely an optimisation.
Mooncake.@is_primitive DefaultCtx Tuple{typeof(istrans), Vararg}
Mooncake.rrule!!(f::CoDual{typeof(istrans)}, x::CoDual...) = simple_zero_adjoint(f, x...)
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg}

end # module
173 changes: 173 additions & 0 deletions ext/MooncakeLuxLibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
module MooncakeLuxLibExt

using LuxLib, Random, Mooncake
using Base: IEEEFloat

import LuxLib: Impl
import LuxLib.Utils: static_training_mode_check
import Mooncake:
@from_rrule,
DefaultCtx,
@mooncake_overlay,
CoDual

@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat})
@from_rrule(
DefaultCtx,
Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat},
)
@from_rrule(
DefaultCtx,
Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat},
)

# Re-implement a bunch of methods to ensure that Mooncake can differentiate them.
@mooncake_overlay function LuxLib.Impl.fused_dense(
opmode,
act::F,
weight::AbstractMatrix,
x::AbstractMatrix,
b::LuxLib.Optional{<:AbstractVector},
) where {F}
return bias_activation(act, Impl.matmul(weight, x), b)
end

@mooncake_overlay function LuxLib.Impl.bias_activation_loop!(
y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector
) where {F, xT, yT}
return LuxLib.Impl.bias_activation_simd_loop!(y, σ, x, bias)
end

@mooncake_overlay function LuxLib.Impl.activation_loop!(
y::AbstractArray, σ::F, x::AbstractArray
) where {F}
return LuxLib.Impl.activation_simd_loop!(y, σ, x)
end

@mooncake_overlay function LuxLib.Impl.fused_conv(
::LuxLib.Impl.AbstractInternalArrayOpMode,
act::F,
weight::AbstractArray{wT, N},
x::AbstractArray{xT, N},
bias::LuxLib.Optional{<:AbstractVector},
cdims::LuxLib.Impl.ConvDims,
) where {F, wT, xT, N}
return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias)
end

for f in [
Impl.SLEEFActivations.sigmoid_fast,
Impl.SLEEFActivations.softplus,
Impl.SLEEFActivations.logsigmoid,
Impl.SLEEFActivations.swish,
Impl.SLEEFActivations.lisht,
Impl.SLEEFActivations.tanh,
Impl.SLEEFActivations.tanh_fast,
]
@from_rrule DefaultCtx Tuple{typeof(f), IEEEFloat}
@from_rrule(
DefaultCtx,
Tuple{typeof(Broadcast.broadcasted), typeof(f), Union{IEEEFloat, Array{<:IEEEFloat}}},
)
end

Mooncake.@zero_adjoint DefaultCtx Tuple{typeof(static_training_mode_check), Vararg}

# This is a really horrible hack that we need to do until Mooncake is able to support the
# call-back-into-ad interface that ChainRules exposes.

import LuxLib.Impl:
safe_eltype,
batchnorm_affine_normalize_internal,
batchnorm_affine_normalize_internal!,
∇batchnorm_affine_normalize,
AbstractInternalArrayOpMode

import ChainRulesCore as CRC

function CRC.rrule(
::typeof(batchnorm_affine_normalize_internal),
opmode::AbstractInternalArrayOpMode,
::typeof(identity),
x::AbstractArray{T, N},
μ::AbstractVector,
σ²::AbstractVector,
γ::LuxLib.Optional{<:AbstractVector},
β::LuxLib.Optional{<:AbstractVector},
ϵ::Real,
) where {T, N}
y = similar(
x,
promote_type(
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β)
)
)
γ′ = similar(
x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1)
)

batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′)

𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²)
𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ)
𝒫β = β === nothing ? identity : CRC.ProjectTo(β)

∇batchnorm_affine_normalize_internal = LuxLib.Impl.@closure Δ -> begin
∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, Δ, x, μ, σ², γ, β, ϵ, γ′)
∂∅ = CRC.NoTangent()
return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅
end

return y, ∇batchnorm_affine_normalize_internal
end

@from_rrule(
DefaultCtx,
Tuple{
typeof(batchnorm_affine_normalize_internal),
AbstractInternalArrayOpMode,
typeof(identity),
AbstractArray,
AbstractVector,
AbstractVector,
LuxLib.Optional{<:AbstractVector},
LuxLib.Optional{<:AbstractVector},
Real,
},
)

@mooncake_overlay function batchnorm_affine_normalize_internal(
opmode::LuxLib.AbstractInternalArrayOpMode,
act::F,
x::AbstractArray{xT, 3},
μ::AbstractVector,
σ²::AbstractVector,
γ::Union{Nothing, AbstractVector},
β::Union{Nothing, AbstractVector},
ϵ::Real,
) where {F, xT}
y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ)
LuxLib.Impl.activation!(y, opmode, act, y)
return y
end

@mooncake_overlay function batchnorm_affine_normalize_internal(
opmode::LuxLib.AbstractInternalArrayOpMode,
::typeof(identity),
x::AbstractArray{xT, 3},
μ::AbstractVector,
σ²::AbstractVector,
γ::Union{Nothing, AbstractVector},
β::Union{Nothing, AbstractVector},
ϵ::Real,
) where {xT}
y = similar(x,
promote_type(
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β)
)
)
batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ)
return y
end

end
Loading

2 comments on commit d6110f0

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/116820

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.10 -m "<description of version>" d6110f04c54e021b2491df31f12bb0fb6ee3dd2e
git push origin v0.4.10

Please sign in to comment.