Skip to content

Commit

Permalink
test: try using MSELoss directly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 9, 2024
1 parent d958283 commit 38ed312
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
3 changes: 2 additions & 1 deletion src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ function huber_loss(x::T1, y::T2, δ::T3) where {T1, T2, T3}
T = promote_type(T1, T2, T3)
diff = x - y
abs_diff = abs(diff)
return ifelse(abs_diff δ, T(0.5) * abs2(diff), δ * (abs_diff - T(0.5) * δ))
return ifelse(
abs_diff δ, convert(T, 0.5) * abs2(diff), δ * (abs_diff - convert(T, 0.5) * δ))
end
has_custom_derivative(::typeof(huber_loss)) = true
function derivative(::typeof(huber_loss), x::T, y::T2, δ::T3) where {T, T2, T3}
Expand Down
16 changes: 4 additions & 12 deletions test/reactant/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
Reactant.set_default_backend("cpu")
end

# TODO: Test for compute_gradients

xdev = xla_device(; force=true)

@testset "MLP Training: $(version)" for version in (:iip, :oop)
Expand All @@ -41,17 +39,13 @@

train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

# FIXME: Use MSELoss <-- currently fails due to Enzyme
function sse(model, ps, st, (x, y))
z, stₙ = model(x, ps, st)
return sum(abs2, z .- y), stₙ, (;)
end

for epoch in 1:100, (xᵢ, yᵢ) in dataloader
grads, loss, stats, train_state = if version === :iip
Training.single_train_step!(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state)
Training.single_train_step!(
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
elseif version === :oop
Training.single_train_step(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state)
Training.single_train_step(
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
else
error("Invalid version: $(version)")
end
Expand All @@ -64,7 +58,5 @@

@test total_final_loss < 100 * total_initial_loss
end

# TODO: Training a CNN
end
end

0 comments on commit 38ed312

Please sign in to comment.