Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiled ReverseDiff for training on CPU #722

Merged
merged 6 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Lux.xlogx
## Recursive Operations

```@docs
Lux.recursive_map
Lux.recursive_add!!
Lux.recursive_eltype
Lux.recursive_make_zero
Expand Down
71 changes: 32 additions & 39 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,75 +1,68 @@
module LuxEnzymeExt

using ADTypes: AutoEnzyme
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated
using Lux: Lux
using Lux.Experimental: TrainingBackendCache, TrainState

@concrete struct CachedEnzymeExtras{FT}
dparameters
objective_function
st_wrap
stats_wrap
end

# Case I: We have CachedEnzymeExtras and objective_function is unchanged.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras{FT}, F}) where {F, FT}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
# Case I: We have TrainingBackendCache{:Enzyme} and obj_fn is unchanged.
function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data,

Check warning on line 9 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L9

Added line #L9 was not covered by tests
ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where {F, FT}
dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)

Check warning on line 11 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L11

Added line #L11 was not covered by tests

_, loss = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model),
Enzyme.ReverseWithPrimal, ts.cache.extras.obj_fn, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
ts.cache.st_wrap[], ts.states, ts, objective_function, dps,
ts.cache.objective_function, ts.cache.st_wrap, ts.cache.stats_wrap)
ts.cache.extras.st_wrap[], ts.states, ts, obj_fn, dps,
ts.cache.extras.obj_fn, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap)

return dps, loss, ts.cache.stats_wrap[], ts_new
return dps, loss, ts.cache.extras.stats_wrap[], ts_new

Check warning on line 21 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L21

Added line #L21 was not covered by tests
end

# Case II: We have CachedEnzymeExtras and objective_function is changed.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras{FT}}) where {F, FT}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
# Case II: We have CachedEnzymeExtras and obj_fn is changed.
function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data,

Check warning on line 25 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L25

Added line #L25 was not covered by tests
ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}}) where {F, FT}
dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)

Check warning on line 27 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L27

Added line #L27 was not covered by tests

obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function(
objective_function, ts.model, ts.parameters, ts.states, data, Val(FT))
obj_fn_wrap, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function(

Check warning on line 29 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L29

Added line #L29 was not covered by tests
obj_fn, ts.model, ts.parameters, ts.states, data, Val(FT))

_, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model),
_, loss = Enzyme.autodiff(

Check warning on line 32 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L32

Added line #L32 was not covered by tests
Enzyme.ReverseWithPrimal, obj_fn_wrap, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
st_wrap[], ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap)
st_wrap[], ts.states, ts, obj_fn, dps, obj_fn_wrap, st_wrap, stats_wrap)

return dps, loss, stats_wrap[], ts_new
end

# Case III: Nothing is cached. First call to `compute_gradients`
function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
function Lux.Experimental.compute_gradients(

Check warning on line 43 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L43

Added line #L43 was not covered by tests
ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)
cache = CachedEnzymeExtras{true}(dps, nothing, nothing, nothing)
ts_new = Lux.Experimental.TrainState(
cache = TrainingBackendCache{:Enzyme, true}(

Check warning on line 46 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L46

Added line #L46 was not covered by tests
dps, (; obj_fn=nothing, st_wrap=nothing, stats_wrap=nothing))
ts_new = TrainState(

Check warning on line 48 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L48

Added line #L48 was not covered by tests
cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new)
return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new)

Check warning on line 50 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L50

Added line #L50 was not covered by tests
end

# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not
# storing the objective function.
function __construct_new_trainstate(
st_new::S, ::S, ts::Lux.Experimental.TrainState, objective_fn::O,
dps, obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2}
cache = CachedEnzymeExtras{false}(dps, obj_fn, st_wrap, stats_wrap)
return Lux.Experimental.TrainState(
function __construct_new_trainstate(st_new::S, ::S, ts::TrainState, objective_fn::O, dps,

Check warning on line 55 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L55

Added line #L55 was not covered by tests
obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2}
cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap))
return TrainState(

Check warning on line 58 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L57-L58

Added lines #L57 - L58 were not covered by tests
cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

function __construct_new_trainstate(
st_new, _, ts::Lux.Experimental.TrainState, objective_fn::O,
dps, obj_fn::O2, st_wrap, stats_wrap) where {O, O2}
cache = CachedEnzymeExtras{false}(dps, nothing, nothing, nothing)
return Lux.Experimental.TrainState(
function __construct_new_trainstate(st_new, _, ts::TrainState, objective_fn::O, dps,

Check warning on line 62 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L62

Added line #L62 was not covered by tests
obj_fn::O2, st_wrap, stats_wrap) where {O, O2}
cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap))
return TrainState(

Check warning on line 65 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L64-L65

Added lines #L64 - L65 were not covered by tests
cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

Expand Down
53 changes: 0 additions & 53 deletions ext/LuxReverseDiffExt.jl

This file was deleted.

28 changes: 28 additions & 0 deletions ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module LuxReverseDiffExt

using ADTypes: ADTypes, AutoReverseDiff
using ArrayInterface: ArrayInterface
using Lux: Lux, LuxCPUDevice
using Lux.Experimental: TrainingBackendCache, TrainState
using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules

# AoS to SoA conversion
function Lux.apply(

Check warning on line 10 in ext/LuxReverseDiffExt/LuxReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/LuxReverseDiffExt.jl#L10

Added line #L10 was not covered by tests
m::Lux.AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st)
@warn "Lux.apply(m::Lux.AbstractExplicitLayer, \

Check warning on line 12 in ext/LuxReverseDiffExt/LuxReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/LuxReverseDiffExt.jl#L12

Added line #L12 was not covered by tests
x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \
Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \
st).\n\n\
1. If this was not the desired behavior overload the dispatch on `m`.\n\n\
2. This might have performance implications. Check which layer was causing this \
problem using `Lux.Experimental.@debug_mode`." maxlog=1
return Lux.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st)

Check warning on line 19 in ext/LuxReverseDiffExt/LuxReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/LuxReverseDiffExt.jl#L19

Added line #L19 was not covered by tests
end

## Prevent an infinite loop
Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)

include("rules.jl")
include("training.jl")

end
14 changes: 14 additions & 0 deletions ext/LuxReverseDiffExt/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SimpleChains.jl
@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice)
@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice)
@grad_from_chainrules Lux.__apply_simple_chain(
layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice)

# DynamicExpressions.jl
@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr,
operator_enum, x::TrackedArray, ps, ::LuxCPUDevice)
@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr,
operator_enum, x, ps::TrackedArray, ::LuxCPUDevice)
@grad_from_chainrules Lux.__apply_dynamic_expression(
de::Lux.DynamicExpressionsLayer, expr, operator_enum,
x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice)
55 changes: 55 additions & 0 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
@static if pkgversion(ADTypes) < v"1.5"
# older versions did not have `compile` type parameter. Use slower type-unstable code
function Lux.Experimental.compute_gradients(

Check warning on line 3 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L3

Added line #L3 was not covered by tests
ad::AutoReverseDiff, obj_fn::F, data, ts::TrainState) where {F}
ad.compile && return __compiled_reverse_diff(obj_fn, data, ts)
return __uncompiled_reverse_diff(obj_fn, data, ts)

Check warning on line 6 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L5-L6

Added lines #L5 - L6 were not covered by tests
end
else
for compiled in (false, true)
fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff
@eval function Lux.Experimental.compute_gradients(

Check warning on line 11 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L11

Added line #L11 was not covered by tests
::AutoReverseDiff{$(compiled)}, obj_fn::F, data, ts::TrainState) where {F}
return $(fname)(obj_fn, data, ts)

Check warning on line 13 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L13

Added line #L13 was not covered by tests
end
end
end

# Uncompiled ReverseDiff
@inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing),

Check warning on line 21 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L19-L21

Added lines #L19 - L21 were not covered by tests
obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
return __uncompiled_reverse_diff(obj_fn, data, ts_new)

Check warning on line 23 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L23

Added line #L23 was not covered by tests
end

@inline function __uncompiled_reverse_diff(obj_fn::F, data,

Check warning on line 26 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L26

Added line #L26 was not covered by tests
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
tape = ReverseDiff.InstructionTape()
dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ps_tracked = Lux.recursive_map(

Check warning on line 30 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L28-L30

Added lines #L28 - L30 were not covered by tests
Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams)

loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data)
loss.deriv = true
ReverseDiff.reverse_pass!(tape)

Check warning on line 35 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L33-L35

Added lines #L33 - L35 were not covered by tests

ts_new = TrainState(

Check warning on line 37 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L37

Added line #L37 was not covered by tests
TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, nothing),
obj_fn, ts.model, ts.parameters, st, ts.optimizer_state, ts.step)

return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new

Check warning on line 41 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L41

Added line #L41 was not covered by tests
end

# Compiled ReverseDiff
@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing),

Check warning on line 47 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L45-L47

Added lines #L45 - L47 were not covered by tests
obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
return __compiled_reverse_diff(obj_fn, data, ts_new)

Check warning on line 49 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L49

Added line #L49 was not covered by tests
end

@inline function __compiled_reverse_diff(obj_fn::F, data,

Check warning on line 52 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L52

Added line #L52 was not covered by tests
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
error(1)

Check warning on line 54 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L54

Added line #L54 was not covered by tests
end
62 changes: 31 additions & 31 deletions ext/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,47 @@
using ADTypes: AutoTracker
using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
using Functors: fmap
using Lux: Lux, LuxCPUDevice
using Lux.Experimental: TrainingBackendCache, TrainState
using LuxCore: LuxCore
using Setfield: @set!
using Tracker: Tracker, TrackedArray, @grad_from_chainrules

const CRC = ChainRulesCore

# Type Piracy: Need to upstream
Tracker.param(nt::NamedTuple{F}) where {F} = NamedTuple{F}(Tracker.param.(values(nt)))
Tracker.param(t::Tuple) = map(Tracker.param, t)
Tracker.param(l::LuxCore.AbstractExplicitLayer) = l

Tracker.zero_grad!(nt::NamedTuple) = Tracker.zero_grad!.(values(nt))
Tracker.zero_grad!(::LuxCore.AbstractExplicitLayer) = nothing

function Tracker.extract_grad!(nt::NamedTuple{F}) where {F}
return NamedTuple{F}(Tracker.extract_grad!.(values(nt)))
end
Tracker.extract_grad!(t::Tuple) = map(Tracker.extract_grad!, t)
Tracker.extract_grad!(::LuxCore.AbstractExplicitLayer) = nothing

Tracker.data(nt::NamedTuple) = fmap(Tracker.data, nt)
Tracker.data(t::Tuple) = map(Tracker.data, t)
Tracker.data(l::LuxCore.AbstractExplicitLayer) = l

# Weight Norm Patch
@inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims))

# multigate chain rules
@inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)]
@inline Lux._gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :]

function __construct_tracked_params(ps, dps)
map_fn = (p, dp) -> Tracker.TrackedArray(Tracker.Call(), p, dp)
return Lux.recursive_map(map_fn, ps, dps)

Check warning on line 22 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L20-L22

Added lines #L20 - L22 were not covered by tests
end

# Lux.Training
function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
ps_tracked = fmap(Tracker.param, ts.parameters)
loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data)
function Lux.Experimental.compute_gradients(::AutoTracker, obj_fn::F, data,

Check warning on line 26 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L26

Added line #L26 was not covered by tests
ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT}
dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ps_tracked = __construct_tracked_params(ts.parameters, dparams)

Check warning on line 29 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L28-L29

Added lines #L28 - L29 were not covered by tests

loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data)

Check warning on line 31 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L31

Added line #L31 was not covered by tests
Tracker.back!(loss)
@set! ts.states = st
grads = fmap(Tracker.grad, ps_tracked)
return grads, Tracker.value(loss), stats, ts

ts_new = TrainState(

Check warning on line 34 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L34

Added line #L34 was not covered by tests
TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, nothing),
obj_fn, ts.model, ts.parameters, st, ts.optimizer_state, ts.step)

return dparams, Tracker.value(loss), stats, ts_new

Check warning on line 38 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L38

Added line #L38 was not covered by tests
end

function Lux.Experimental.compute_gradients(

Check warning on line 41 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L41

Added line #L41 was not covered by tests
::AutoTracker, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn,

Check warning on line 44 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L43-L44

Added lines #L43 - L44 were not covered by tests
ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return Lux.Experimental.compute_gradients(AutoTracker(), obj_fn, data, ts_new)

Check warning on line 46 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L46

Added line #L46 was not covered by tests
end

# AoS to SoA conversion
Expand Down Expand Up @@ -77,9 +75,11 @@
As such please test your model with FiniteDifferences or Zygote before using \
`Tracker.jl` for your model." maxlog=1
y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps))
__∇apply_simple_chain = @closure Δ -> begin
_, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ)))
return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing))
__∇apply_simple_chain = let pb_f = pb_f
Δ -> begin
_, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ)))
return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing))

Check warning on line 81 in ext/LuxTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxTrackerExt.jl#L78-L81

Added lines #L78 - L81 were not covered by tests
end
end
# Tracker is not great at handling arbitrary types, so we convert to Array
return Array(y), __∇apply_simple_chain
Expand Down
Loading
Loading