Skip to content

Commit

Permalink
Finish implementation of compiled reversediff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 23, 2024
1 parent 66a5cda commit 8331b57
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.56"
version = "0.5.57"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
31 changes: 8 additions & 23 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@ using Lux.Experimental: TrainingBackendCache, TrainState

# Case I: We have TrainingBackendCache{:Enzyme} and obj_fn is unchanged.
function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data,

Check warning on line 9 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L9

Added line #L9 was not covered by tests
ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where {F, FT}
dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ts::TrainState{<:TrainingBackendCache{:Enzyme, false}, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)

_, loss = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, ts.cache.extras.obj_fn, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
ts.cache.extras.st_wrap[], ts.states, ts, obj_fn, dps,
ts.cache.extras.obj_fn, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap)
ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters,

Check warning on line 17 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L17

Added line #L17 was not covered by tests
ts.cache.extras.st_wrap[], ts.optimizer_state, ts.step)

return dps, loss, ts.cache.extras.stats_wrap[], ts_new

Check warning on line 20 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L20

Added line #L20 was not covered by tests
end
Expand All @@ -33,8 +32,10 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data,
Enzyme.ReverseWithPrimal, obj_fn_wrap, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
st_wrap[], ts.states, ts, obj_fn, dps, obj_fn_wrap, st_wrap, stats_wrap)
cache = TrainingBackendCache{:Enzyme, false}(

Check warning on line 35 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L35

Added line #L35 was not covered by tests
dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap))
ts_new = TrainState(

Check warning on line 37 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L37

Added line #L37 was not covered by tests
cache, obj_fn, ts.model, ts.parameters, st_wrap[], ts.optimizer_state, ts.step)

return dps, loss, stats_wrap[], ts_new
end
Expand All @@ -50,20 +51,4 @@ function Lux.Experimental.compute_gradients(
return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new)

Check warning on line 51 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L51

Added line #L51 was not covered by tests
end

# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not
# storing the objective function.
function __construct_new_trainstate(st_new::S, ::S, ts::TrainState, objective_fn::O, dps,
obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2}
cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap))
return TrainState(
cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

function __construct_new_trainstate(st_new, _, ts::TrainState, objective_fn::O, dps,
obj_fn::O2, st_wrap, stats_wrap) where {O, O2}
cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap))
return TrainState(
cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

end
80 changes: 71 additions & 9 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,84 @@ end
# Compiled ReverseDiff
@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn,
data_cache = deepcopy(data)
ps_cache = deepcopy(ts.parameters)
extras = (; data_cache, ps_cache)

Check warning on line 49 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L45-L49

Added lines #L45 - L49 were not covered by tests

ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing,

Check warning on line 51 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L51

Added line #L51 was not covered by tests
ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return __compiled_reverse_diff(obj_fn, data, ts_new)

Check warning on line 53 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L53

Added line #L53 was not covered by tests
end

## Tape hasn't been compiled yet
@inline function __compiled_reverse_diff(obj_fn::F,
data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT, P, Nothing}}) where {
F, FT, P}
## Tape hasn't been compiled yet / Function mismatch so recompile
@inline function __compiled_reverse_diff(obj_fn::F, data,

Check warning on line 57 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L57

Added line #L57 was not covered by tests
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
if Lux.statelength(ts.states) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \

Check warning on line 60 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
models with non-empty state `st`."))
end

if FT # do a dry run
_, st_, stats = obj_fn(ts.model, ts.parameters, ts.states, data)
if stats != NamedTuple()
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \

Check warning on line 67 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L64-L67

Added lines #L64 - L67 were not covered by tests
loss functions that return non-empty `stats`."))
end
if Lux.statelength(st_) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \

Check warning on line 71 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
models with non-empty state `st`."))
end
end

dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)

Check warning on line 76 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L76

Added line #L76 was not covered by tests

(; ps_cache, data_cache) = ts.cache.extras
if !FT
Lux.recursive_copyto!(ps_cache, ts.parameters)
Lux.recursive_copyto!(data_cache, data)

Check warning on line 81 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L78-L81

Added lines #L78 - L81 were not covered by tests
end

obj_fn_wrap = first obj_fn

Check warning on line 84 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L84

Added line #L84 was not covered by tests

tape = ReverseDiff.InstructionTape()
ps_tracked = Lux.recursive_map(

Check warning on line 87 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L86-L87

Added lines #L86 - L87 were not covered by tests
Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams)
Lux.__Fix3(ReverseDiff.TrackedArray, tape), ps_cache, dparams)

loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data)
loss = obj_fn_wrap(ts.model, ps_tracked, ts.states, data_cache)
loss.deriv = true
ReverseDiff.reverse_pass!(tape)

Check warning on line 92 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L90-L92

Added lines #L90 - L92 were not covered by tests

forward_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ForwardExecutor(instruction))

Check warning on line 94 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L94

