-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Further ChainRulesCore.rrule Integration (#254)
* Bump patch version * Fix usage with benchmarktools * Initial pass * Bump patch * Unit test to_tapir_tangent and to_cr_tangent * Make use of macro * More testing and tidying up * Add some basic type checking and a test * Improve formatting and commenting * Formatting * Improve documentation * Explain how not to use rrule functionality * Add rules for BLAS utilities * Initial NNlib integration * Thunks and batched_mul * More rules + kwargs + rename * Fix link in docs * Rename chain_rules_macro to chain_rules_interop * Complete rename of chain rules interop file * Refactor chain rules interop * Add more nnlib functionality * Remove old tests * Some work * Remove errant show statment * Remove redundant test * Support where * Make use of where params * Improve kwarg interface * Default kwargs test * Improve docstring * Some work * Some work * Better conv support in nnlib rules * More LuxLib rules * Permit :meta nodes in IR * Remove redundant test * Uncomment some tests * Rename chain rules doc * Add notes to docs on rule writing strategies * Add mooncake_overlay * Add simpler method of build_rrule * Fix dispatch problem * Tidy up * Tidy up build_rrule calls * Improve zero_adjoint docs * Improve documentation of from_rrule * Fix formatting * Explain what is new * Improve from_rrule documentation * Formatting * Fix formatting * Add compat for ChainRulesCore in docs * Tidy up mooncake_method_table usage * Add extra luxlib test * Add another luxlib test * Bump patch version * Update ext/MooncakeNNlibExt.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Restrict CI to 1.10 for now * Apply suggestions from code review Co-authored-by: Markus Hauru <markus@mhauru.org> * Restrict version consistently * Fix typo in docstring * Shove all testing functionality inside module --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Markus Hauru <markus@mhauru.org>
- Loading branch information
1 parent
d4285ef
commit d6110f0
Showing
38 changed files
with
1,264 additions
and
261 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" | ||
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
|
||
[compat] | ||
ChainRulesCore = "1" | ||
Documenter = "1" | ||
Mooncake = "0.4.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Tools for Rules | ||
|
||
Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. | ||
However, this does not always necessitate writing your own `rrule!!` from scratch. | ||
In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations. | ||
|
||
## Simplfiying Code via Overlays | ||
|
||
```@docs | ||
Mooncake.@mooncake_overlay | ||
``` | ||
|
||
## Functions with Zero Adjoint | ||
|
||
If the above strategy does not work, but you find yourself in the surprisingly common | ||
situation that the adjoint of the derivative of your function is always zero, you can very | ||
straightforwardly write a rule by making use of the following: | ||
```@docs | ||
Mooncake.@zero_adjoint | ||
Mooncake.zero_adjoint | ||
``` | ||
|
||
## Using ChainRules.jl | ||
|
||
[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode. | ||
These rules are methods of the `ChainRulesCore.rrule` function. | ||
There are some instances where it is most convenient to implement a `Mooncake.rrule!!` by wrapping an existing `ChainRulesCore.rrule`. | ||
|
||
There is enough similarity between these two systems that most of the boilerplate code can be avoided. | ||
|
||
```@docs | ||
Mooncake.@from_rrule | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,9 @@ | ||
module MooncakeDynamicPPLExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using DynamicPPL: DynamicPPL, istrans | ||
using Mooncake: Mooncake | ||
else | ||
using ..DynamicPPL: DynamicPPL, istrans | ||
using ..Mooncake: Mooncake | ||
end | ||
|
||
using Mooncake: DefaultCtx, CoDual, simple_zero_adjoint | ||
using DynamicPPL: DynamicPPL, istrans | ||
using Mooncake: Mooncake | ||
|
||
# This is purely an optimisation. | ||
Mooncake.@is_primitive DefaultCtx Tuple{typeof(istrans), Vararg} | ||
Mooncake.rrule!!(f::CoDual{typeof(istrans)}, x::CoDual...) = simple_zero_adjoint(f, x...) | ||
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg} | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
module MooncakeLuxLibExt | ||
|
||
using LuxLib, Random, Mooncake | ||
using Base: IEEEFloat | ||
|
||
import LuxLib: Impl | ||
import LuxLib.Utils: static_training_mode_check | ||
import Mooncake: | ||
@from_rrule, | ||
DefaultCtx, | ||
@mooncake_overlay, | ||
CoDual | ||
|
||
@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) | ||
@from_rrule( | ||
DefaultCtx, | ||
Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, | ||
) | ||
@from_rrule( | ||
DefaultCtx, | ||
Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, | ||
) | ||
|
||
# Re-implement a bunch of methods to ensure that Mooncake can differentiate them. | ||
@mooncake_overlay function LuxLib.Impl.fused_dense( | ||
opmode, | ||
act::F, | ||
weight::AbstractMatrix, | ||
x::AbstractMatrix, | ||
b::LuxLib.Optional{<:AbstractVector}, | ||
) where {F} | ||
return bias_activation(act, Impl.matmul(weight, x), b) | ||
end | ||
|
||
@mooncake_overlay function LuxLib.Impl.bias_activation_loop!( | ||
y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector | ||
) where {F, xT, yT} | ||
return LuxLib.Impl.bias_activation_simd_loop!(y, σ, x, bias) | ||
end | ||
|
||
@mooncake_overlay function LuxLib.Impl.activation_loop!( | ||
y::AbstractArray, σ::F, x::AbstractArray | ||
) where {F} | ||
return LuxLib.Impl.activation_simd_loop!(y, σ, x) | ||
end | ||
|
||
@mooncake_overlay function LuxLib.Impl.fused_conv( | ||
::LuxLib.Impl.AbstractInternalArrayOpMode, | ||
act::F, | ||
weight::AbstractArray{wT, N}, | ||
x::AbstractArray{xT, N}, | ||
bias::LuxLib.Optional{<:AbstractVector}, | ||
cdims::LuxLib.Impl.ConvDims, | ||
) where {F, wT, xT, N} | ||
return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias) | ||
end | ||
|
||
for f in [ | ||
Impl.SLEEFActivations.sigmoid_fast, | ||
Impl.SLEEFActivations.softplus, | ||
Impl.SLEEFActivations.logsigmoid, | ||
Impl.SLEEFActivations.swish, | ||
Impl.SLEEFActivations.lisht, | ||
Impl.SLEEFActivations.tanh, | ||
Impl.SLEEFActivations.tanh_fast, | ||
] | ||
@from_rrule DefaultCtx Tuple{typeof(f), IEEEFloat} | ||
@from_rrule( | ||
DefaultCtx, | ||
Tuple{typeof(Broadcast.broadcasted), typeof(f), Union{IEEEFloat, Array{<:IEEEFloat}}}, | ||
) | ||
end | ||
|
||
Mooncake.@zero_adjoint DefaultCtx Tuple{typeof(static_training_mode_check), Vararg} | ||
|
||
# This is a really horrible hack that we need to do until Mooncake is able to support the | ||
# call-back-into-ad interface that ChainRules exposes. | ||
|
||
import LuxLib.Impl: | ||
safe_eltype, | ||
batchnorm_affine_normalize_internal, | ||
batchnorm_affine_normalize_internal!, | ||
∇batchnorm_affine_normalize, | ||
AbstractInternalArrayOpMode | ||
|
||
import ChainRulesCore as CRC | ||
|
||
function CRC.rrule( | ||
::typeof(batchnorm_affine_normalize_internal), | ||
opmode::AbstractInternalArrayOpMode, | ||
::typeof(identity), | ||
x::AbstractArray{T, N}, | ||
μ::AbstractVector, | ||
σ²::AbstractVector, | ||
γ::LuxLib.Optional{<:AbstractVector}, | ||
β::LuxLib.Optional{<:AbstractVector}, | ||
ϵ::Real, | ||
) where {T, N} | ||
y = similar( | ||
x, | ||
promote_type( | ||
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) | ||
) | ||
) | ||
γ′ = similar( | ||
x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1) | ||
) | ||
|
||
batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′) | ||
|
||
𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) | ||
𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) | ||
𝒫β = β === nothing ? identity : CRC.ProjectTo(β) | ||
|
||
∇batchnorm_affine_normalize_internal = LuxLib.Impl.@closure Δ -> begin | ||
∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, Δ, x, μ, σ², γ, β, ϵ, γ′) | ||
∂∅ = CRC.NoTangent() | ||
return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅ | ||
end | ||
|
||
return y, ∇batchnorm_affine_normalize_internal | ||
end | ||
|
||
@from_rrule( | ||
DefaultCtx, | ||
Tuple{ | ||
typeof(batchnorm_affine_normalize_internal), | ||
AbstractInternalArrayOpMode, | ||
typeof(identity), | ||
AbstractArray, | ||
AbstractVector, | ||
AbstractVector, | ||
LuxLib.Optional{<:AbstractVector}, | ||
LuxLib.Optional{<:AbstractVector}, | ||
Real, | ||
}, | ||
) | ||
|
||
@mooncake_overlay function batchnorm_affine_normalize_internal( | ||
opmode::LuxLib.AbstractInternalArrayOpMode, | ||
act::F, | ||
x::AbstractArray{xT, 3}, | ||
μ::AbstractVector, | ||
σ²::AbstractVector, | ||
γ::Union{Nothing, AbstractVector}, | ||
β::Union{Nothing, AbstractVector}, | ||
ϵ::Real, | ||
) where {F, xT} | ||
y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ) | ||
LuxLib.Impl.activation!(y, opmode, act, y) | ||
return y | ||
end | ||
|
||
@mooncake_overlay function batchnorm_affine_normalize_internal( | ||
opmode::LuxLib.AbstractInternalArrayOpMode, | ||
::typeof(identity), | ||
x::AbstractArray{xT, 3}, | ||
μ::AbstractVector, | ||
σ²::AbstractVector, | ||
γ::Union{Nothing, AbstractVector}, | ||
β::Union{Nothing, AbstractVector}, | ||
ϵ::Real, | ||
) where {xT} | ||
y = similar(x, | ||
promote_type( | ||
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) | ||
) | ||
) | ||
batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) | ||
return y | ||
end | ||
|
||
end |
Oops, something went wrong.
d6110f0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register()
d6110f0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/116820
Tip: Release Notes
Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.
To add them here just re-invoke and the PR will be updated.
Tagging
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: