Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiled ReverseDiff for training on CPU #722

Merged
merged 6 commits into from
Jun 23, 2024
Merged

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Jun 22, 2024

Fixes #642

For performance SciML/ADTypes.jl#63 will be needed

  • Consolidated caching for parameters
  • Caching Parameters
    • Tracker
    • ReverseDiff
  • Compiled ReverseDiff
  • Type-Piracy of Tracker has been removed. @test_gradients might fail need to verify.

Copy link

codecov bot commented Jun 22, 2024

Codecov Report

Attention: Patch coverage is 93.33333% with 9 lines in your changes missing coverage. Please review.

Project coverage is 97.01%. Comparing base (66bd131) to head (9eb4def).

Current head 9eb4def differs from pull request most recent head 1da2e52

Please upload reports for the commit 1da2e52 to get more accurate results.

Files Patch % Lines
ext/LuxReverseDiffExt/training.jl 93.75% 4 Missing ⚠️
src/contrib/training.jl 20.00% 4 Missing ⚠️
src/utils.jl 92.85% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #722      +/-   ##
==========================================
- Coverage   97.32%   97.01%   -0.32%     
==========================================
  Files          52       53       +1     
  Lines        2653     2711      +58     
==========================================
+ Hits         2582     2630      +48     
- Misses         71       81      +10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 1da2e52 Previous: 72cfeec Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3642.4444444444443 ns 3665.625 ns 0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7193.5 ns 7216.916666666666 ns 1.00
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20929 ns 20528 ns 1.02
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9566.25 ns 9693 ns 0.99
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8800.4 ns 9036.875 ns 0.97
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4482.125 ns 4573.625 ns 0.98
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1151.5314685314686 ns 1136.840579710145 ns 1.01
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1163.8613138686133 ns 1149.4418604651162 ns 1.01
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1187.9358974358975 ns 1197.1639344262296 ns 0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1773.6842105263158 ns 1774.3684210526317 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 178.95827538247565 ns 180.0305980528512 ns 0.99
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17142 ns 17322 ns 0.99
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16771 ns 16971 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37361 ns 39253 ns 0.95
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29345 ns 29565 ns 0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19888 ns 20178 ns 0.99
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17092 ns 17172 ns 1.00
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4301.75 ns 4349.571428571428 ns 0.99
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3826 ns 3849.75 ns 0.99
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3949.875 ns 3928.625 ns 1.01
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4998 ns 4852 ns 1.03
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1649.1 ns 1656.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 38630810.5 ns 39770040 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57506079 ns 57632628.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 75662796 ns 76488023 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 88344391 ns 89018322 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 70400778 ns 72909138.5 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11546033 ns 11939374.5 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 17590212.5 ns 17821257 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 6963002 ns 6991541.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6935185 ns 6978706 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 9838443.5 ns 10166161 ns 0.97
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6383563 ns 6397739 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 715133916 ns 756350190 ns 0.95
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2515094642 ns 2555366182 ns 0.98
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 127064350 ns 141900399 ns 0.90
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 768458977 ns 776440989 ns 0.99
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2827206553 ns 2952053106 ns 0.96
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 194319154 ns 205837315.5 ns 0.94
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 700167308 ns 687802033.5 ns 1.02
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2462646593 ns 2441444521 ns 1.01
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 119767580.5 ns 132709456 ns 0.90
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 173197641 ns 174508592.5 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 638622575.5 ns 642835050.5 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 43574003 ns 34681759 ns 1.26
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 162193021 ns 183856278 ns 0.88
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 633881021.5 ns 640825595 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29618317 ns 30176411 ns 0.98
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 187271794 ns 187538992 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 752112090.5 ns 728794962 ns 1.03
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 35201017 ns 37966134.5 ns 0.93
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1168823578.5 ns 1263007358.5 ns 0.93
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1833508726.5 ns 1859371232 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2259723916 ns 2370919227 ns 0.95
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2482580887 ns 2548748419 ns 0.97
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1844324234 ns 1815330207 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 555200248 ns 558317841 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 311113812 ns 320621158 ns 0.97
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 310203294.5 ns 324372940 ns 0.96
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 339834501 ns 455318653.5 ns 0.75
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11858461 ns 12039979 ns 0.98
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17745505 ns 17936234 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19023835 ns 19094503 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23770123 ns 23827099 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17838621 ns 17970370.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1158364 ns 1164332 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5801078 ns 5815240.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2031433 ns 2060550.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2009011 ns 2038369.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2047093 ns 2079506 ns 0.98
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 197651 ns 202013.5 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 288421 ns 293620 ns 0.98
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 261706 ns 267050 ns 0.98
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 361944.5 ns 370955 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 405561 ns 412423 ns 0.98
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 271790 ns 276132.5 ns 0.98
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 400576.5 ns 408115 ns 0.98
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 82875 ns 83316 ns 0.99
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 80241 ns 81843 ns 0.98
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 80292 ns 82094 ns 0.98
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 85290.5 ns 87504 ns 0.97
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104317 ns 104496 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 192064862 ns 188734689.5 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 313776392.5 ns 328709072 ns 0.95
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 355826346 ns 397785865 ns 0.89
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 439716902 ns 459733566 ns 0.96
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 366480578.5 ns 377160653 ns 0.97
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 320362246.5 ns 327525716.5 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 97737540 ns 102501485 ns 0.95
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 42420914 ns 43867409 ns 0.97
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43021921 ns 43652376 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 57964764.5 ns 60057045 ns 0.97
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 26657659 ns 28379983.5 ns 0.94
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18608378 ns 18956091 ns 0.98
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19317586 ns 19563861 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 22943211 ns 23676947 ns 0.97
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 23872969.5 ns 24221720 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19406384 ns 19663414 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6357292 ns 6539789 ns 0.97
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6436783 ns 6517423.5 ns 0.99
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6420615 ns 6503627.5 ns 0.99
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6449739 ns 6510008 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

ext/LuxReverseDiffExt/training.jl Outdated Show resolved Hide resolved
ext/LuxReverseDiffExt/training.jl Outdated Show resolved Hide resolved
test/contrib/training_tests.jl Outdated Show resolved Hide resolved
test/contrib/training_tests.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal changed the title [WIP] Compiled ReverseDiff for training on CPU Compiled ReverseDiff for training on CPU Jun 23, 2024
@avik-pal avik-pal force-pushed the ap/compile_reversediff branch 3 times, most recently from 8331b57 to 9eb4def Compare June 23, 2024 02:33
@avik-pal avik-pal merged commit 226563f into main Jun 23, 2024
11 of 14 checks passed
@avik-pal avik-pal deleted the ap/compile_reversediff branch June 23, 2024 03:51
Comment on lines +123 to +125
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these
issues in most cases and throw an error.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these
issues in most cases and throw an error.
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these
issues in most cases and throw an error.

avik-pal added a commit that referenced this pull request Jun 23, 2024
Compiled ReverseDiff for training on CPU
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a compiled tape version for ReverseDiff
1 participant