Skip to content

Commit

Permalink
fix: don't reuse pullback for safety
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 20, 2024
1 parent 937efad commit d99d823
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
1 change: 1 addition & 0 deletions ext/LuxZygoteExt/LuxZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using ADTypes: AutoZygote
using ChainRulesCore: ChainRulesCore
using ForwardDiff: ForwardDiff
using Lux: Lux
using LuxDeviceUtils: get_device_type, LuxCPUDevice
using Setfield: @set!
using Zygote: Zygote

Expand Down
20 changes: 17 additions & 3 deletions ext/LuxZygoteExt/batched_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
end

function Lux.__batched_jacobian_impl(f::F, ::AutoZygote, x) where {F}
# It's not safe to run `pb_f` multiple times. We run the first time to be able to
# construct the Jacobian
y, pb_f = Zygote.pullback(f, x)

@argcheck y isa AbstractArray MethodError
Expand All @@ -13,9 +15,21 @@ function Lux.__batched_jacobian_impl(f::F, ::AutoZygote, x) where {F}

J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]),
prod(size(x)[1:(end - 1)]), size(x, ndims(x)))

for i in eachindex(axes(J, 1))
__fill_chunked_jacobian!(J, i, f, pb_f, y, x)
__fill_chunked_jacobian!(J, 1, f, pb_f, y, x)

if get_device_type(x) <: LuxCPUDevice # Use threads
tasks = map(2:size(J, 1)) do i
Threads.@spawn begin
yᵢ, pb_fᵢ = Zygote.pullback(f, x)
__fill_chunked_jacobian!(J, i, f, pb_fᵢ, yᵢ, x)
end
end
map(fetch, tasks)
else # Threading has some issues with cuDNN. we are being safe and not using threads
map(2:size(J, 1)) do i
yᵢ, pb_fᵢ = Zygote.pullback(f, x)
__fill_chunked_jacobian!(J, i, f, pb_fᵢ, yᵢ, x)
end
end

return J
Expand Down

1 comment on commit d99d823

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: d99d823 Previous: 937efad Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3650.625 ns 3669.375 ns 0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7099.833333333333 ns 7200 ns 0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20969 ns 21054.5 ns 1.00
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9688 ns 9781.3 ns 0.99
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8941.75 ns 8908.6 ns 1.00
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4464.625 ns 4449.5 ns 1.00
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1153.062937062937 ns 1157.9930555555557 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1104.0858895705521 ns 1171.8985507246377 ns 0.94
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1188.625 ns 1185.8125 ns 1.00
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1766.051724137931 ns 1791.6440677966102 ns 0.99
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 178.99435825105783 ns 180.02920723226703 ns 0.99
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17262 ns 17283 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16691 ns 16841 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 36799 ns 37390 ns 0.98
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29359.5 ns 28994 ns 1.01
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19826 ns 20027 ns 0.99
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17141.5 ns 17182 ns 1.00
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4335.285714285715 ns 4328 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3860.9375 ns 3870.875 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3921 ns 3954.875 ns 0.99
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4937.714285714285 ns 4847.642857142857 ns 1.02
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1662.1 ns 1658.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 38380312 ns 40009085 ns 0.96
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 58080861 ns 58044866 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 75706671 ns 81652371 ns 0.93
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 88430576 ns 84691568 ns 1.04
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 72678861 ns 75567907 ns 0.96
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11629181 ns 11702113 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 7074423 ns 6962047.5 ns 1.02
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7237640 ns 7131554 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7031304 ns 7068401 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 9924675.5 ns 12275161.5 ns 0.81
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6376916 ns 6382810.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 685503650 ns 714086754 ns 0.96
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2532175381 ns 2579434858 ns 0.98
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 137312502 ns 148383143 ns 0.93
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 791018541 ns 891741497 ns 0.89
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3118945691 ns 3088439700 ns 1.01
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 187532696 ns 212058426 ns 0.88
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 653534251 ns 653749242 ns 1.00
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2584662949 ns 2642538349 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 124578751.5 ns 128699402 ns 0.97
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 173451500 ns 175013829.5 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 669889358.5 ns 654368735 ns 1.02
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 45443081 ns 45671588 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164233714 ns 165691299 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 641495324 ns 644989507 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29662692.5 ns 45474454 ns 0.65
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 209510938 ns 227248905 ns 0.92
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 732300671.5 ns 775116259 ns 0.94
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 35133655 ns 40575003.5 ns 0.87
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1229955419 ns 1266387049.5 ns 0.97
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1870868488 ns 1873245890 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2357631843 ns 2337050962 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2548939820 ns 2598176961 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1854903249 ns 1906864622.5 ns 0.97
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 324822924.5 ns 323660356.5 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 321525303 ns 323114250 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 318220641 ns 319854782 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 435306909 ns 431236413 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11851270 ns 11856906 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17992689 ns 17995791 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19131076 ns 19308830 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23860945 ns 24002610 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18004156 ns 18021078 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1149174.5 ns 1168729 ns 0.98
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2062229 ns 2067715 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2071025 ns 2082044 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2074542 ns 2083795 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2063622 ns 2078152 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 195595 ns 198179 ns 0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 292706 ns 294960 ns 0.99
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 263472 ns 268370 ns 0.98
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 363559 ns 372856 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 411138 ns 414518.5 ns 0.99
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 272924.5 ns 278368 ns 0.98
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 405708 ns 410746 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83085 ns 83776 ns 0.99
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 80541 ns 81953 ns 0.98
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81091 ns 83195 ns 0.97
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86131 ns 87764 ns 0.98
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104796 ns 104685 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 186741939 ns 198244459 ns 0.94
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 327447810 ns 329497937.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 397249133 ns 407311865.5 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 484679103 ns 452728713 ns 1.07
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 377393182 ns 388579645 ns 0.97
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 321862177 ns 319868232.5 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 44743313 ns 44289039 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 44850549.5 ns 44316868 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43924254.5 ns 43960351 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 53418487 ns 51507810 ns 1.04
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28061921 ns 28435005 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19028106.5 ns 18968878 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19552218 ns 19614768 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23488414 ns 23514329 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24184792 ns 24219437.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19687526 ns 19657817 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6493300 ns 6533333 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6514744.5 ns 6543909 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6500552 ns 6522365 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6483934 ns 6506154.5 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.