Skip to content

Commit

Permalink
Benchmark Tapir
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 6, 2024
1 parent 6c0efeb commit 8c83975
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 56 deletions.
1 change: 1 addition & 0 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand Down
102 changes: 71 additions & 31 deletions bench/helpers.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,44 @@
# TODO: Special Handling for GPU Arrays with @sync
function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps_nt::NamedTuple,
st; simple_chains=nothing)
function benchmark_forward_pass(
tag::String, end_tag::String, model, x_dims; simple_chains=nothing,
flux_model=nothing)
SUITE[tag]["cpu"]["forward"]["NamedTuple"][end_tag] = @benchmarkable Lux.apply(
$model, $x, $ps_nt, $st)
$model, x, ps_nt, st) setup=((x, ps_nt, st) = general_setup($model, $x_dims))

ps_ca = ComponentArray(ps_nt)
SUITE[tag]["cpu"]["forward"]["ComponentArray"][end_tag] = @benchmarkable Lux.apply(
$model, $x, $ps_ca, $st)
$model, x, ps_ca, st) setup=((x, ps_nt, st) = general_setup($model, $x_dims); ps_ca = ComponentArray(ps_nt))

if simple_chains !== nothing
simple_chains_model = simple_chains(model)
ps_simple_chains, st_simple_chains = general_setup(simple_chains_model, nothing)
SUITE[tag]["cpu"]["forward"]["SimpleChains"][end_tag] = @benchmarkable Lux.apply(
$simple_chains_model, $x, $ps_simple_chains, $st_simple_chains)
$simple_chains_model, x, ps_simple_chains, st_simple_chains) setup=((x, ps_simple_chains, st_simple_chains) = general_setup(
$simple_chains_model, $x_dims))
end

if flux_model !== nothing
SUITE[tag]["cpu"]["forward"]["Flux"][end_tag] = @benchmarkable fmodel(x) setup=(x = randn(
StableRNG(0), Float32, $x_dims);
fmodel = $(flux_model()))
end

return
end

function benchmark_reverse_pass(
tag::String, end_tag::String, backends, model, x, ps_nt::NamedTuple, st;
simple_chains=nothing)
# Not everyone can handle NamedTuples so convert to ComponentArray
__f = @closure ps -> sum(abs2, first(Lux.apply(model, x, ps, st)))
ps_ca = ComponentArray(ps_nt)

tag::String, end_tag::String, backends, model, x_dims;
simple_chains=nothing, flux_model=nothing)
for backend in backends
__benchmark_reverse_pass(tag, end_tag, backend, __f, ps_ca)
__benchmark_reverse_pass(tag, end_tag, backend, model, x_dims)
end

if simple_chains !== nothing
simple_chains_model = simple_chains(model)
ps_simple_chains, st_simple_chains = general_setup(simple_chains_model, nothing)
__f = @closure ps -> sum(
abs2, first(Lux.apply(simple_chains_model, x, ps, st_simple_chains)))
__benchmark_reverse_pass_simple_chains(
tag, end_tag, AutoZygote(), __f, ps_simple_chains)
tag, end_tag, AutoZygote(), simple_chains_model, x_dims)
end

if flux_model !== nothing
__benchmark_reverse_pass_flux(tag, end_tag, AutoZygote(), flux_model, x_dims)
end

return
Expand All @@ -51,41 +54,78 @@ end

