-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compiled ReverseDiff for training on CPU #722
Conversation
ec81f7a
to
e0d262a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #722 +/- ##
==========================================
- Coverage 97.32% 97.01% -0.32%
==========================================
Files 52 53 +1
Lines 2653 2711 +58
==========================================
+ Hits 2582 2630 +48
- Misses 71 81 +10 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 1da2e52 | Previous: 72cfeec | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3642.4444444444443 ns |
3665.625 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
7193.5 ns |
7216.916666666666 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
20929 ns |
20528 ns |
1.02 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9566.25 ns |
9693 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
8800.4 ns |
9036.875 ns |
0.97 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4482.125 ns |
4573.625 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) |
1151.5314685314686 ns |
1136.840579710145 ns |
1.01 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
1163.8613138686133 ns |
1149.4418604651162 ns |
1.01 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1187.9358974358975 ns |
1197.1639344262296 ns |
0.99 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1773.6842105263158 ns |
1774.3684210526317 ns |
1.00 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
178.95827538247565 ns |
180.0305980528512 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17142 ns |
17322 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
16771 ns |
16971 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
37361 ns |
39253 ns |
0.95 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
29345 ns |
29565 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
19888 ns |
20178 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
17092 ns |
17172 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) |
4301.75 ns |
4349.571428571428 ns |
0.99 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
3826 ns |
3849.75 ns |
0.99 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
3949.875 ns |
3928.625 ns |
1.01 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4998 ns |
4852 ns |
1.03 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1649.1 ns |
1656.1 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
38630810.5 ns |
39770040 ns |
0.97 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
57506079 ns |
57632628.5 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
75662796 ns |
76488023 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
88344391 ns |
89018322 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
70400778 ns |
72909138.5 ns |
0.97 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
11546033 ns |
11939374.5 ns |
0.97 |
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) |
17590212.5 ns |
17821257 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
6963002 ns |
6991541.5 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
6935185 ns |
6978706 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
9838443.5 ns |
10166161 ns |
0.97 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6383563 ns |
6397739 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
715133916 ns |
756350190 ns |
0.95 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2515094642 ns |
2555366182 ns |
0.98 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) |
127064350 ns |
141900399 ns |
0.90 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
768458977 ns |
776440989 ns |
0.99 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
2827206553 ns |
2952053106 ns |
0.96 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) |
194319154 ns |
205837315.5 ns |
0.94 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
700167308 ns |
687802033.5 ns |
1.02 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
2462646593 ns |
2441444521 ns |
1.01 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) |
119767580.5 ns |
132709456 ns |
0.90 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
173197641 ns |
174508592.5 ns |
0.99 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
638622575.5 ns |
642835050.5 ns |
0.99 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) |
43574003 ns |
34681759 ns |
1.26 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
162193021 ns |
183856278 ns |
0.88 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
633881021.5 ns |
640825595 ns |
0.99 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) |
29618317 ns |
30176411 ns |
0.98 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
187271794 ns |
187538992 ns |
1.00 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
752112090.5 ns |
728794962 ns |
1.03 |
vgg16/cpu/forward/Flux/(32, 32, 3, 2) |
35201017 ns |
37966134.5 ns |
0.93 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1168823578.5 ns |
1263007358.5 ns |
0.93 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1833508726.5 ns |
1859371232 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2259723916 ns |
2370919227 ns |
0.95 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2482580887 ns |
2548748419 ns |
0.97 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
1844324234 ns |
1815330207 ns |
1.02 |
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) |
555200248 ns |
558317841 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
311113812 ns |
320621158 ns |
0.97 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
310203294.5 ns |
324372940 ns |
0.96 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
339834501 ns |
455318653.5 ns |
0.75 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
11858461 ns |
12039979 ns |
0.98 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
17745505 ns |
17936234 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
19023835 ns |
19094503 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
23770123 ns |
23827099 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
17838621 ns |
17970370.5 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1158364 ns |
1164332 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) |
5801078 ns |
5815240.5 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
2031433 ns |
2060550.5 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
2009011 ns |
2038369.5 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
2047093 ns |
2079506 ns |
0.98 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
197651 ns |
202013.5 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
288421 ns |
293620 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
261706 ns |
267050 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
361944.5 ns |
370955 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
405561 ns |
412423 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
271790 ns |
276132.5 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
400576.5 ns |
408115 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) |
82875 ns |
83316 ns |
0.99 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
80241 ns |
81843 ns |
0.98 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
80292 ns |
82094 ns |
0.98 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
85290.5 ns |
87504 ns |
0.97 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104317 ns |
104496 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
192064862 ns |
188734689.5 ns |
1.02 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
313776392.5 ns |
328709072 ns |
0.95 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
355826346 ns |
397785865 ns |
0.89 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
439716902 ns |
459733566 ns |
0.96 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
366480578.5 ns |
377160653 ns |
0.97 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
320362246.5 ns |
327525716.5 ns |
0.98 |
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) |
97737540 ns |
102501485 ns |
0.95 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
42420914 ns |
43867409 ns |
0.97 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
43021921 ns |
43652376 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
57964764.5 ns |
60057045 ns |
0.97 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
26657659 ns |
28379983.5 ns |
0.94 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
18608378 ns |
18956091 ns |
0.98 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19317586 ns |
19563861 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
22943211 ns |
23676947 ns |
0.97 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
23872969.5 ns |
24221720 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19406384 ns |
19663414 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) |
6357292 ns |
6539789 ns |
0.97 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6436783 ns |
6517423.5 ns |
0.99 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6420615 ns |
6503627.5 ns |
0.99 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6449739 ns |
6510008 ns |
0.99 |
This comment was automatically generated by workflow using github-action-benchmark.
ebe293b
to
aa99508
Compare
24d1d1f
to
2b00266
Compare
e561b85
to
a739247
Compare
8331b57
to
9eb4def
Compare
9eb4def
to
1da2e52
Compare
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state | ||
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these | ||
issues in most cases and throw an error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state | |
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these | |
issues in most cases and throw an error. | |
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state | |
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these | |
issues in most cases and throw an error. |
Compiled ReverseDiff for training on CPU
Fixes #642
For performance SciML/ADTypes.jl#63 will be needed