Added line #L94 was not covered by tests
for instruction in tape]
reverse_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ReverseExecutor(tape[i]))

Check warning on line 96 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L96

Added line #L96 was not covered by tests
for i in length(tape):-1:1]

compiled_extras = (;

Check warning on line 99 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L99

Added line #L99 was not covered by tests
ps_cache, data_cache, forward_executor, reverse_executor, output=loss)
ts_new = TrainState(

Check warning on line 101 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L101

Added line #L101 was not covered by tests
TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, compiled_extras),
obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)

return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new

Check warning on line 105 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L105

Added line #L105 was not covered by tests
end

@inline function __compiled_reverse_diff(obj_fn::F, data,

Check warning on line 108 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L108

Added line #L108 was not covered by tests
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F}
(; ps_cache, data_cache, output) = ts.cache.extras

Check warning on line 110 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L110

Added line #L110 was not covered by tests

error(1)
dparams = Lux.recursive_make_zero!!(ts.cache.dparameters)
Lux.recursive_copyto!(ps_cache, ts.parameters)
Lux.recursive_copyto!(data_cache, data)

Check warning on line 114 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L112-L114

Added lines #L112 - L114 were not covered by tests

for wrapper in ts.cache.extras.forward_executor
wrapper()
end
output.deriv = true
for wrapper in ts.cache.extras.reverse_executor
wrapper()
end

Check warning on line 122 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L116-L122

Added lines #L116 - L122 were not covered by tests

ts_new = TrainState(

Check warning on line 124 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L124

Added line #L124 was not covered by tests
ts.cache, obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return dparams, ReverseDiff.value(output), NamedTuple(), ts_new

Check warning on line 126 in ext/LuxReverseDiffExt/training.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReverseDiffExt/training.jl#L126

Added line #L126 was not covered by tests
end
9 changes: 9 additions & 0 deletions src/helpers/recursive_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version.
"""
@inline recursive_make_zero!!(x) = recursive_map(__zero!!, x)::typeof(x)

Check warning on line 51 in src/helpers/recursive_ops.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/recursive_ops.jl#L51

Added line #L51 was not covered by tests

"""
recursive_copyto!(x, y)
Recursively copy the leaves of two nested structures `x` and `y`. In Functor language, this
is equivalent to doing `fmap(copyto!, x, y)`, but this implementation uses type stable code
for common cases. Note that any immutable leaf will lead to an error.
"""
@inline recursive_copyto!(x, y) = recursive_map(copyto!, x, y)

Check warning on line 60 in src/helpers/recursive_ops.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/recursive_ops.jl#L60

Added line #L60 was not covered by tests

"""
recursive_map(f, x, args...)
Expand Down
71 changes: 71 additions & 0 deletions test/contrib/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,74 @@ end

@inferred Lux.Experimental.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)
end

@testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:contrib] begin
using ADTypes, Optimisers, ReverseDiff

function mse1(model, ps, st, data)
x_data, y_data = data
y, st_ = model(x_data, ps, st)
return sum(abs2, y .- y_data), st_, (;)
end

function mse2(model, ps, st, data)
l, st_, stats = mse1(model, ps, st, data)
return l, st_, (; data=2.0f0)
end

rng = StableRNG(12345)

dataset = [(randn(rng, Float32, 4, 32), randn(rng, Float32, 4, 32)) for _ in 1:100]

@testset "Unhandled Cases" begin
model = Chain(Dense(4, 32, tanh), BatchNorm(32),
Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4))
ps, st = Lux.setup(rng, model)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))

# Stateful models are not supported
@test_throws ArgumentError Lux.Experimental.compute_gradients(
AutoReverseDiff(; compile=true), mse1, dataset[1], tstate)

model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4))
ps, st = Lux.setup(rng, model)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))

# Loss functions that return non-empty `stats` are not supported
@test_throws ArgumentError Lux.Experimental.compute_gradients(
AutoReverseDiff(; compile=true), mse2, dataset[1], tstate)

struct StrangeModel <: Lux.AbstractExplicitLayer end

function (m::StrangeModel)(x, ps, st)
return x, (; new_state=0.0)
end

model = StrangeModel()
ps, st = Lux.setup(rng, model)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))

# Stateful models are not supported
@test_throws ArgumentError Lux.Experimental.compute_gradients(
AutoReverseDiff(; compile=true), mse1, dataset[1], tstate)
end

model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4))
ps, st = Lux.setup(rng, model)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))

loss_initial = first(mse1(model, ps, st, dataset[1]))
for i in 1:100
for (x, y) in dataset
_, _, _, tstate = Lux.Experimental.single_train_step!(
AutoReverseDiff(; compile=true), mse1, (x, y), tstate)
end
end
loss_final = first(mse1(model, tstate.parameters, tstate.states, dataset[1]))

@test loss_final * 100 < loss_initial
end

0 comments on commit 8331b57

Please sign in to comment.