Skip to content

Commit

Permalink
Consolidate the Backend parameter caching
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 22, 2024
1 parent e0d262a commit ebe293b
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 64 deletions.
29 changes: 13 additions & 16 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
module LuxEnzymeExt

using ADTypes: AutoEnzyme
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated
using Lux: Lux

@concrete struct CachedEnzymeExtras{FT}
dparameters
objective_function
st_wrap
stats_wrap
end
using Lux.Experimental: TrainingBackendCache

# Case I: We have CachedEnzymeExtras and objective_function is unchanged.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras{FT}, F}) where {F, FT}
function Lux.Experimental.compute_gradients(::AutoEnzyme,

Check warning on line 9 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L9

Added line #L9 was not covered by tests
objective_function::F,
data,
ts::Lux.Experimental.TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where {
F, FT}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)

_, loss = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
ts.cache.st_wrap[], ts.states, ts, objective_function, dps,
ts.cache.objective_function, ts.cache.st_wrap, ts.cache.stats_wrap)
ts.cache.extras.st_wrap[], ts.states, ts, objective_function, dps,
ts.cache.objective_function, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap)

return dps, loss, ts.cache.stats_wrap[], ts_new
return dps, loss, ts.cache.extras.stats_wrap[], ts_new

Check warning on line 24 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L24

Added line #L24 was not covered by tests
end

# Case II: We have CachedEnzymeExtras and objective_function is changed.
Expand All @@ -49,7 +45,8 @@ end
function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)
cache = CachedEnzymeExtras{true}(dps, nothing, nothing, nothing)
cache = TrainingBackendCache{:Enzyme, true}(

Check warning on line 48 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L48

Added line #L48 was not covered by tests
dps, nothing, (; st_wrap=nothing, stats_wrap=nothing))
ts_new = Lux.Experimental.TrainState(
cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new)
Expand All @@ -60,15 +57,15 @@ end
function __construct_new_trainstate(
st_new::S, ::S, ts::Lux.Experimental.TrainState, objective_fn::O,
dps, obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2}
cache = CachedEnzymeExtras{false}(dps, obj_fn, st_wrap, stats_wrap)
cache = TrainingBackendCache{:Enzyme, false}(dps, obj_fn, (; st_wrap, stats_wrap))

Check warning on line 60 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L60

