Skip to content

Commit

Permalink
Add warmup step
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 5, 2024
1 parent 537bdba commit 6c0efeb
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
30 changes: 23 additions & 7 deletions bench/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps_nt::N
end

function benchmark_reverse_pass(
tag::String, end_tag::String, backends, model, x, ps_nt::NamedTuple, st)
tag::String, end_tag::String, backends, model, x, ps_nt::NamedTuple, st;
simple_chains=nothing)
# Not everyone can handle NamedTuples so convert to ComponentArray
__f = @closure ps -> sum(abs2, first(Lux.apply(model, x, ps, st)))
ps_ca = ComponentArray(ps_nt)
Expand All @@ -28,6 +29,15 @@ function benchmark_reverse_pass(
__benchmark_reverse_pass(tag, end_tag, backend, __f, ps_ca)
end

if simple_chains !== nothing
simple_chains_model = simple_chains(model)
ps_simple_chains, st_simple_chains = general_setup(simple_chains_model, nothing)
__f = @closure ps -> sum(
abs2, first(Lux.apply(simple_chains_model, x, ps, st_simple_chains)))
__benchmark_reverse_pass_simple_chains(
tag, end_tag, AutoZygote(), __f, ps_simple_chains)
end

return
end

Expand All @@ -41,23 +51,23 @@ end

# TODO: Remove these once DifferentiationInterface has been released
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoEnzyme, f::F, x) where {F}
tag::String, end_tag::String, ::AutoEnzyme, f::F, x; kwargs...) where {F}
# TODO: Enable this. But enzyme doesn't handle closures well it seems...
# SUITE[tag]["cpu"]["reverse"]["Enzyme"][end_tag] = @benchmarkable Enzyme.gradient(
# $Enzyme.Reverse, $f, $x)
error("Enzyme backend hasn't been implemented yet.")
return error("Enzyme backend hasn't been implemented yet.")
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTapir, f::F, x) where {F}
tag::String, end_tag::String, ::AutoTapir, f::F, x; kwargs...) where {F}
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTracker, f::F, x) where {F}
tag::String, end_tag::String, ::AutoTracker, f::F, x; kwargs...) where {F}
SUITE[tag]["cpu"]["reverse"]["Tracker"][end_tag] = @benchmarkable Tracker.gradient(
$f, $x)
return
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ad::AutoReverseDiff, f::F, x) where {F}
tag::String, end_tag::String, ad::AutoReverseDiff, f::F, x; kwargs...) where {F}
if ad.compile
SUITE[tag]["cpu"]["reverse"]["ReverseDiff (compiled)"][end_tag] = @benchmarkable ReverseDiff.gradient!(
∂x, tape, $x) setup=(∂x = similar($x);
Expand All @@ -68,8 +78,14 @@ function __benchmark_reverse_pass(
end
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoZygote, f::F, x) where {F}
tag::String, end_tag::String, ::AutoZygote, f::F, x; kwargs...) where {F}
SUITE[tag]["cpu"]["reverse"]["Zygote"][end_tag] = @benchmarkable Zygote.gradient(
$f, $x)
return
end
function __benchmark_reverse_pass_simple_chains(
tag::String, end_tag::String, ::AutoZygote, f::F, x; kwargs...) where {F}
SUITE[tag]["cpu"]["reverse"]["SimpleChains"][end_tag] = @benchmarkable Zygote.gradient(
$f, $x)
return
end
5 changes: 2 additions & 3 deletions bench/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function add_dense_benchmarks!()
"Dense($n => $n)", "($n, 128)",
(AutoTapir(), AutoTracker(), AutoReverseDiff(),
AutoReverseDiff(true), AutoZygote()),
layer, x, ps, st)
layer, x, ps, st; simple_chains)
end

return
Expand All @@ -22,10 +22,9 @@ function add_conv_benchmarks!()
simple_chains = Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch)))
benchmark_forward_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)",
layer, x, ps, st; simple_chains)

benchmark_reverse_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)",
(AutoTapir(), AutoTracker(), AutoReverseDiff(),
AutoReverseDiff(true), AutoZygote()), layer, x, ps, st)
AutoReverseDiff(true), AutoZygote()), layer, x, ps, st; simple_chains)
end
end

Expand Down
2 changes: 1 addition & 1 deletion bench/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct AutoTapir <: ADTypes.AbstractReverseMode end
const SUITE = BenchmarkGroup()

include("helpers.jl")
include("vgg.jl")
# include("vgg.jl")
include("layers.jl")

BenchmarkTools.tune!(SUITE; verbose=true)
Expand Down
5 changes: 5 additions & 0 deletions examples/SimpleChains/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ function train(model; rng=Xoshiro(0), kwargs...)
train_state = Lux.Experimental.TrainState(
rng, model, Adam(3.0f-4); transform_variables=identity)

### Warmup the model
x_proto = randn(rng, Float32, 28, 28, 1, 1)
y_proto = onehotbatch([1], 0:9)
Lux.Experimental.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)

### Lets train the model
nepochs = 10
for epoch in 1:nepochs
Expand Down

0 comments on commit 6c0efeb

Please sign in to comment.