Skip to content

Commit

Permalink
refactor: use setfield and make make_zero!! type-stable
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 13, 2024
1 parent 4cd193f commit 92b15fd
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 100 deletions.
2 changes: 2 additions & 0 deletions ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module LuxEnzymeExt
using ADTypes: AutoEnzyme
using Enzyme: Enzyme, Active, Const, Duplicated
using EnzymeCore: EnzymeCore
using Setfield: @set!
using Static: False, True

using Lux: Lux
using Lux.Training: TrainingBackendCache, TrainState
Expand Down
53 changes: 26 additions & 27 deletions ext/LuxEnzymeExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@
function Lux.Training.compute_gradients(
::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Enzyme.make_zero(ts.parameters)
ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)

obj_fn_wrap, st_wrap, stats_wrap = Lux.Training.wrap_objective_function(
obj_fn, ts.model, ts.parameters, ts.states, data, Val(true))
obj_fn, ts.model, ts.parameters, ts.states, data, True())

_, loss = Enzyme.autodiff(
EnzymeCore.ReverseWithPrimal, Const(obj_fn_wrap), Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

cache = TrainingBackendCache{:Enzyme, false}(
dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap))
ts_new = TrainState(cache, obj_fn, ts.model, ts.parameters, st_wrap[],
ts.optimizer, ts.optimizer_state, ts.step)

return dps, loss, stats_wrap[], ts_new
cache = TrainingBackendCache(
ad, False(), dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap))
@set! ts.cache = cache
@set! ts.objective_function = obj_fn
@set! ts.states = st_wrap[]
return dps, loss, stats_wrap[], ts
end

const AUTODIFF_CACHE_TYPE = TrainingBackendCache{
:Enzyme, false, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS}
const AUTODIFF_CACHE_TYPE = TrainingBackendCache{<:AutoEnzyme, False, PS,
<:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS}

function Lux.Training.compute_gradients(
::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
# dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
Enzyme.make_zero!(ts.cache.dparameters)
dps = ts.cache.dparameters

_, loss = Enzyme.autodiff(
EnzymeCore.ReverseWithPrimal, Const(ts.cache.extras.obj_fn), Active,
Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = TrainState(
ts.cache, obj_fn, ts.model, ts.parameters, ts.cache.extras.st_wrap[],
ts.optimizer, ts.optimizer_state, ts.step)
@set! ts.objective_function = obj_fn
@set! ts.states = ts.cache.extras.st_wrap[]

return dps, loss, ts.cache.extras.stats_wrap[], ts_new
return dps, loss, ts.cache.extras.stats_wrap[], ts
end

function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:Enzyme, false}}) where {F}
ts::TrainState{<:TrainingBackendCache{<:AutoEnzyme, False}}) where {F}
@warn "Detected calls to `compute_gradients(::AutoEnzyme, ...)` with objective \
function that is changing across function calls. This can lead to the \
generation of slow code" maxlog=1
Expand All @@ -46,15 +47,14 @@ function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)},
Const{typeof(ts.states)}, Const{typeof(data)})

cache = TrainingBackendCache{:Enzyme, false}(ts.cache.dparameters, (; forward, reverse))
ts_new = TrainState(cache, obj_fn, ts.model, ts.parameters, ts.states,
ts.optimizer, ts.optimizer_state, ts.step)

return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
cache = TrainingBackendCache(ad, False(), ts.cache.dparameters, (; forward, reverse))
@set! ts.cache = cache
@set! ts.objective_function = obj_fn
return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end

const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{
:Enzyme, false, PS, <:NamedTuple{(:forward, :reverse)}} where {PS}
<:AutoEnzyme, False, PS, <:NamedTuple{(:forward, :reverse)}} where {PS}

function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F}
Expand All @@ -67,8 +67,7 @@ function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data,
Const(obj_fn), Const(ts.model), params, Const(ts.states), Const(data),
(one(loss), Lux.recursive_make_zero(st_), Lux.recursive_make_zero(stats)), tape)

ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, st_,
ts.optimizer, ts.optimizer_state, ts.step)

return dps, loss, stats, ts_new
@set! ts.objective_function = obj_fn
@set! ts.states = st_
return dps, loss, stats, ts
end
4 changes: 3 additions & 1 deletion ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ using ArrayInterface: ArrayInterface
using FunctionWrappers: FunctionWrapper
using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal,
@grad_from_chainrules
using Setfield: @set!
using Static: False, True

using Lux: Lux, Utils
using Lux.Training: TrainingBackendCache, TrainState
using Lux.Training: Training, TrainingBackendCache, TrainState
using LuxCore: LuxCore
using MLDataDevices: CPUDevice