Added line #L60 was not covered by tests
return Lux.Experimental.TrainState(
cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

function __construct_new_trainstate(
st_new, _, ts::Lux.Experimental.TrainState, objective_fn::O,
dps, obj_fn::O2, st_wrap, stats_wrap) where {O, O2}
cache = CachedEnzymeExtras{false}(dps, nothing, nothing, nothing)
cache = TrainingBackendCache{:Enzyme, false}(dps, nothing, (; st_wrap, stats_wrap))

Check warning on line 68 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L68

Added line #L68 was not covered by tests
return Lux.Experimental.TrainState(
cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end
Expand Down
3 changes: 1 addition & 2 deletions ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ module LuxReverseDiffExt

using ADTypes: ADTypes, AutoReverseDiff
using ArrayInterface: ArrayInterface
using Functors: fmap
using Lux: Lux, LuxCPUDevice
using Lux.Experimental: TrainingBackendCache
using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules
using Setfield: @set!

# AoS to SoA conversion
function Lux.apply(

Check warning on line 10 in ext/LuxReverseDiffExt/LuxReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/LuxReverseDiffExt.jl#L10

Added line #L10 was not covered by tests
Expand Down
44 changes: 29 additions & 15 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,43 @@ else
end
end

@inline function __uncompiled_reverse_diff(
objective_function::F, data, ts::Lux.Experimental.TrainState) where {F}
# Uncompiled ReverseDiff
@inline function __uncompiled_reverse_diff(objective_function::F, data,

Check warning on line 20 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L20

Added line #L20 was not covered by tests
ts::Lux.Experimental.TrainState{<:TrainingBackendCache{:ReverseDiff}}) where {F}
tape = ReverseDiff.InstructionTape()
grads = Lux.recursive_make_zero(ts.parameters)
ps_tracked = Lux.recursive_map(

Check warning on line 23 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L22-L23

Added lines #L22 - L23 were not covered by tests
Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads)
Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, ts.cache.dparameters)

loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data)
loss.deriv = true
ReverseDiff.reverse_pass!(tape)

Check warning on line 28 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L26-L28

Added lines #L26 - L28 were not covered by tests
@set! ts.states = st
return grads, ReverseDiff.value(loss), stats, ts

ts_new = Lux.Experimental.TrainState(

Check warning on line 30 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L30

Added line #L30 was not covered by tests
TrainingBackendCache{:ReverseDiff, false}(
ts.cache.dparameters, objective_function, nothing),
objective_function,
ts.model,
ts.parameters,
st,
ts.optimizer_state,
ts.step)

return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new

Check warning on line 40 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L40

Added line #L40 was not covered by tests
end

# First call, nothing is cached
@inline function __uncompiled_reverse_diff(

Check warning on line 44 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L44

Added line #L44 was not covered by tests
objective_function::F, data, ts::Lux.Experimental.TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = Lux.Experimental.TrainState(

Check warning on line 47 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L46-L47

Added lines #L46 - L47 were not covered by tests
TrainingBackendCache{:ReverseDiff, true}(grads, objective_function, nothing),
objective_function, ts.model, ts.parameters,
ts.states, ts.optimizer_state, ts.step)
return __uncompiled_reverse_diff(objective_function, data, ts_new)

Check warning on line 51 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L51

Added line #L51 was not covered by tests
end

# Compiled ReverseDiff
@inline function __compiled_reverse_diff(

Check warning on line 55 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L55

Added line #L55 was not covered by tests
objective_function::F, data, ts::Lux.Experimental.TrainState) where {F}
# tape = ReverseDiff.InstructionTape()
# grads = Lux.recursive_make_zero(ts.parameters)
# ps_tracked = Lux.recursive_map(
# Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads)
# loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data)
# loss.deriv = true
# ReverseDiff.reverse_pass!(tape)
# @set! ts.states = st
# return grads, ReverseDiff.value(loss), stats, ts
error("Not implemented yet")

Check warning on line 57 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L57

Added line #L57 was not covered by tests
end
72 changes: 43 additions & 29 deletions ext/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,61 @@ module LuxTrackerExt
using ADTypes: AutoTracker
using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
using Functors: fmap
using Lux: Lux, LuxCPUDevice
using Lux.Experimental: TrainingBackendCache
using LuxCore: LuxCore
using Setfield: @set!
using Tracker: Tracker, TrackedArray, @grad_from_chainrules

const CRC = ChainRulesCore

# Type Piracy: Need to upstream
Tracker.param(nt::NamedTuple{F}) where {F} = NamedTuple{F}(Tracker.param.(values(nt)))
Tracker.param(t::Tuple) = map(Tracker.param, t)
Tracker.param(l::LuxCore.AbstractExplicitLayer) = l

Tracker.zero_grad!(nt::NamedTuple) = Tracker.zero_grad!.(values(nt))
Tracker.zero_grad!(::LuxCore.AbstractExplicitLayer) = nothing

function Tracker.extract_grad!(nt::NamedTuple{F}) where {F}
return NamedTuple{F}(Tracker.extract_grad!.(values(nt)))
end
Tracker.extract_grad!(t::Tuple) = map(Tracker.extract_grad!, t)
Tracker.extract_grad!(::LuxCore.AbstractExplicitLayer) = nothing

Tracker.data(nt::NamedTuple) = fmap(Tracker.data, nt)
Tracker.data(t::Tuple) = map(Tracker.data, t)
Tracker.data(l::LuxCore.AbstractExplicitLayer) = l

# Weight Norm Patch
@inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims))

# multigate chain rules
@inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)]
@inline Lux._gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :]

function __construct_tracked_params(ps, dps)
map_fn = (p, dp) -> Tracker.TrackedArray(Tracker.Call(), p, dp)
return Lux.recursive_map(map_fn, ps, dps)

Check warning on line 23 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L21-L23

Added lines #L21 - L23 were not covered by tests
end

