From 6c0efebb306b4287a6b6a5a25a32e98c1666f592 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 5 Apr 2024 18:00:24 -0400 Subject: [PATCH] Add warmup step --- bench/helpers.jl | 30 +++++++++++++++++++++++------- bench/layers.jl | 5 ++--- bench/runbenchmarks.jl | 2 +- examples/SimpleChains/main.jl | 5 +++++ 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/bench/helpers.jl b/bench/helpers.jl index 485e0b08d..5aa79988c 100644 --- a/bench/helpers.jl +++ b/bench/helpers.jl @@ -19,7 +19,8 @@ function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps_nt::N end function benchmark_reverse_pass( - tag::String, end_tag::String, backends, model, x, ps_nt::NamedTuple, st) + tag::String, end_tag::String, backends, model, x, ps_nt::NamedTuple, st; + simple_chains=nothing) # Not everyone can handle NamedTuples so convert to ComponentArray __f = @closure ps -> sum(abs2, first(Lux.apply(model, x, ps, st))) ps_ca = ComponentArray(ps_nt) @@ -28,6 +29,15 @@ function benchmark_reverse_pass( __benchmark_reverse_pass(tag, end_tag, backend, __f, ps_ca) end + if simple_chains !== nothing + simple_chains_model = simple_chains(model) + ps_simple_chains, st_simple_chains = general_setup(simple_chains_model, nothing) + __f = @closure ps -> sum( + abs2, first(Lux.apply(simple_chains_model, x, ps, st_simple_chains))) + __benchmark_reverse_pass_simple_chains( + tag, end_tag, AutoZygote(), __f, ps_simple_chains) + end + return end @@ -41,23 +51,23 @@ end # TODO: Remove these once DifferentiationInterface has been released function __benchmark_reverse_pass( - tag::String, end_tag::String, ::AutoEnzyme, f::F, x) where {F} + tag::String, end_tag::String, ::AutoEnzyme, f::F, x; kwargs...) where {F} # TODO: Enable this. But enzyme doesn't handle closures well it seems... # SUITE[tag]["cpu"]["reverse"]["Enzyme"][end_tag] = @benchmarkable Enzyme.gradient( # $Enzyme.Reverse, $f, $x) - error("Enzyme backend hasn't been implemented yet.") + return error("Enzyme backend hasn't been implemented yet.") end function __benchmark_reverse_pass( - tag::String, end_tag::String, ::AutoTapir, f::F, x) where {F} + tag::String, end_tag::String, ::AutoTapir, f::F, x; kwargs...) where {F} end function __benchmark_reverse_pass( - tag::String, end_tag::String, ::AutoTracker, f::F, x) where {F} + tag::String, end_tag::String, ::AutoTracker, f::F, x; kwargs...) where {F} SUITE[tag]["cpu"]["reverse"]["Tracker"][end_tag] = @benchmarkable Tracker.gradient( $f, $x) return end function __benchmark_reverse_pass( - tag::String, end_tag::String, ad::AutoReverseDiff, f::F, x) where {F} + tag::String, end_tag::String, ad::AutoReverseDiff, f::F, x; kwargs...) where {F} if ad.compile SUITE[tag]["cpu"]["reverse"]["ReverseDiff (compiled)"][end_tag] = @benchmarkable ReverseDiff.gradient!( ∂x, tape, $x) setup=(∂x = similar($x); @@ -68,8 +78,14 @@ function __benchmark_reverse_pass( end end function __benchmark_reverse_pass( - tag::String, end_tag::String, ::AutoZygote, f::F, x) where {F} + tag::String, end_tag::String, ::AutoZygote, f::F, x; kwargs...) where {F} SUITE[tag]["cpu"]["reverse"]["Zygote"][end_tag] = @benchmarkable Zygote.gradient( $f, $x) return end +function __benchmark_reverse_pass_simple_chains( + tag::String, end_tag::String, ::AutoZygote, f::F, x; kwargs...) where {F} + SUITE[tag]["cpu"]["reverse"]["SimpleChains"][end_tag] = @benchmarkable Zygote.gradient( + $f, $x) + return +end diff --git a/bench/layers.jl b/bench/layers.jl index e38979703..c1d54c8df 100644 --- a/bench/layers.jl +++ b/bench/layers.jl @@ -9,7 +9,7 @@ function add_dense_benchmarks!() "Dense($n => $n)", "($n, 128)", (AutoTapir(), AutoTracker(), AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()), - layer, x, ps, st) + layer, x, ps, st; simple_chains) end return @@ -22,10 +22,9 @@ function add_conv_benchmarks!() simple_chains = Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch))) benchmark_forward_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)", layer, x, ps, st; simple_chains) - benchmark_reverse_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)", (AutoTapir(), AutoTracker(), AutoReverseDiff(), - AutoReverseDiff(true), AutoZygote()), layer, x, ps, st) + AutoReverseDiff(true), AutoZygote()), layer, x, ps, st; simple_chains) end end diff --git a/bench/runbenchmarks.jl b/bench/runbenchmarks.jl index 7f2f6067f..e6f8ae5a2 100644 --- a/bench/runbenchmarks.jl +++ b/bench/runbenchmarks.jl @@ -27,7 +27,7 @@ struct AutoTapir <: ADTypes.AbstractReverseMode end const SUITE = BenchmarkGroup() include("helpers.jl") -include("vgg.jl") +# include("vgg.jl") include("layers.jl") BenchmarkTools.tune!(SUITE; verbose=true) diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 321c83767..a80b1b93c 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -78,6 +78,11 @@ function train(model; rng=Xoshiro(0), kwargs...) train_state = Lux.Experimental.TrainState( rng, model, Adam(3.0f-4); transform_variables=identity) + ### Warmup the model + x_proto = randn(rng, Float32, 28, 28, 1, 1) + y_proto = onehotbatch([1], 0:9) + Lux.Experimental.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state) + ### Lets train the model nepochs = 10 for epoch in 1:nepochs