Skip to content

Commit

Permalink
test: warp @inferred with @test
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 21, 2024
1 parent 4dc688f commit f2df403
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 31 deletions.
3 changes: 2 additions & 1 deletion test/contrib/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ end
_, _, _, tstate_new = @inferred Lux.Experimental.compute_gradients(
AutoEnzyme(), mse, (x, x), tstate)

@inferred Lux.Experimental.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)
@test @inferred(Lux.Experimental.compute_gradients(
AutoEnzyme(), mse, (x, x), tstate_new)) isa Any

_, _, _, tstate_new2 = @inferred Lux.Experimental.compute_gradients(
AutoEnzyme(), mse2, (x, x), tstate_new)
Expand Down
4 changes: 2 additions & 2 deletions test/helpers/compact_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,10 @@
_, st_new = model(x, ps, st_new)
@test st_new.incr == 100

@inferred model(x, ps, st)
@test @inferred(model(x, ps, st)) isa Any

__f = (m, x, ps, st) -> sum(abs2, first(m(x, ps, st)))
@inferred Zygote.gradient(__f, model, x, ps, st)
@test @inferred(Zygote.gradient(__f, model, x, ps, st)) isa Any
end

@testset "Multiple @return" begin
Expand Down
37 changes: 15 additions & 22 deletions test/helpers/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
∂x2 = Zygote.gradient(xlogx, 2.0)[1]
@test ∂x1 ∂x2

@inferred xlogx(2)
@inferred xlogx(0)
@test @inferred(xlogx(2)) isa Number
@test @inferred(xlogx(0)) isa Number
@jet xlogx(2)

@test iszero(xlogy(0, 1))
Expand All @@ -27,11 +27,12 @@
@test ∂x1 ∂x2 ∂x3
@test ∂y1 ∂y2 ∂y3

@inferred xlogy(2, 3)
@inferred xlogy(0, 1)
@test @inferred(xlogy(2, 3)) isa Number
@test @inferred(xlogy(0, 1)) isa Number
@jet xlogy(2, 3)

@inferred Enzyme.autodiff(Enzyme.Reverse, xlogy, Active, Active(2.0), Active(3.0))
@test @inferred(Enzyme.autodiff(
Enzyme.Reverse, xlogy, Active, Active(2.0), Active(3.0))) isa Any

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
x = rand(10) |> aType
Expand Down Expand Up @@ -94,7 +95,7 @@ end
@test loss_sum(ŷ, y) loss_res * 4
@test loss_sum2(ŷ, y) loss_res * 4

@inferred Zygote.gradient(loss_mean, ŷ, y)
@test @inferred(Zygote.gradient(loss_mean, ŷ, y)) isa Any

@jet loss_mean(ŷ, y)
@jet loss_sum(ŷ, y)
Expand Down Expand Up @@ -172,7 +173,7 @@ end
@jet celoss(ŷ, y)
@jet celoss_smooth(ŷ, y)

@inferred Zygote.gradient(celoss, ŷ, y)
@test @inferred(Zygote.gradient(celoss, ŷ, y)) isa Any

__f = Base.Fix2(celoss, y)
!ongpu && test_enzyme_gradient(__f, ŷ)
Expand All @@ -196,7 +197,7 @@ end
@jet logitceloss(logŷ, y)
@jet logitceloss_smooth(logŷ, y)

@inferred Zygote.gradient(logitceloss, logŷ, y)
@test @inferred(Zygote.gradient(logitceloss, logŷ, y)) isa Any

__f = Base.Fix2(logitceloss, y)
!ongpu && test_enzyme_gradient(__f, logŷ)
Expand Down Expand Up @@ -225,7 +226,7 @@ end
@jet bceloss(σ.(logŷ), y)
@jet bceloss_smooth(σ.(logŷ), y)

@inferred Zygote.gradient(bceloss, σ.(logŷ), y)
@test @inferred(Zygote.gradient(bceloss, σ.(logŷ), y)) isa Any

