Skip to content

Commit

Permalink
Use recursive_map
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 22, 2024
1 parent 6b81817 commit e0d262a
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 34 deletions.
1 change: 1 addition & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Lux.xlogx
## Recursive Operations

```@docs
Lux.recursive_map
Lux.recursive_add!!
Lux.recursive_eltype
Lux.recursive_make_zero
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module LuxReverseDiffExt

using ADTypes: AutoReverseDiff
using ADTypes: ADTypes, AutoReverseDiff
using ArrayInterface: ArrayInterface
using Functors: fmap
using Lux: Lux, LuxCPUDevice
Expand Down
41 changes: 37 additions & 4 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,44 @@
function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
@static if pkgversion(ADTypes) < v"1.5"
# older versions did not have `compile` type parameter. Use slower type-unstable code
function Lux.Experimental.compute_gradients(ad::AutoReverseDiff, objective_function::F,
data, ts::Lux.Experimental.TrainState) where {F}
ad.compile && return __compiled_reverse_diff(objective_function, data, ts)
return __uncompiled_reverse_diff(objective_function, data, ts)
end
else
for compiled in (false, true)
fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff
@eval function Lux.Experimental.compute_gradients(

Check warning on line 11 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L11

Added line #L11 was not covered by tests
::AutoReverseDiff{$(compiled)}, objective_function::F,
data, ts::Lux.Experimental.TrainState) where {F}
return $(fname)(objective_function, data, ts)

Check warning on line 14 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L14

Added line #L14 was not covered by tests
end
end
end

@inline function __uncompiled_reverse_diff(
objective_function::F, data, ts::Lux.Experimental.TrainState) where {F}
tape = ReverseDiff.InstructionTape()
grads = fmap(zero, ts.parameters)
ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads)
grads = Lux.recursive_make_zero(ts.parameters)
ps_tracked = Lux.recursive_map(
Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads)
loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data)
loss.deriv = true
ReverseDiff.reverse_pass!(tape)
@set! ts.states = st
return grads, ReverseDiff.value(loss), stats, ts
end

@inline function __compiled_reverse_diff(

Check warning on line 32 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L32

Added line #L32 was not covered by tests
objective_function::F, data, ts::Lux.Experimental.TrainState) where {F}
# tape = ReverseDiff.InstructionTape()
# grads = Lux.recursive_make_zero(ts.parameters)
# ps_tracked = Lux.recursive_map(
# Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads)
# loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data)
# loss.deriv = true
# ReverseDiff.reverse_pass!(tape)
# @set! ts.states = st
# return grads, ReverseDiff.value(loss), stats, ts
error("Not implemented yet")

Check warning on line 43 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L43

Added line #L43 was not covered by tests
end
78 changes: 49 additions & 29 deletions src/helpers/recursive_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,7 @@ common cases.
Any leaves of `x` that are arrays and allow in-place addition will be modified in place.
"""
function recursive_add!!(x::AbstractArray, y::AbstractArray)
ArrayInterface.can_setindex(x) || return x .+ y
@. x += y
return x
end
recursive_add!!(x::Tuple, y::Tuple) = map(recursive_add!!, x, y)
recursive_add!!(::Nothing, ::Nothing) = nothing
function recursive_add!!(x::NamedTuple{F}, y::NamedTuple{F}) where {F}
return NamedTuple{F}(map(recursive_add!!, values(x), values(y)))
end
recursive_add!!(x, y) = fmap(recursive_add!!, x, y)
recursive_add!!(x, y) = recursive_map(__add!!, x, y)

Check warning on line 10 in src/helpers/recursive_ops.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/recursive_ops.jl#L10

Added line #L10 was not covered by tests

"""
recursive_eltype(x)
Expand Down Expand Up @@ -48,15 +38,7 @@ Recursively create a zero value for a nested structure `x`. This is equivalent t
See also [`Lux.recursive_make_zero!!`](@ref).
"""
@inline recursive_make_zero(x::Number) = zero(x)
@inline recursive_make_zero(x::AbstractArray{<:Number}) = zero(x)
@inline recursive_make_zero(x::AbstractArray) = map(recursive_make_zero, x)
@inline recursive_make_zero(x::Tuple) = map(recursive_make_zero, x)
@inline recursive_make_zero(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map(
recursive_make_zero, values(x)))
@inline recursive_make_zero(::Nothing) = nothing
@inline recursive_make_zero(v::Val) = v
@inline recursive_make_zero(x) = fmap(recursive_make_zero, x)
@inline recursive_make_zero(x) = recursive_map(__zero, x)

"""
recursive_make_zero!!(x)
Expand All @@ -66,12 +48,50 @@ in-place zeroing will be modified in place.
See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version.
"""
@inline recursive_make_zero!!(x::Number) = zero(x)
@inline recursive_make_zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x)))
@inline recursive_make_zero!!(x::AbstractArray) = map(recursive_make_zero!!, x)
@inline recursive_make_zero!!(x::Tuple) = map(recursive_make_zero!!, x)
@inline recursive_make_zero!!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map(
recursive_make_zero!!, values(x)))
@inline recursive_make_zero!!(::Nothing) = nothing
@inline recursive_make_zero!!(x::Val) = x
@inline recursive_make_zero!!(x) = fmap(recursive_make_zero!!, x)
@inline recursive_make_zero!!(x) = recursive_map(__zero!!, x)

"""
recursive_map(f, x, args...)
Similar to `fmap(f, args...)` but with restricted support for the notion of "leaf" types.
However, this allows for more efficient and type stable implementations of recursive
operations.
## How this works?
For the following types it directly defines recursion rules:
1. `AbstractArray`: If eltype is `isbitstype`, then `f` is applied to the array, else we
recurse on the array.
2. `Tuple/NamedTuple`: We recurse on the values.
3. `Number/Val/Nothing`: We directly apply `f`.
4. For all other types, we recurse on the fields using `Functors.fmap`.
!!! note
In most cases, users should gravitate towards `Functors.fmap` if it is being used
outside of hot loops. Even for other cases, it is always recommended to verify the
correctness of this implementation for specific usecases.
"""
function recursive_map end

for direct_call in (Number, Val, Nothing)
@eval @inline recursive_map(f::F, x::$(direct_call), args...) where {F} = f(x, args...)

Check warning on line 79 in src/helpers/recursive_ops.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/recursive_ops.jl#L79

Added line #L79 was not covered by tests
end
@inline function recursive_map(f::F, x::AbstractArray{T}, args...) where {F, T}
isbitstype(T) && return f(x, args...)
return f.(x, args...)

Check warning on line 83 in src/helpers/recursive_ops.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/recursive_ops.jl#L83

Added line #L83 was not covered by tests
end
@inline function recursive_map(f::F, x::Tuple, args...) where {F}
map_fn = let f = f
(args_...) -> recursive_map(f, args_...)

Check warning on line 87 in src/helpers/recursive_ops.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/recursive_ops.jl#L85-L87

Added lines #L85 - L87 were not covered by tests
end
return map(map_fn, x, args...)

Check warning on line 89 in src/helpers/recursive_ops.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/recursive_ops.jl#L89

Added line #L89 was not covered by tests
end
@inline function recursive_map(f::F, x::NamedTuple{fields}, args...) where {F, fields}
map_fn = let f = f
(args_...) -> recursive_map(f, args_...)
end
return NamedTuple{fields}(map(map_fn, values(x), values.(args)...))
end
@inline recursive_map(f::F, x, args...) where {F} = fmap(f, x, args...)

Check warning on line 97 in src/helpers/recursive_ops.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/recursive_ops.jl#L97

Added line #L97 was not covered by tests
17 changes: 17 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,20 @@ end

@inline __get_dims(::AbstractVector) = Colon()
@inline __get_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1)

@inline __zero(x) = zero(x)
@inline __zero(::Nothing) = nothing
@inline __zero(x::Val) = x

Check warning on line 360 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L359-L360

Added lines #L359 - L360 were not covered by tests

@inline __zero!!(x::Number) = zero(x)

Check warning on line 362 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L362

Added line #L362 was not covered by tests
@inline __zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x)))
@inline __zero!!(::Nothing) = nothing
@inline __zero!!(x::Val) = x

Check warning on line 365 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L364-L365

Added lines #L364 - L365 were not covered by tests

@inline function __add!!(x::AbstractArray{<:Number}, y::AbstractArray{<:Number})
ArrayInterface.can_setindex(x) || return x .+ y
@. x += y
return x

Check warning on line 370 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L367-L370

Added lines #L367 - L370 were not covered by tests
end
@inline __add!!(x::Number, y::Number) = x + y

Check warning on line 372 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L372

Added line #L372 was not covered by tests
@inline __add!!(::Nothing, ::Nothing) = nothing

0 comments on commit e0d262a

Please sign in to comment.