Skip to content

Commit

Permalink
test: more tests got fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 9, 2024
1 parent 8d14e0c commit a8fcd1f
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions test/helpers/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ end

@jet MSLELoss()(ŷ, y)

if VERSION v"1.11-"
@test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any
else
@test_broken @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any
end
@test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu

__f = Base.Fix2(MSLELoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
Expand Down Expand Up @@ -343,7 +339,7 @@ end
@test Lux.PoissonLoss()(y, y) 0.5044459776946685

@jet Lux.PoissonLoss()(ŷ, y)
@test_broken @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y)
@test @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) isa Any

__f = Base.Fix2(Lux.PoissonLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
Expand All @@ -357,7 +353,7 @@ end
@test DiceCoeffLoss()(y, y) 0.0

@jet DiceCoeffLoss()(ŷ, y)
@test_broken @inferred Zygote.gradient(DiceCoeffLoss(), ŷ, y)
@test @inferred(Zygote.gradient(DiceCoeffLoss(), ŷ, y)) isa Any broken=true

__f = Base.Fix2(DiceCoeffLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3,
Expand Down

0 comments on commit a8fcd1f

Please sign in to comment.