Expand Down
73 changes: 32 additions & 41 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,50 @@
# Uncompiled ReverseDiff
function Lux.Training.compute_gradients(
ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, ts.model,
ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step)
return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
@set! ts.cache = TrainingBackendCache(
ad, True(), Lux.recursive_make_zero(ts.parameters), nothing)
@set! ts.objective_function = obj_fn
return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end

function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{false}}}) where {F}
dparams = Training.dparameters(ts.cache)
tape = ReverseDiff.InstructionTape()
ps_tracked = Lux.recursive_map(Utils.Fix3(TrackedArray, tape), ts.parameters, dparams)

loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data)
loss.deriv = true
ReverseDiff.reverse_pass!(tape)

ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, nothing),
obj_fn, ts.model, ts.parameters, st, ts.optimizer, ts.optimizer_state, ts.step)

return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new
@set! ts.cache.first_try = False()
@set! ts.objective_function = obj_fn
@set! ts.states = st
return dparams, ReverseDiff.value(loss), stats, ts
end

# Compiled ReverseDiff
function Lux.Training.compute_gradients(
ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
data_cache = deepcopy(data)
ps_cache = deepcopy(ts.parameters)
extras = (; data_cache, ps_cache)

ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing, ts.model,
ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step)
return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
@set! ts.cache = TrainingBackendCache(
ad, True(), Lux.recursive_make_zero(ts.parameters),
(; data_cache=deepcopy(data), ps_cache=deepcopy(ts.parameters)))
@set! ts.objective_function = nothing

return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end

## Tape hasn't been compiled yet / Function mismatch so recompile
function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}}) where {F}
if LuxCore.statelength(ts.states) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \
models with non-empty state `st`."))
end

if FT # do a dry run
first_try = ts.cache.first_try isa True

if first_try # do a dry run
_, st_, stats = obj_fn(ts.model, ts.parameters, ts.states, data)
if stats != NamedTuple()
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \
Expand All @@ -59,20 +56,18 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data
end
end

dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
dparams = Training.dparameters(ts.cache)

(; ps_cache, data_cache) = ts.cache.extras
if !FT
if !first_try
Lux.recursive_copyto!(ps_cache, ts.parameters)
Lux.recursive_copyto!(data_cache, data)
end

obj_fn_wrap = first obj_fn

tape = ReverseDiff.InstructionTape()
ps_tracked = Lux.recursive_map(Utils.Fix3(TrackedArray, tape), ps_cache, dparams)

loss = obj_fn_wrap(ts.model, ps_tracked, ts.states, data_cache)
loss = first(obj_fn(ts.model, ps_tracked, ts.states, data_cache))
loss.deriv = true
ReverseDiff.reverse_pass!(tape)

Expand All @@ -81,18 +76,14 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data
reverse_executor = [FunctionWrapper{Nothing, Tuple{}}(ReverseExecutor(tape[i]))
for i in length(tape):-1:1]

compiled_extras = (;
ps_cache, data_cache, forward_executor, reverse_executor, output=loss)
ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, compiled_extras),
obj_fn, ts.model, ts.parameters, ts.states,
ts.optimizer, ts.optimizer_state, ts.step)

return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new
@set! ts.cache = TrainingBackendCache(ad, False(), dparams,
(; ps_cache, data_cache, forward_executor, reverse_executor, output=loss))
@set! ts.objective_function = obj_fn
return dparams, ReverseDiff.value(loss), NamedTuple(), ts
end

function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F}
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}, F}) where {F}
(; ps_cache, data_cache, output) = ts.cache.extras

dparams = Lux.recursive_make_zero!!(ts.cache.dparameters)
Expand All @@ -107,7 +98,7 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data
wrapper()
end

ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, ts.states,
ts.optimizer, ts.optimizer_state, ts.step)
return dparams, ReverseDiff.value(output), NamedTuple(), ts_new
@set! ts.cache.first_try = False()
@set! ts.objective_function = obj_fn
return dparams, ReverseDiff.value(output), NamedTuple(), ts
end
4 changes: 3 additions & 1 deletion ext/LuxTrackerExt/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ module LuxTrackerExt
using ADTypes: AbstractADType, AutoTracker
using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore
using Setfield: @set!
using Static: False, True
using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules

using Lux: Lux, Utils
using Lux.Training: TrainingBackendCache, TrainState
using Lux.Training: Training, TrainingBackendCache, TrainState

const CRC = ChainRulesCore

Expand Down
24 changes: 12 additions & 12 deletions ext/LuxTrackerExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT}
dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ps_tracked = construct_tracked_params(ts.parameters, dparams)
ts::TrainState{<:TrainingBackendCache{AutoTracker}}) where {F}
dps = Training.dparameters(ts.cache)
ps_tracked = construct_tracked_params(ts.parameters, dps)

loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data)
Tracker.back!(loss)

ts_new = TrainState(
TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, nothing), obj_fn,
ts.model, ts.parameters, st, ts.optimizer, ts.optimizer_state, ts.step)
@set! ts.cache.first_try = False()
@set! ts.objective_function = obj_fn
@set! ts.states = st