# TODO: Remove these once DifferentiationInterface has been released
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoEnzyme, f::F, x; kwargs...) where {F}
tag::String, end_tag::String, ::AutoEnzyme, model, x_dims)
# TODO: Enable this. But enzyme doesn't handle closures well it seems...
# SUITE[tag]["cpu"]["reverse"]["Enzyme"][end_tag] = @benchmarkable Enzyme.gradient(
# $Enzyme.Reverse, $f, $x)
return error("Enzyme backend hasn't been implemented yet.")
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTapir, f::F, x; kwargs...) where {F}
tag::String, end_tag::String, ::AutoTapir, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["Tapir"][end_tag] = @benchmarkable Tapir.value_and_pullback!!(
trrule, 1.0f0, f, ps_ca) setup=begin
(x, ps, st) = general_setup($model, $x_dims)
ps_ca = ComponentArray(ps)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
trrule = Tapir.build_rrule(f, ps_ca)
end
return
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTracker, f::F, x; kwargs...) where {F}
tag::String, end_tag::String, ::AutoTracker, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["Tracker"][end_tag] = @benchmarkable Tracker.gradient(
$f, $x)
f, ps_ca) setup=begin
(x, ps, st) = general_setup($model, $x_dims)
ps_ca = ComponentArray(ps)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
end
return
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ad::AutoReverseDiff, f::F, x; kwargs...) where {F}
tag::String, end_tag::String, ad::AutoReverseDiff, model, x_dims)
if ad.compile
SUITE[tag]["cpu"]["reverse"]["ReverseDiff (compiled)"][end_tag] = @benchmarkable ReverseDiff.gradient!(
∂x, tape, $x) setup=(∂x = similar($x);
tape = ReverseDiff.compile(ReverseDiff.GradientTape($f, $x)))
∂ps, tape, ps_ca) setup=begin
(x, ps, st) = general_setup($model, $x_dims)
ps_ca = ComponentArray(ps)
∂ps = similar(ps_ca)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
tape = ReverseDiff.compile(ReverseDiff.GradientTape(f, ps_ca))
end
else
SUITE[tag]["cpu"]["reverse"]["ReverseDiff"][end_tag] = @benchmarkable ReverseDiff.gradient(
$f, $x)
f, ps_ca) setup=begin
(x, ps, st) = general_setup($model, $x_dims)
ps_ca = ComponentArray(ps)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
end
end
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoZygote, f::F, x; kwargs...) where {F}
function __benchmark_reverse_pass(tag::String, end_tag::String, ::AutoZygote, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["Zygote"][end_tag] = @benchmarkable Zygote.gradient(
$f, $x)
f, ps_ca) setup=begin
(x, ps, st) = general_setup($model, $x_dims)
ps_ca = ComponentArray(ps)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
end
return
end
function __benchmark_reverse_pass_simple_chains(
tag::String, end_tag::String, ::AutoZygote, f::F, x; kwargs...) where {F}
tag::String, end_tag::String, ::AutoZygote, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["SimpleChains"][end_tag] = @benchmarkable Zygote.gradient(
$f, $x)
f, ps) setup=begin
(x, ps, st) = general_setup($model, $x_dims)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
end
return
end
function __benchmark_reverse_pass_flux(
tag::String, end_tag::String, ::AutoZygote, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["Flux"][end_tag] = @benchmarkable Zygote.gradient(
f, m) setup=begin
x = randn(StableRNG(0), Float32, $x_dims)
m = $(model)()
f = @closure(m->sum(abs2, m(x)))
end
return
end
23 changes: 12 additions & 11 deletions bench/layers.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
function add_dense_benchmarks!()
for n in (2, 20, 200, 2000)
layer = Dense(n => n)
x, ps, st = general_setup(layer, (n, 128))
simple_chains = Lux.ToSimpleChainsAdaptor((static(n),))
simple_chains = n 200 ? Lux.ToSimpleChainsAdaptor((static(n),)) : nothing
flux_model = () -> Flux.Dense(n => n)
benchmark_forward_pass(
"Dense($n => $n)", "($n, 128)", layer, x, ps, st; simple_chains)
"Dense($n => $n)", "($n, 128)", layer, (n, 128); simple_chains, flux_model)
benchmark_reverse_pass(
"Dense($n => $n)", "($n, 128)",
(AutoTapir(), AutoTracker(), AutoReverseDiff(),
AutoReverseDiff(true), AutoZygote()),
layer, x, ps, st; simple_chains)
(AutoTracker(), AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()),
layer, (n, 128); simple_chains, flux_model)
end

return
Expand All @@ -18,13 +17,15 @@ end
function add_conv_benchmarks!()
for ch in (1, 3, 16, 64)
layer = Conv((3, 3), ch => ch)
x, ps, st = general_setup(layer, (64, 64, ch, 128))
simple_chains = Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch)))
simple_chains = ch 16 ?
Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch))) :
nothing
flux_model = () -> Flux.Conv((3, 3), ch => ch)
benchmark_forward_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)",
layer, x, ps, st; simple_chains)
layer, (64, 64, ch, 128); simple_chains, flux_model)
benchmark_reverse_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)",
(AutoTapir(), AutoTracker(), AutoReverseDiff(),
AutoReverseDiff(true), AutoZygote()), layer, x, ps, st; simple_chains)
(AutoTracker(), AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()),
layer, (64, 64, ch, 128); simple_chains, flux_model)
end
end

Expand Down
4 changes: 3 additions & 1 deletion bench/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @btime, @benchmarkable
using ComponentArrays: ComponentArray
using InteractiveUtils: versioninfo
using FastClosures: @closure
using Flux: Flux
using Lux: Lux, BatchNorm, Chain, Conv, Dense, Dropout, FlattenLayer, MaxPool
using NNlib: relu
using SimpleChains: SimpleChains, static
Expand All @@ -27,10 +28,11 @@ struct AutoTapir <: ADTypes.AbstractReverseMode end
const SUITE = BenchmarkGroup()

include("helpers.jl")
# include("vgg.jl")
include("vgg.jl")
include("layers.jl")

BenchmarkTools.tune!(SUITE; verbose=true)
results = BenchmarkTools.run(SUITE; verbose=true)
display(median(results))