# Lux.Training
function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
ps_tracked = fmap(Tracker.param, ts.parameters)
## Use the cached gradient parameters
function Lux.Experimental.compute_gradients(::AutoTracker,

Check warning on line 28 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L28

Added line #L28 was not covered by tests
objective_function::F,
data,
ts::Lux.Experimental.TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT}
dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ps_tracked = __construct_tracked_params(ts.parameters, dparams)

Check warning on line 33 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L32-L33

Added lines #L32 - L33 were not covered by tests

loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data)
Tracker.back!(loss)
@set! ts.states = st
grads = fmap(Tracker.grad, ps_tracked)
return grads, Tracker.value(loss), stats, ts

ts_new = Lux.Experimental.TrainState(

Check warning on line 38 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L38

Added line #L38 was not covered by tests
TrainingBackendCache{:Tracker, false}(
ts.cache.dparameters, objective_function, nothing),
objective_function,
ts.model,
ts.parameters,
st,
ts.optimizer_state,
ts.step)

return dparams, Tracker.value(loss), stats, ts_new

Check warning on line 48 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L48

Added line #L48 was not covered by tests
end

## First call, nothing is cached
function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F, data,

Check warning on line 52 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L52

Added line #L52 was not covered by tests
ts::Lux.Experimental.TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = Lux.Experimental.TrainState(

Check warning on line 55 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L54-L55

Added lines #L54 - L55 were not covered by tests
TrainingBackendCache{:Tracker, true}(grads, objective_function, nothing),
objective_function, ts.model, ts.parameters,
ts.states, ts.optimizer_state, ts.step)
return Lux.Experimental.compute_gradients(

Check warning on line 59 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L59

Added line #L59 was not covered by tests
AutoTracker(), objective_function, data, ts_new)
end

# AoS to SoA conversion
Expand Down Expand Up @@ -77,9 +89,11 @@ Tracker.@grad function Lux.__apply_simple_chain(layer, x, ps, ::LuxCPUDevice)
As such please test your model with FiniteDifferences or Zygote before using \
`Tracker.jl` for your model." maxlog=1
y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps))
__∇apply_simple_chain = @closure Δ -> begin
_, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ)))
return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing))
__∇apply_simple_chain = let pb_f = pb_f
Δ -> begin
_, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ)))
return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing))

Check warning on line 95 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L92-L95

Added lines #L92 - L95 were not covered by tests
end
end
# Tracker is not great at handling arbitrary types, so we convert to Array
return Array(y), __∇apply_simple_chain
Expand Down
16 changes: 15 additions & 1 deletion src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,28 @@ Internal fields:
step::Int
end

@concrete struct TrainingBackendCache{backend, first_try}
dparameters
objective_function
extras
end

@inline __backend(::TrainingBackendCache{backend}) where {backend} = backend

Check warning on line 38 in src/contrib/training.jl

View check run for this annotation

Codecov / codecov/patch

src/contrib/training.jl#L38

Added line #L38 was not covered by tests

function Base.show(io::IO, ts::TrainState)
println(io, "TrainState")
println(io, " model: ", ts.model)
println(io, " # of parameters: ", Lux.parameterlength(ts.parameters))
println(io, " # of states: ", Lux.statelength(ts.states))
println(io, " optimizer_state: ", ts.optimizer_state)
print(io, " step: ", ts.step)
ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache)))
if ts.cache !== nothing
if ts.cache isa TrainingBackendCache
print(io, "\n cache: $(nameof(typeof(ts.cache))){$(__backend(ts.cache))}")

Check warning on line 49 in src/contrib/training.jl

View check run for this annotation

Codecov / codecov/patch

src/contrib/training.jl#L48-L49

Added lines #L48 - L49 were not covered by tests
else
print(io, "\n cache: $(nameof(typeof(ts.cache)))")

Check warning on line 51 in src/contrib/training.jl

View check run for this annotation

Codecov / codecov/patch

src/contrib/training.jl#L51

Added line #L51 was not covered by tests
end
end
ts.objective_function !== nothing &&
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end
Expand Down
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ end
end
@inline function __fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y)
fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y)
return mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y)
# mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) leads to slowdowns, better to
# allocate a new array
return sum(lfn.(x, y))
end

@inline __fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...)
Expand Down

0 comments on commit ebe293b

Please sign in to comment.