Skip to content

Commit

Permalink
Merge pull request #659 from LuxDL/ap/restore_merge
Browse files Browse the repository at this point in the history
Restore the rrule for merge
  • Loading branch information
avik-pal authored May 19, 2024
2 parents 2a866be + 3d1e0c6 commit 50a7d90
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ CRC.@non_differentiable Base.printstyled(::Any...)
CRC.@non_differentiable fieldcount(::Any)

# Utilities
## DON'T REMOVE THIS CAUSES DOWNSTREAM FAILURES
function CRC.rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2}
y = merge(nt1, nt2)
function ∇merge(dy)
dnt1 = NamedTuple((f1 => (f1 in F2 ? NoTangent() : getproperty(dy, f1))
for f1 in F1))
dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2))
return (NoTangent(), dnt1, dnt2)
end
∇merge(::Union{NoTangent, ZeroTangent}) = (NoTangent(), NoTangent(), NoTangent())
return y, ∇merge
end

function CRC.rrule(::typeof(_eachslice), x, d::Val)
return _eachslice(x, d), @closure->(NoTangent(), ∇_eachslice(Δ, x, d), NoTangent()))
end
Expand Down

4 comments on commit 50a7d90

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 50a7d90 Previous: 2a866be Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3638.125 ns 3669.375 ns 0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7089.833333333333 ns 7280.333333333333 ns 0.97
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20438 ns 20599 ns 0.99
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9543.8 ns 9582 ns 1.00
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8842.6 ns 8914.6 ns 0.99
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4403.25 ns 4422 ns 1.00
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1199.107142857143 ns 1198.9590163934427 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1110.107594936709 ns 1112.5786163522012 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1185.8015267175572 ns 1180.9856115107914 ns 1.00
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1784.9298245614036 ns 1794.1864406779662 ns 0.99
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.2403846153846 ns 179.262341325811 ns 1.01
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17372 ns 17322 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17668.5 ns 17532 ns 1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 36588 ns 36779 ns 0.99
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28002 ns 28073 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19777 ns 19877 ns 0.99
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 16891 ns 16801 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4330.928571428572 ns 4333.857142857143 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3876 ns 3888.625 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3965 ns 3981.25 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4853.428571428572 ns 4990.714285714285 ns 0.97
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1656.1 ns 1662.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 49597894 ns 38952154 ns 1.27
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57758294.5 ns 57640902 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 81508052 ns 71710714.5 ns 1.14
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 106359049 ns 88813041.5 ns 1.20
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 89809751 ns 72723443 ns 1.23
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11922280 ns 11766883 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 17867350 ns 8476633 ns 2.11
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7025899 ns 7046540 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6987958.5 ns 7005955 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 12392999 ns 10530991 ns 1.18
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6382888 ns 6400580 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 703499940.5 ns 709679665 ns 0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2841193404 ns 2835972125 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 143564651 ns 160659965 ns 0.89
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 886876592 ns 752407029 ns 1.18
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2896203958 ns 2546120620 ns 1.14
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 203585797.5 ns 196895225 ns 1.03
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 681282715 ns 720159708 ns 0.95
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2862082835 ns 2731538703 ns 1.05
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 143102777 ns 124336484 ns 1.15
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 173810333 ns 172239449 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 650530971 ns 642436307.5 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 34529043 ns 45370320 ns 0.76
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164099994.5 ns 164713553.5 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 631711995 ns 639941245 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30330418.5 ns 30400796 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 229262989 ns 185024422.5 ns 1.24
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 897236420 ns 742411361 ns 1.21
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 40649604 ns 35336203 ns 1.15
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1194867773 ns 1221666322.5 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1850182045 ns 1870625278 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2028694919.5 ns 2167914981.5 ns 0.94
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2584282324 ns 2317427712.5 ns 1.12
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1855635723 ns 1790838711.5 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 552686649.5 ns 350970706 ns 1.57
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 320430812.5 ns 320997548.5 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 317906642.5 ns 321496032.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 454196399 ns 355099855 ns 1.28
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11783521 ns 11838794 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17922356.5 ns 17898707 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19164946 ns 19130086 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23928314 ns 23811780.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17993335 ns 17973841 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1159677 ns 1169676 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5882718 ns 2526472.5 ns 2.33
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2051093 ns 2058818.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2032137.5 ns 2040219 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2072117.5 ns 2075980 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 198215.5 ns 204222 ns 0.97
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 292988 ns 291976 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 266703.5 ns 266062.5 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 364792 ns 366104 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 409846 ns 409981 ns 1.00
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 273721 ns 276541.5 ns 0.99
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 407552 ns 408022 ns 1.00
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83877 ns 83165 ns 1.01
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 82404 ns 81222 ns 1.01
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 83606 ns 81722 ns 1.02
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87413.5 ns 87032 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104606 ns 104516 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 206158801.5 ns 188927480 ns 1.09
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 324418045.5 ns 324906661 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 407110724 ns 394040123 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 451569755 ns 479556483 ns 0.94
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 382491094 ns 372077273 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 334054331 ns 328660700.5 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 101678998.5 ns 51470201 ns 1.98
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43669107 ns 43859960 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43553003 ns 43858980 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 61705314 ns 59466808 ns 1.04
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28736076 ns 28625276 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18676289 ns 19049587 ns 0.98
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19453140 ns 19503862 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23179954.5 ns 23216947 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 23993664 ns 24035928 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19538431 ns 19591788 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6514322 ns 6531167 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6508011 ns 6514635 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6479377 ns 6490360 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6489111 ns 6506068.5 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/107198

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.51 -m "<description of version>" 50a7d90b84c90e3822d2f25852440ed9910fb08b
git push origin v0.5.51

Please sign in to comment.