Skip to content

Commit

Permalink
Merge pull request #576 from LuxDL/ap/regression
Browse files Browse the repository at this point in the history
Make the AD benchmarks type stable
  • Loading branch information
avik-pal authored Apr 6, 2024
2 parents 2151446 + 85dac6a commit 774141a
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 23 deletions.
5 changes: 3 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ steps:
echo $BUILDKITE_PULL_REQUEST
echo $BUILDKITE_TAG
julia --project --code-coverage=user --color=yes --threads=auto --project=docs -e '
julia --project --code-coverage=user --color=yes --project=docs -e '
println("--- :julia: Instantiating project")
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()
println("+++ :julia: Building tutorials")
include("docs/tutorials.jl")'
julia --project --code-coverage=user --color=yes --threads=auto --project=docs -e '
julia --project --code-coverage=user --color=yes --project=docs -e '
println("--- :julia: Instantiating project")
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Expand All @@ -210,6 +210,7 @@ steps:
DATADEPS_ALWAYS_ACCEPT: true
JULIA_DEBUG: "Documenter"
DEBUG: true
JULIA_NUM_THREADS: 6
GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988
if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft
timeout_in_minutes: 240
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Run benchmark
run: |
cd bench
julia --project --color=yes -e '
julia --project --threads=2 --color=yes -e '
using Pkg;
Pkg.develop(PackageSpec(path=joinpath(pwd(), "..")));
Pkg.instantiate();
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ ChainRulesCore = "1.21"
ComponentArrays = "0.15.11"
ConcreteStructs = "0.2.3"
ConstructionBase = "1.5"
FastClosures = "0.3.2"
ExplicitImports = "1.1.1"
FastClosures = "0.3.2"
Flux = "0.14.11"
Functors = "0.4.4"
GPUArraysCore = "0.1.6"
Expand Down
2 changes: 2 additions & 0 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
37 changes: 22 additions & 15 deletions bench/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ function __benchmark_reverse_pass(
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTracker, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["Tracker"][end_tag] = @benchmarkable Tracker.gradient(
f, ps_ca) setup=begin
SUITE[tag]["cpu"]["reverse"]["Tracker"][end_tag] = @benchmarkable begin
ps_tracked = fmap(Tracker.param, ps)
x_tracked = Tracker.param(x)
loss = sum(abs2, first(Lux.apply($model, x_tracked, ps_tracked, st)))
Tracker.back!(loss)
end setup=begin
(x, ps, st) = general_setup($model, $x_dims)
ps_ca = ComponentArray(ps)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
end
return
end
Expand All @@ -93,39 +95,44 @@ function __benchmark_reverse_pass(
tape = ReverseDiff.compile(ReverseDiff.GradientTape(f, ps_ca))
end
else
SUITE[tag]["cpu"]["reverse"]["ReverseDiff"][end_tag] = @benchmarkable ReverseDiff.gradient(
f, ps_ca) setup=begin
SUITE[tag]["cpu"]["reverse"]["ReverseDiff"][end_tag] = @benchmarkable begin
tape = ReverseDiff.InstructionTape()
∂ps = fmap(zero, ps)
ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ps, ∂ps)
∂x = zero(x)
x_tracked = ReverseDiff.TrackedArray(x, ∂x, tape)
loss = sum(abs2, first(Lux.apply($model, x_tracked, ps_tracked, st)))
loss.deriv = true
ReverseDiff.reverse_pass!(tape)
end setup=begin
(x, ps, st) = general_setup($model, $x_dims)
ps_ca = ComponentArray(ps)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
end
end
end
function __benchmark_reverse_pass(tag::String, end_tag::String, ::AutoZygote, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["Zygote"][end_tag] = @benchmarkable Zygote.gradient(
f, ps_ca) setup=begin
f, $model, x, ps, st) setup=begin
(x, ps, st) = general_setup($model, $x_dims)
ps_ca = ComponentArray(ps)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
f = @closure((model, x, p, st)->sum(abs2, first(Lux.apply(model, x, p, st))))
end
return
end
function __benchmark_reverse_pass_simple_chains(
tag::String, end_tag::String, ::AutoZygote, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["SimpleChains"][end_tag] = @benchmarkable Zygote.gradient(
f, ps) setup=begin
f, $model, x, ps, st) setup=begin
(x, ps, st) = general_setup($model, $x_dims)
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st))))
f = @closure((model, x, p, st)->sum(abs2, first(Lux.apply(model, x, p, st))))
end
return
end
function __benchmark_reverse_pass_flux(
tag::String, end_tag::String, ::AutoZygote, model, x_dims)
SUITE[tag]["cpu"]["reverse"]["Flux"][end_tag] = @benchmarkable Zygote.gradient(
f, m) setup=begin
f, m, x) setup=begin
x = randn(StableRNG(0), Float32, $x_dims)
m = $(model)()
f = @closure(m->sum(abs2, m(x)))
f = @closure((m, x)->sum(abs2, m(x)))
end
return
end
9 changes: 5 additions & 4 deletions bench/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using ComponentArrays: ComponentArray
using InteractiveUtils: versioninfo
using FastClosures: @closure
using Flux: Flux
using Functors: fmap
using LinearAlgebra: BLAS
using Lux: Lux, BatchNorm, Chain, Conv, Dense, Dropout, FlattenLayer, MaxPool
using NNlib: relu
using SimpleChains: SimpleChains, static
Expand All @@ -17,13 +19,12 @@ using Tapir: Tapir
using Tracker: Tracker
using Zygote: Zygote

# BenchmarkTools Parameters
BenchmarkTools.DEFAULT_PARAMETERS.samples = 100
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.25

struct AutoTapir <: ADTypes.AbstractReverseMode end

BLAS.set_num_threads(min(4, Threads.nthreads()))

@info sprint(versioninfo)
@info "BLAS threads: $(BLAS.get_num_threads())"

const SUITE = BenchmarkGroup()

Expand Down

0 comments on commit 774141a

Please sign in to comment.