BenchmarkTools.save(joinpath(@__DIR__, "benchmark_results.json"), median(results))
29 changes: 25 additions & 4 deletions bench/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,33 @@ function add_vgg_benchmarks!()
BatchNorm(512), MaxPool((2, 2)), FlattenLayer(), Dense(512, 4096, relu),
Dropout(0.5), Dense(4096, 4096, relu), Dropout(0.5), Dense(4096, 10))

flux_model = () -> Flux.Chain(
Flux.Conv((3, 3), 3 => 64, relu; pad=(1, 1), stride=(1, 1)),
Flux.BatchNorm(64), Flux.Conv((3, 3), 64 => 64, relu; pad=(1, 1), stride=(1, 1)),
Flux.BatchNorm(64), Flux.MaxPool((2, 2)),
Flux.Conv((3, 3), 64 => 128, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(128),
Flux.Conv((3, 3), 128 => 128, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(128),
Flux.MaxPool((2, 2)),
Flux.Conv((3, 3), 128 => 256, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(256),
Flux.Conv((3, 3), 256 => 256, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(256),
Flux.Conv((3, 3), 256 => 256, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(256),
Flux.MaxPool((2, 2)),
Flux.Conv((3, 3), 256 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512),
Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512),
Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512),
Flux.MaxPool((2, 2)),
Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512),
Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512),
Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512),
Flux.MaxPool((2, 2)), Flux.flatten, Flux.Dense(512, 4096, relu), Flux.Dropout(0.5),
Flux.Dense(4096, 4096, relu), Flux.Dropout(0.5), Flux.Dense(4096, 10))

for bsize in (1, 16, 64)
x, ps, st = general_setup(vgg16, (32, 32, 3, bsize))
benchmark_forward_pass("vgg16", "(32, 32, 3, $bsize)", vgg16, x, ps, st)
benchmark_forward_pass(
"vgg16", "(32, 32, 3, $bsize)", vgg16, (32, 32, 3, bsize); flux_model)
benchmark_reverse_pass(
"vgg16", "(32, 32, 3, $bsize)",
(AutoTapir(), AutoTracker(), AutoZygote()), vgg16, x, ps, st)
"vgg16", "(32, 32, 3, $bsize)", (AutoTracker(), AutoZygote()),
vgg16, (32, 32, 3, bsize); flux_model)
end

return
Expand Down
12 changes: 6 additions & 6 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ export default defineConfig({
},
nav: [
{ text: 'Home', link: '/' },
{ text: 'Getting Started', link: '/introduction/index' },
{ text: 'Getting Started', link: '/introduction' },
{ text: 'Benchmarks', link: 'https://lux.csail.mit.edu/benchmarks/' },
{ text: 'Tutorials', link: '/tutorials/index' },
{ text: 'Tutorials', link: '/tutorials' },
{ text: 'Manual', link: '/manual/interface' },
{
text: 'API', items: [
Expand Down Expand Up @@ -104,22 +104,22 @@ export default defineConfig({
},
{
text: 'Versions', items: [
{ text: 'Stable', link: 'https://lux.csail.mit.edu/stable/' },
{ text: 'Dev', link: 'https://lux.csail.mit.edu/dev/' }
{ text: 'Stable', link: 'https://lux.csail.mit.edu/stable' },
{ text: 'Dev', link: 'https://lux.csail.mit.edu/dev' }
]
}
],
sidebar: {
"/introduction/": {
text: 'Getting Started', collapsed: false, items: [
{ text: 'Introduction', link: '/introduction/index' },
{ text: 'Introduction', link: '/introduction' },
{ text: 'Overview', link: '/introduction/overview' },
{ text: 'Resources', link: '/introduction/resources' },
{ text: 'Citation', link: '/introduction/citation' }]
},
"/tutorials/": {
text: 'Tutorials', collapsed: false, items: [
{ text: 'Overview', link: '/tutorials/index' },
{ text: 'Overview', link: '/tutorials' },
{
text: 'Beginner', collapsed: false, items: [
{ text: 'Julia & Lux for the Uninitiated', link: '/tutorials/beginner/1_Basics' },
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ hero:
actions:
- theme: brand
text: Tutorials
link: /tutorials/
link: /tutorials
- theme: alt
text: Ecosystem
link: /ecosystem
Expand All @@ -28,7 +28,7 @@ features:
- icon: 🚀
title: Fast & Extendible
details: Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal Hardware.
link: /introduction/
link: /introduction
- icon: 🧑‍🔬
title: SciML ❤️ Lux
Expand Down
2 changes: 1 addition & 1 deletion test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
@test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k)
@jet layer(x, ps, st)
__f = (x, ps) -> sum(first(layer(x, ps, st)))
@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu
@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true
end
end
end
Expand Down

0 comments on commit 8c83975

Please sign in to comment.