From a8fcd1f2e10df20c70e32f3c180afa951a84e437 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 9 Oct 2024 15:20:31 -0400 Subject: [PATCH] test: more tests got fixed --- test/helpers/loss_tests.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index 2435adcb9..9ef21d91d 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -91,11 +91,7 @@ end @jet MSLELoss()(ŷ, y) - if VERSION ≥ v"1.11-" - @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any - else - @test_broken @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any - end + @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu __f = Base.Fix2(MSLELoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -343,7 +339,7 @@ end @test Lux.PoissonLoss()(y, y) ≈ 0.5044459776946685 @jet Lux.PoissonLoss()(ŷ, y) - @test_broken @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) + @test @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) isa Any __f = Base.Fix2(Lux.PoissonLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -357,7 +353,7 @@ end @test DiceCoeffLoss()(y, y) ≈ 0.0 @jet DiceCoeffLoss()(ŷ, y) - @test_broken @inferred Zygote.gradient(DiceCoeffLoss(), ŷ, y) + @test @inferred(Zygote.gradient(DiceCoeffLoss(), ŷ, y)) isa Any broken=true __f = Base.Fix2(DiceCoeffLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3,