From d95828386e3a149bf4e42eb591164da810d96c8f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 9 Oct 2024 15:22:12 -0400 Subject: [PATCH] fix: remove old LossFunctions.jl dispatches --- ext/LuxReactantExt/LuxReactantExt.jl | 5 +---- ext/LuxReactantExt/overrides.jl | 24 ------------------------ 2 files changed, 1 insertion(+), 28 deletions(-) delete mode 100644 ext/LuxReactantExt/overrides.jl diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 9d41fa53e..ce0e0cd06 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -1,17 +1,14 @@ module LuxReactantExt using Enzyme: Enzyme, Const, Duplicated, Active -using LossFunctions: LossFunctions using Optimisers: Optimisers using Reactant: Reactant, @compile, TracedRArray using Setfield: @set! using Static: False -using Statistics: mean -using Lux: Lux, LuxOps, Training, LossFunctionImpl +using Lux: Lux, LuxOps, Training using Lux.Training: TrainingBackendCache, ReactantBackend -include("overrides.jl") include("training.jl") end diff --git a/ext/LuxReactantExt/overrides.jl b/ext/LuxReactantExt/overrides.jl deleted file mode 100644 index 3044fabf5..000000000 --- a/ext/LuxReactantExt/overrides.jl +++ /dev/null @@ -1,24 +0,0 @@ -# Loss Functions -for fnType in (typeof(sum), Any, typeof(mean)) - @eval begin - function LossFunctionImpl.fused_agg( - fn::$(fnType), lfn::LossFunctions.HuberLoss{T1}, x::TracedRArray{T2, N}, - y::TracedRArray{T3, N}) where {T1, T2, T3, N} - T = promote_type(T1, T2, T3) - delta = T(lfn.d) - diff = x .- y - abs_diff = abs.(diff) - quadratic = abs2.(diff) ./ 2 - linear = (delta .* abs_diff) .- T(0.5) .* abs2(delta) - return fn(ifelse.(abs_diff .≤ delta, quadratic, linear)) - end - - function LossFunctionImpl.fused_agg( - fn::$(fnType), ::LossFunctions.L2HingeLoss, x::TracedRArray{T1, N}, - y::TracedRArray{T2, N}) where {T1, T2, N} - T = promote_type(T1, T2) - agreement = x .* y - return fn(ifelse.(agreement .≥ ones(T), zero(T), abs2.(one(T) .- agreement))) - end - end -end