__f = Base.Fix2(bceloss, y)
σlogŷ = σ.(logŷ)
Expand All @@ -248,7 +249,7 @@ end
@jet logitbceloss(logŷ, y)
@jet logitbceloss_smooth(logŷ, y)

@inferred Zygote.gradient(logitbceloss, logŷ, y)
@test @inferred(Zygote.gradient(logitbceloss, logŷ, y)) isa Any

__f = Base.Fix2(logitbceloss, y)
!ongpu && test_enzyme_gradient(__f, logŷ)
Expand All @@ -272,11 +273,7 @@ end

@jet BinaryFocalLoss()(ŷ, y)

if ongpu
@test_broken @inferred Zygote.gradient(BinaryFocalLoss(), ŷ, y)
else
@inferred Zygote.gradient(BinaryFocalLoss(), ŷ, y)
end
@test @inferred(Zygote.gradient(BinaryFocalLoss(), ŷ, y)) isa Any broken=ongpu

__f = Base.Fix2(BinaryFocalLoss(), y)
!ongpu && test_enzyme_gradient(__f, ŷ)
Expand All @@ -301,11 +298,7 @@ end

@jet FocalLoss()(ŷ, y)

if ongpu
@test_broken @inferred Zygote.gradient(FocalLoss(), ŷ, y)
else
@inferred Zygote.gradient(FocalLoss(), ŷ, y)
end
@test @inferred(Zygote.gradient(FocalLoss(), ŷ, y)) isa Any broken=ongpu

__f = Base.Fix2(FocalLoss(), y)
!ongpu && test_enzyme_gradient(__f, ŷ)
Expand All @@ -329,7 +322,7 @@ end
@test KLDivergenceLoss()(y, y) 0

@jet KLDivergenceLoss()(ŷ, y)
@inferred Zygote.gradient(KLDivergenceLoss(), ŷ, y)
@test @inferred(Zygote.gradient(KLDivergenceLoss(), ŷ, y)) isa Any

__f = Base.Fix2(KLDivergenceLoss(), y)
!ongpu && test_enzyme_gradient(__f, ŷ)
Expand All @@ -344,7 +337,7 @@ end
@test Lux.HingeLoss()(y, 0.5 .* y) 0.125

@jet Lux.HingeLoss()(ŷ, y)
@inferred Zygote.gradient(Lux.HingeLoss(), ŷ, y)
@test @inferred(Zygote.gradient(Lux.HingeLoss(), ŷ, y)) isa Any

__f = Base.Fix2(Lux.HingeLoss(), y)
!ongpu && test_enzyme_gradient(__f, ŷ)
Expand Down
4 changes: 2 additions & 2 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
x = randn(rng, 2, 3) |> aType

@test layer(x, ps, st)[1] x .* x
@inferred layer(x, ps, st)
@test @inferred(layer(x, ps, st)) isa Any

f12(x, ps, st) = x .+ 1, st

Expand All @@ -117,7 +117,7 @@
x = randn(rng, 2, 3) |> aType

@test layer(x, ps, st)[1] x .+ 1
@inferred layer(x, ps, st)
@test @inferred(layer(x, ps, st)) isa Any
end

@testset "PeriodicEmbedding" begin
Expand Down
9 changes: 5 additions & 4 deletions test/layers/type_stability_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@
ps, st = Lux.setup(rng, model) |> dev
x = input |> dev

@inferred model(x, ps, st)
@inferred loss_function(model, x, ps, st)
@test @inferred(model(x, ps, st)) isa Any
@test @inferred(loss_function(model, x, ps, st)) isa Any
if mode == "amdgpu" && (model isa Conv || model isa CrossCor)
@test_broken @inferred Zygote.gradient(loss_function, model, x, ps, st)
@test_broken @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa
Any
else
@inferred Zygote.gradient(loss_function, model, x, ps, st)
@test @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa Any
end
end
end
Expand Down

