diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index fd947497a..1af021e4c 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -120,7 +120,8 @@ function huber_loss(x::T1, y::T2, δ::T3) where {T1, T2, T3} T = promote_type(T1, T2, T3) diff = x - y abs_diff = abs(diff) - return ifelse(abs_diff ≤ δ, T(0.5) * abs2(diff), δ * (abs_diff - T(0.5) * δ)) + return ifelse( + abs_diff ≤ δ, convert(T, 0.5) * abs2(diff), δ * (abs_diff - convert(T, 0.5) * δ)) end has_custom_derivative(::typeof(huber_loss)) = true function derivative(::typeof(huber_loss), x::T, y::T2, δ::T3) where {T, T2, T3} diff --git a/test/reactant/training_tests.jl b/test/reactant/training_tests.jl index 427f3a171..b3b27969c 100644 --- a/test/reactant/training_tests.jl +++ b/test/reactant/training_tests.jl @@ -13,8 +13,6 @@ Reactant.set_default_backend("cpu") end - # TODO: Test for compute_gradients - xdev = xla_device(; force=true) @testset "MLP Training: $(version)" for version in (:iip, :oop) @@ -41,17 +39,13 @@ train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) - # FIXME: Use MSELoss <-- currently fails due to Enzyme - function sse(model, ps, st, (x, y)) - z, stₙ = model(x, ps, st) - return sum(abs2, z .- y), stₙ, (;) - end - for epoch in 1:100, (xᵢ, yᵢ) in dataloader grads, loss, stats, train_state = if version === :iip - Training.single_train_step!(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state) + Training.single_train_step!( + AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) elseif version === :oop - Training.single_train_step(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state) + Training.single_train_step( + AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) else error("Invalid version: $(version)") end @@ -64,7 +58,5 @@ @test total_final_loss < 100 * total_initial_loss end - - # TODO: Training a CNN end end