From e0d262a62bc5f5ec2d7400ace4b5de9bc38346d1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 14:04:26 -0700 Subject: [PATCH] Use recursive_map --- docs/src/api/Lux/utilities.md | 1 + ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 2 +- ext/LuxReverseDiffExt/training.jl | 41 ++++++++++-- src/helpers/recursive_ops.jl | 78 ++++++++++++++-------- src/utils.jl | 17 +++++ 5 files changed, 105 insertions(+), 34 deletions(-) diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 78461ac4a..e73b31530 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -59,6 +59,7 @@ Lux.xlogx ## Recursive Operations ```@docs +Lux.recursive_map Lux.recursive_add!! Lux.recursive_eltype Lux.recursive_make_zero diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index 706b24b90..d22c39e5c 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -1,6 +1,6 @@ module LuxReverseDiffExt -using ADTypes: AutoReverseDiff +using ADTypes: ADTypes, AutoReverseDiff using ArrayInterface: ArrayInterface using Functors: fmap using Lux: Lux, LuxCPUDevice diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index 6e56f899a..e41b15e85 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -1,11 +1,44 @@ -function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} +@static if pkgversion(ADTypes) < v"1.5" + # older versions did not have `compile` type parameter. Use slower type-unstable code + function Lux.Experimental.compute_gradients(ad::AutoReverseDiff, objective_function::F, + data, ts::Lux.Experimental.TrainState) where {F} + ad.compile && return __compiled_reverse_diff(objective_function, data, ts) + return __uncompiled_reverse_diff(objective_function, data, ts) + end +else + for compiled in (false, true) + fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff + @eval function Lux.Experimental.compute_gradients( + ::AutoReverseDiff{$(compiled)}, objective_function::F, + data, ts::Lux.Experimental.TrainState) where {F} + return $(fname)(objective_function, data, ts) + end + end +end + +@inline function __uncompiled_reverse_diff( + objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} tape = ReverseDiff.InstructionTape() - grads = fmap(zero, ts.parameters) - ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads) + grads = Lux.recursive_make_zero(ts.parameters) + ps_tracked = Lux.recursive_map( + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads) loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) loss.deriv = true ReverseDiff.reverse_pass!(tape) @set! ts.states = st return grads, ReverseDiff.value(loss), stats, ts end + +@inline function __compiled_reverse_diff( + objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} + # tape = ReverseDiff.InstructionTape() + # grads = Lux.recursive_make_zero(ts.parameters) + # ps_tracked = Lux.recursive_map( + # Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads) + # loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) + # loss.deriv = true + # ReverseDiff.reverse_pass!(tape) + # @set! ts.states = st + # return grads, ReverseDiff.value(loss), stats, ts + error("Not implemented yet") +end diff --git a/src/helpers/recursive_ops.jl b/src/helpers/recursive_ops.jl index f358cb9bb..42460348f 100644 --- a/src/helpers/recursive_ops.jl +++ b/src/helpers/recursive_ops.jl @@ -7,17 +7,7 @@ common cases. Any leaves of `x` that are arrays and allow in-place addition will be modified in place. """ -function recursive_add!!(x::AbstractArray, y::AbstractArray) - ArrayInterface.can_setindex(x) || return x .+ y - @. x += y - return x -end -recursive_add!!(x::Tuple, y::Tuple) = map(recursive_add!!, x, y) -recursive_add!!(::Nothing, ::Nothing) = nothing -function recursive_add!!(x::NamedTuple{F}, y::NamedTuple{F}) where {F} - return NamedTuple{F}(map(recursive_add!!, values(x), values(y))) -end -recursive_add!!(x, y) = fmap(recursive_add!!, x, y) +recursive_add!!(x, y) = recursive_map(__add!!, x, y) """ recursive_eltype(x) @@ -48,15 +38,7 @@ Recursively create a zero value for a nested structure `x`. This is equivalent t See also [`Lux.recursive_make_zero!!`](@ref). """ -@inline recursive_make_zero(x::Number) = zero(x) -@inline recursive_make_zero(x::AbstractArray{<:Number}) = zero(x) -@inline recursive_make_zero(x::AbstractArray) = map(recursive_make_zero, x) -@inline recursive_make_zero(x::Tuple) = map(recursive_make_zero, x) -@inline recursive_make_zero(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( - recursive_make_zero, values(x))) -@inline recursive_make_zero(::Nothing) = nothing -@inline recursive_make_zero(v::Val) = v -@inline recursive_make_zero(x) = fmap(recursive_make_zero, x) +@inline recursive_make_zero(x) = recursive_map(__zero, x) """ recursive_make_zero!!(x) @@ -66,12 +48,50 @@ in-place zeroing will be modified in place. See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version. """ -@inline recursive_make_zero!!(x::Number) = zero(x) -@inline recursive_make_zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) -@inline recursive_make_zero!!(x::AbstractArray) = map(recursive_make_zero!!, x) -@inline recursive_make_zero!!(x::Tuple) = map(recursive_make_zero!!, x) -@inline recursive_make_zero!!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( - recursive_make_zero!!, values(x))) -@inline recursive_make_zero!!(::Nothing) = nothing -@inline recursive_make_zero!!(x::Val) = x -@inline recursive_make_zero!!(x) = fmap(recursive_make_zero!!, x) +@inline recursive_make_zero!!(x) = recursive_map(__zero!!, x) + +""" + recursive_map(f, x, args...) + +Similar to `fmap(f, args...)` but with restricted support for the notion of "leaf" types. +However, this allows for more efficient and type stable implementations of recursive +operations. + +## How this works? + +For the following types it directly defines recursion rules: + + 1. `AbstractArray`: If eltype is `isbitstype`, then `f` is applied to the array, else we + recurse on the array. + 2. `Tuple/NamedTuple`: We recurse on the values. + 3. `Number/Val/Nothing`: We directly apply `f`. + 4. For all other types, we recurse on the fields using `Functors.fmap`. + +!!! note + + In most cases, users should gravitate towards `Functors.fmap` if it is being used + outside of hot loops. Even for other cases, it is always recommended to verify the + correctness of this implementation for specific usecases. +""" +function recursive_map end + +for direct_call in (Number, Val, Nothing) + @eval @inline recursive_map(f::F, x::$(direct_call), args...) where {F} = f(x, args...) +end +@inline function recursive_map(f::F, x::AbstractArray{T}, args...) where {F, T} + isbitstype(T) && return f(x, args...) + return f.(x, args...) +end +@inline function recursive_map(f::F, x::Tuple, args...) where {F} + map_fn = let f = f + (args_...) -> recursive_map(f, args_...) + end + return map(map_fn, x, args...) +end +@inline function recursive_map(f::F, x::NamedTuple{fields}, args...) where {F, fields} + map_fn = let f = f + (args_...) -> recursive_map(f, args_...) + end + return NamedTuple{fields}(map(map_fn, values(x), values.(args)...)) +end +@inline recursive_map(f::F, x, args...) where {F} = fmap(f, x, args...) diff --git a/src/utils.jl b/src/utils.jl index d5a99d3e3..e2b0f0df4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -354,3 +354,20 @@ end @inline __get_dims(::AbstractVector) = Colon() @inline __get_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1) + +@inline __zero(x) = zero(x) +@inline __zero(::Nothing) = nothing +@inline __zero(x::Val) = x + +@inline __zero!!(x::Number) = zero(x) +@inline __zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) +@inline __zero!!(::Nothing) = nothing +@inline __zero!!(x::Val) = x + +@inline function __add!!(x::AbstractArray{<:Number}, y::AbstractArray{<:Number}) + ArrayInterface.can_setindex(x) || return x .+ y + @. x += y + return x +end +@inline __add!!(x::Number, y::Number) = x + y +@inline __add!!(::Nothing, ::Nothing) = nothing