1 comment on commit f2df403

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: f2df403 Previous: 59840df Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3654.375 ns 3882.25 ns 0.94
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7153.357142857143 ns 7106.666666666667 ns 1.01
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20799 ns 20689 ns 1.01
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9912.5 ns 9690.1 ns 1.02
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8970.6 ns 8872.6 ns 1.01
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4464.625 ns 4422 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1166.0921985815603 ns 1155.6573426573427 ns 1.01
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1102.0060975609756 ns 1103.4197530864199 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1190.7692307692307 ns 1169.1376811594203 ns 1.02
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1804.78 ns 1774.9833333333333 ns 1.02
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 184.81372549019608 ns 179.60225669957686 ns 1.03
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17232 ns 17222 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16872 ns 16711 ns 1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37210 ns 36699 ns 1.01
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28413 ns 29125 ns 0.98
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20037 ns 19957.5 ns 1.00
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17343 ns 17152 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4329.571428571428 ns 4308.142857142857 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3883.5 ns 3862.25 ns 1.01
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3951.8125 ns 3942.375 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4982.142857142857 ns 4940.571428571428 ns 1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1659.1 ns 1663.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 40947012 ns 40597578 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 58034514.5 ns 58420422 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 80730603 ns 81996635 ns 0.98
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 92327301 ns 84719853 ns 1.09
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 78411107 ns 78243311 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11940034.5 ns 12253538 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 7040440 ns 7139073 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7193096.5 ns 7295300 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7117252 ns 7121320.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 11651567 ns 11957784 ns 0.97
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6422658 ns 6428859 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 701030424 ns 694576554 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2590201983 ns 2544347623 ns 1.02
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 140190010.5 ns 144049664.5 ns 0.97
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 876551202 ns 799910013 ns 1.10
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3356633010 ns 3396813511 ns 0.99
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 197466482 ns 209331850 ns 0.94
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 692417063 ns 832924048 ns 0.83
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2718741579 ns 2799907824 ns 0.97
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 140643582 ns 140673147 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 175950163.5 ns 174392387 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 652447286 ns 655041232 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 42722902.5 ns 45398404.5 ns 0.94
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 165562400 ns 164957412 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 646079758.5 ns 640983277 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30503986 ns 29703983.5 ns 1.03
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 216087564.5 ns 185859364 ns 1.16
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 758513421.5 ns 764621834 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 38353748.5 ns 37578537 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1219847712.5 ns 1194862835 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1887996340 ns 1879498836.5 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2447955597 ns 2346332159 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2626651116 ns 2642927488 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1934007715 ns 1834750459 ns 1.05
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 335642586.5 ns 331610709 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 332570295 ns 332970921 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 329527060 ns 326938939 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 374730028.5 ns 350802967.5 ns 1.07
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11990295.5 ns 12028952 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18049600 ns 18002968 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19367354 ns 19260830 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 24028261 ns 23917481.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18064582 ns 18020134 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1230918 ns 1175374 ns 1.05
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2101198.5 ns 2068210 ns 1.02
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2102362 ns 2080086.5 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2105777 ns 2083427 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2098950 ns 2069321 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 214702 ns 201762.5 ns 1.06
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 296645.5 ns 293549 ns 1.01
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 266860 ns 264285 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 369252 ns 364136 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 412393 ns 406996.5 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 276438 ns 273742 ns 1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 414357 ns 406200 ns 1.02
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83792.5 ns 83246 ns 1.01
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81814 ns 81412 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 82705 ns 81432 ns 1.02
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87064 ns 86733 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104616 ns 104576 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 192507909.5 ns 193197051 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 323234152 ns 327253734.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 413620121 ns 390124076 ns 1.06
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 460075216.5 ns 460359190 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 388836774 ns 365899084 ns 1.06
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 355864946 ns 341128190 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 44389027 ns 44910463.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 44440540.5 ns 45017706 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 44195892 ns 44060014 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 60815754 ns 51951901 ns 1.17
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28722749 ns 27897102 ns 1.03
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19205844.5 ns 19603784 ns 0.98
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19859609 ns 19681809.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23705159.5 ns 23485714 ns 1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24353181.5 ns 24215234 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19858547 ns 19767029 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6604286 ns 6564517 ns 1.01
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6592334.5 ns 6553486 ns 1.01
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6529285 ns 6533774 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6570212.5 ns 6524160 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.