diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index 261cd48d7..4f81fa2fe 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -168,7 +168,8 @@ end _, _, _, tstate_new = @inferred Lux.Experimental.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) - @inferred Lux.Experimental.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new) + @test @inferred(Lux.Experimental.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate_new)) isa Any _, _, _, tstate_new2 = @inferred Lux.Experimental.compute_gradients( AutoEnzyme(), mse2, (x, x), tstate_new) diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index 1eef8cddd..603795eee 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -355,10 +355,10 @@ _, st_new = model(x, ps, st_new) @test st_new.incr == 100 - @inferred model(x, ps, st) + @test @inferred(model(x, ps, st)) isa Any __f = (m, x, ps, st) -> sum(abs2, first(m(x, ps, st))) - @inferred Zygote.gradient(__f, model, x, ps, st) + @test @inferred(Zygote.gradient(__f, model, x, ps, st)) isa Any end @testset "Multiple @return" begin diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index 91c3edced..afc38b92d 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -10,8 +10,8 @@ ∂x2 = Zygote.gradient(xlogx, 2.0)[1] @test ∂x1 ≈ ∂x2 - @inferred xlogx(2) - @inferred xlogx(0) + @test @inferred(xlogx(2)) isa Number + @test @inferred(xlogx(0)) isa Number @jet xlogx(2) @test iszero(xlogy(0, 1)) @@ -27,11 +27,12 @@ @test ∂x1 ≈ ∂x2 ≈ ∂x3 @test ∂y1 ≈ ∂y2 ≈ ∂y3 - @inferred xlogy(2, 3) - @inferred xlogy(0, 1) + @test @inferred(xlogy(2, 3)) isa Number + @test @inferred(xlogy(0, 1)) isa Number @jet xlogy(2, 3) - @inferred Enzyme.autodiff(Enzyme.Reverse, xlogy, Active, Active(2.0), Active(3.0)) + @test @inferred(Enzyme.autodiff( + Enzyme.Reverse, xlogy, Active, Active(2.0), Active(3.0))) isa Any @testset "$mode" for (mode, aType, dev, ongpu) in MODES x = rand(10) |> aType @@ -94,7 +95,7 @@ end @test loss_sum(ŷ, y) ≈ loss_res * 4 @test loss_sum2(ŷ, y) ≈ loss_res * 4 - @inferred Zygote.gradient(loss_mean, ŷ, y) + @test @inferred(Zygote.gradient(loss_mean, ŷ, y)) isa Any @jet loss_mean(ŷ, y) @jet loss_sum(ŷ, y) @@ -172,7 +173,7 @@ end @jet celoss(ŷ, y) @jet celoss_smooth(ŷ, y) - @inferred Zygote.gradient(celoss, ŷ, y) + @test @inferred(Zygote.gradient(celoss, ŷ, y)) isa Any __f = Base.Fix2(celoss, y) !ongpu && test_enzyme_gradient(__f, ŷ) @@ -196,7 +197,7 @@ end @jet logitceloss(logŷ, y) @jet logitceloss_smooth(logŷ, y) - @inferred Zygote.gradient(logitceloss, logŷ, y) + @test @inferred(Zygote.gradient(logitceloss, logŷ, y)) isa Any __f = Base.Fix2(logitceloss, y) !ongpu && test_enzyme_gradient(__f, logŷ) @@ -225,7 +226,7 @@ end @jet bceloss(σ.(logŷ), y) @jet bceloss_smooth(σ.(logŷ), y) - @inferred Zygote.gradient(bceloss, σ.(logŷ), y) + @test @inferred(Zygote.gradient(bceloss, σ.(logŷ), y)) isa Any __f = Base.Fix2(bceloss, y) σlogŷ = σ.(logŷ) @@ -248,7 +249,7 @@ end @jet logitbceloss(logŷ, y) @jet logitbceloss_smooth(logŷ, y) - @inferred Zygote.gradient(logitbceloss, logŷ, y) + @test @inferred(Zygote.gradient(logitbceloss, logŷ, y)) isa Any __f = Base.Fix2(logitbceloss, y) !ongpu && test_enzyme_gradient(__f, logŷ) @@ -272,11 +273,7 @@ end @jet BinaryFocalLoss()(ŷ, y) - if ongpu - @test_broken @inferred Zygote.gradient(BinaryFocalLoss(), ŷ, y) - else - @inferred Zygote.gradient(BinaryFocalLoss(), ŷ, y) - end + @test @inferred(Zygote.gradient(BinaryFocalLoss(), ŷ, y)) isa Any broken=ongpu __f = Base.Fix2(BinaryFocalLoss(), y) !ongpu && test_enzyme_gradient(__f, ŷ) @@ -301,11 +298,7 @@ end @jet FocalLoss()(ŷ, y) - if ongpu - @test_broken @inferred Zygote.gradient(FocalLoss(), ŷ, y) - else - @inferred Zygote.gradient(FocalLoss(), ŷ, y) - end + @test @inferred(Zygote.gradient(FocalLoss(), ŷ, y)) isa Any broken=ongpu __f = Base.Fix2(FocalLoss(), y) !ongpu && test_enzyme_gradient(__f, ŷ) @@ -329,7 +322,7 @@ end @test KLDivergenceLoss()(y, y) ≈ 0 @jet KLDivergenceLoss()(ŷ, y) - @inferred Zygote.gradient(KLDivergenceLoss(), ŷ, y) + @test @inferred(Zygote.gradient(KLDivergenceLoss(), ŷ, y)) isa Any __f = Base.Fix2(KLDivergenceLoss(), y) !ongpu && test_enzyme_gradient(__f, ŷ) @@ -344,7 +337,7 @@ end @test Lux.HingeLoss()(y, 0.5 .* y) ≈ 0.125 @jet Lux.HingeLoss()(ŷ, y) - @inferred Zygote.gradient(Lux.HingeLoss(), ŷ, y) + @test @inferred(Zygote.gradient(Lux.HingeLoss(), ŷ, y)) isa Any __f = Base.Fix2(Lux.HingeLoss(), y) !ongpu && test_enzyme_gradient(__f, ŷ) diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index b4b642c70..47ce64fbf 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -107,7 +107,7 @@ x = randn(rng, 2, 3) |> aType @test layer(x, ps, st)[1] ≈ x .* x - @inferred layer(x, ps, st) + @test @inferred(layer(x, ps, st)) isa Any f12(x, ps, st) = x .+ 1, st @@ -117,7 +117,7 @@ x = randn(rng, 2, 3) |> aType @test layer(x, ps, st)[1] ≈ x .+ 1 - @inferred layer(x, ps, st) + @test @inferred(layer(x, ps, st)) isa Any end @testset "PeriodicEmbedding" begin diff --git a/test/layers/type_stability_tests.jl b/test/layers/type_stability_tests.jl index 2ef7a902c..55e89cd8e 100644 --- a/test/layers/type_stability_tests.jl +++ b/test/layers/type_stability_tests.jl @@ -74,12 +74,13 @@ ps, st = Lux.setup(rng, model) |> dev x = input |> dev - @inferred model(x, ps, st) - @inferred loss_function(model, x, ps, st) + @test @inferred(model(x, ps, st)) isa Any + @test @inferred(loss_function(model, x, ps, st)) isa Any if mode == "amdgpu" && (model isa Conv || model isa CrossCor) - @test_broken @inferred Zygote.gradient(loss_function, model, x, ps, st) + @test_broken @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa + Any else - @inferred Zygote.gradient(loss_function, model, x, ps, st) + @test @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa Any end end end