Skip to content

Commit

Permalink
test: add tests for the new macros
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 20, 2024
1 parent d5a34e4 commit 59840df
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions test/helpers/compact_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,60 @@
∂x2 = only(Zygote.gradient(x -> sum(first(model(x, ps, st))), -2.0))
@test ∂x2 === nothing
end

@testset "Init Functions" begin
model = @compact(; a=@init_fn(rng->randn(rng, 3, 2))) do x
@return a * x
end

ps, st = Lux.setup(rng, model) |> device
@test ps.a isa AbstractMatrix
@test size(ps.a) == (3, 2)

x = ones(2, 10) |> aType
y, _ = model(x, ps, st)
@test y isa AbstractMatrix
@test size(y) == (3, 10)

model = @compact(; a=@init_fn(rng->randn(rng, 3, 2), :parameter),
b=2, c=@init_fn(rng->randn(rng, 3), :state)) do x
@return a * x
end

ps, st = Lux.setup(rng, model) |> device
@test ps.a isa AbstractMatrix && st.b isa Number && st.c isa AbstractVector

x = ones(2, 10) |> aType
y, _ = model(x, ps, st)
@test y isa AbstractMatrix
@test size(y) == (3, 10)

@testset "Error Checks" begin
# This should work
model = @compact(; a=@init_fn(rng->randn(rng, 3, 2), parameter)) do x
@return a * x
end

# This should not work
@test_throws ArgumentError @macroexpand(@init_fn(rng->randn(rng, 3, 2),
param))
end
end

@testset "Non-Trainable" begin
model = @compact(; a=@non_trainable(randn(3, 2))) do x
@return a * x
end

ps, st = Lux.setup(rng, model) |> device
@test st.a isa AbstractMatrix
@test size(st.a) == (3, 2)

x = ones(2, 10) |> aType
y, _ = model(x, ps, st)
@test y isa AbstractMatrix
@test size(y) == (3, 10)
end
end
end

Expand Down

1 comment on commit 59840df

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 59840df Previous: d99d823 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3882.25 ns 3650.625 ns 1.06
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7106.666666666667 ns 7099.833333333333 ns 1.00
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20689 ns 20969 ns 0.99
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9690.1 ns 9688 ns 1.00
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8872.6 ns 8941.75 ns 0.99
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4422 ns 4464.625 ns 0.99
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1155.6573426573427 ns 1153.062937062937 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1103.4197530864199 ns 1104.0858895705521 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1169.1376811594203 ns 1188.625 ns 0.98
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1774.9833333333333 ns 1766.051724137931 ns 1.01
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.60225669957686 ns 178.99435825105783 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17222 ns 17262 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16711 ns 16691 ns 1.00
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 36699 ns 36799 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29125 ns 29359.5 ns 0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19957.5 ns 19826 ns 1.01
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17152 ns 17141.5 ns 1.00
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4308.142857142857 ns 4335.285714285715 ns 0.99
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3862.25 ns 3860.9375 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3942.375 ns 3921 ns 1.01
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4940.571428571428 ns 4937.714285714285 ns 1.00
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1663.1 ns 1662.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 40597578 ns 38380312 ns 1.06
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 58420422 ns 58080861 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 81996635 ns 75706671 ns 1.08
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 84719853 ns 88430576 ns 0.96
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 78243311 ns 72678861 ns 1.08
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12253538 ns 11629181 ns 1.05
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 7139073 ns 7074423 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7295300 ns 7237640 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7121320.5 ns 7031304 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 11957784 ns 9924675.5 ns 1.20
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6428859 ns 6376916 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 694576554 ns 685503650 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2544347623 ns 2532175381 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 144049664.5 ns 137312502 ns 1.05
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 799910013 ns 791018541 ns 1.01
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3396813511 ns 3118945691 ns 1.09
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 209331850 ns 187532696 ns 1.12
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 832924048 ns 653534251 ns 1.27
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2799907824 ns 2584662949 ns 1.08
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 140673147 ns 124578751.5 ns 1.13
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 174392387 ns 173451500 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 655041232 ns 669889358.5 ns 0.98
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 45398404.5 ns 45443081 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164957412 ns 164233714 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 640983277 ns 641495324 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29703983.5 ns 29662692.5 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 185859364 ns 209510938 ns 0.89
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 764621834 ns 732300671.5 ns 1.04
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 37578537 ns 35133655 ns 1.07
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1194862835 ns 1229955419 ns 0.97
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1879498836.5 ns 1870868488 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2346332159 ns 2357631843 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2642927488 ns 2548939820 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1834750459 ns 1854903249 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 331610709 ns 324822924.5 ns 1.02
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 332970921 ns 321525303 ns 1.04
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 326938939 ns 318220641 ns 1.03
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 350802967.5 ns 435306909 ns 0.81
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 12028952 ns 11851270 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18002968 ns 17992689 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19260830 ns 19131076 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23917481.5 ns 23860945 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18020134 ns 18004156 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1175374 ns 1149174.5 ns 1.02
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2068210 ns 2062229 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2080086.5 ns 2071025 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2083427 ns 2074542 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2069321 ns 2063622 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 201762.5 ns 195595 ns 1.03
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 293549 ns 292706 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 264285 ns 263472 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 364136 ns 363559 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 406996.5 ns 411138 ns 0.99
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 273742 ns 272924.5 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 406200 ns 405708 ns 1.00
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83246 ns 83085 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81412 ns 80541 ns 1.01
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81432 ns 81091 ns 1.00
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86733 ns 86131 ns 1.01
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104576 ns 104796 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 193197051 ns 186741939 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 327253734.5 ns 327447810 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 390124076 ns 397249133 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 460359190 ns 484679103 ns 0.95
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 365899084 ns 377393182 ns 0.97
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 341128190 ns 321862177 ns 1.06
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 44910463.5 ns 44743313 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 45017706 ns 44850549.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 44060014 ns 43924254.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 51951901 ns 53418487 ns 0.97
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 27897102 ns 28061921 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19603784 ns 19028106.5 ns 1.03
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19681809.5 ns 19552218 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23485714 ns 23488414 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24215234 ns 24184792 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19767029 ns 19687526 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6564517 ns 6493300 ns 1.01
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6553486 ns 6514744.5 ns 1.01
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6533774 ns 6500552 ns 1.01
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6524160 ns 6483934 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.