return dparams, loss.data, stats, ts_new
return dps, loss.data, stats, ts
end

function Lux.Training.compute_gradients(
::AutoTracker, obj_fn::F, data, ts::TrainState) where {F}
ad::AutoTracker, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(
TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn, ts.model,
ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step)
return Lux.Training.compute_gradients(AutoTracker(), obj_fn, data, ts_new)
cache = TrainingBackendCache(ad, True(), grads, nothing)
@set! ts.cache = cache
@set! ts.objective_function = obj_fn
return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end
8 changes: 1 addition & 7 deletions src/helpers/recursive_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,12 @@ function recursive_map(f::F, x::AbstractArray{T}, args...) where {F, T}
(T <: Number || isbitstype(T)) && return f(x, args...) # Not all Number types (BigFloat) are bitstype
return f.(x, args...)
end
function recursive_map(f::F, x::Tuple, args...) where {F}
function recursive_map(f::F, x::Union{NamedTuple, Tuple}, args...) where {F}
map_fn = let f = f
(args_...) -> recursive_map(f, args_...)
end
return map(map_fn, x, args...)
end
function recursive_map(f::F, x::NamedTuple{fields}, args...) where {F, fields}
map_fn = let f = f
(args_...) -> recursive_map(f, args_...)
end
return NamedTuple{fields}(map(map_fn, values(x), values.(args)...))
end
recursive_map(f::F, x, args...) where {F} = fmap(f, x, args...)

@compat(public,
Expand Down
23 changes: 13 additions & 10 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ConcreteStructs: @concrete
using FastClosures: @closure
using Optimisers: Optimisers
using Setfield: @set!
using Static: StaticBool, Static, False, True

using ..Lux: Lux
using LuxCore: LuxCore, AbstractLuxLayer
Expand Down Expand Up @@ -50,13 +51,10 @@ Constructor for [`TrainState`](@ref).
## Arguments
- `rng`: Random Number Generator.
- `ps`: Parameters of the model.
- `st`: States of the model.
- `model`: `Lux` model.
- `optimizer`: Optimizer from `Optimisers.jl`.
- `transform_variables`: Function to transform the variables of the model. Typically used
to transfer variables to GPU / CPU.
## Returns
Expand All @@ -67,12 +65,18 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end

@concrete struct TrainingBackendCache{backend, first_try}
@concrete struct TrainingBackendCache
backend
first_try <: StaticBool
dparameters
extras
end

training_backend(::TrainingBackendCache{backend}) where {backend} = backend
dparameters(cache::TrainingBackendCache) = dparameters(cache, cache.first_try)
function dparameters(cache::TrainingBackendCache, ::False)
return Lux.recursive_make_zero!!(cache.dparameters)
end
dparameters(cache::TrainingBackendCache, ::True) = cache.dparameters

function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
println(io, "TrainState")
Expand All @@ -83,8 +87,7 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
print(io, " step: ", ts.step)
if ts.cache !== nothing
if ts.cache isa TrainingBackendCache
print(io,
"\n cache: $(nameof(typeof(ts.cache))){$(training_backend(ts.cache))}")
print(io, "\n cache: $(nameof(typeof(ts.cache)))($(ts.cache.backend))")
else
print(io, "\n cache: $(nameof(typeof(ts.cache)))")
end
Expand Down Expand Up @@ -198,7 +201,7 @@ for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme)
end
end

function generate_wrappers(::F, m, ps, st, data, ::Val{false}) where {F}
function generate_wrappers(::F, m, ps, st, data, ::False) where {F}
@warn "Detected function wrapper generation with function being updated between calls. \
This will generate type-unstable code. A possible reason for this is \
`TrainState` was compiled (first call to `compute_gradients`) with function \
Expand All @@ -208,13 +211,13 @@ function generate_wrappers(::F, m, ps, st, data, ::Val{false}) where {F}
end

# Run the code when trying to compile the function for the first time.
function generate_wrappers(objective_function::F, m, ps, st, data, ::Val{true}) where {F}
function generate_wrappers(objective_function::F, m, ps, st, data, ::True) where {F}
_, stₙ, statsₙ = objective_function(m, ps, st, data)
return Ref{typeof(stₙ)}(stₙ), Ref{typeof(statsₙ)}(statsₙ)
end

function wrap_objective_function(
objective_function::F, m, ps, st, data, first_try::Val) where {F}
objective_function::F, m, ps, st, data, first_try::StaticBool) where {F}
st_updated, stats = generate_wrappers(objective_function, m, ps, st, data, first_try)

wrapped_objective_function = @closure (model, ps, st, data) -> begin
Expand Down
Loading

0 comments on commit 92b15fd

Please sign in to comment.