From 77eb5fbf4b12f4597ee4324f8871818c9eb3552e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 9 Oct 2024 17:06:54 -0400 Subject: [PATCH] refactor: make `LossFunctions` an optional dep (#976) * refactor: make LossFunctions an optional dep * feat: add custom derivative fast paths * test: more tests got fixed --- Project.toml | 5 +- ext/LuxLossFunctionsExt.jl | 71 ++++++++++++++++++ src/Lux.jl | 1 - src/helpers/losses.jl | 150 +++++++++++++++++++++++++------------ test/helpers/loss_tests.jl | 10 +-- 5 files changed, 180 insertions(+), 57 deletions(-) create mode 100644 ext/LuxLossFunctionsExt.jl diff --git a/Project.toml b/Project.toml index fcaa5ed7c..00435a516 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.1.0" +version = "1.2.0-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -18,7 +18,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -43,6 +42,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" @@ -55,6 +55,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" LuxComponentArraysExt = "ComponentArrays" LuxEnzymeExt = "Enzyme" LuxFluxExt = "Flux" +LuxLossFunctionsExt = "LossFunctions" LuxMLUtilsExt = "MLUtils" LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] diff --git a/ext/LuxLossFunctionsExt.jl b/ext/LuxLossFunctionsExt.jl new file mode 100644 index 000000000..ef4bddd5f --- /dev/null +++ b/ext/LuxLossFunctionsExt.jl @@ -0,0 +1,71 @@ +module LuxLossFunctionsExt + +using ArrayInterface: fast_scalar_indexing +using ChainRulesCore: ChainRulesCore, NoTangent, @thunk +using EnzymeCore: EnzymeCore, EnzymeRules +using FastClosures: @closure +using LossFunctions: LossFunctions +using Statistics: mean + +using Lux: Lux, LossFunctionImpl + +const CRC = ChainRulesCore + +function LossFunctionImpl.fused_agg( + ::typeof(mean), lfn::LossFunctions.Traits.Loss, x::AbstractArray, y::AbstractArray) + return LossFunctionImpl.fused_agg(sum, lfn, x, y) / length(x) +end + +function LossFunctionImpl.fused_agg( + ::typeof(sum), lfn::LossFunctions.Traits.Loss, x::Number, y::Number) + return lfn(x, y) +end +function LossFunctionImpl.fused_agg( + ::typeof(sum), lfn::LossFunctions.Traits.Loss, x::AbstractArray, y::AbstractArray) + fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) + return sum(lfn.(x, y)) +end + +function CRC.rrule( + ::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(LossFunctionImpl.fused_agg), ::typeof(sum), + lfn::LossFunctions.Traits.Loss, x, y) + ∇fused_agg = @closure Δ -> begin + ∂x = @thunk LossFunctions.deriv.(Ref(lfn), x, y) .* Δ + return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent() + end + return LossFunctionImpl.fused_agg(sum, lfn, x, y), ∇fused_agg +end + +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(LossFunctionImpl.fused_agg)}, + ::Type{<:EnzymeCore.Active}, agg_f::EnzymeCore.Const{typeof(sum)}, + lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss}, + x::EnzymeCore.Annotation{<:AbstractArray}, y::EnzymeCore.Const) + primal = EnzymeRules.needs_primal(cfg) ? func.val(agg_f.val, lfn.val, x.val, y.val) : + nothing + + cache_x = EnzymeRules.overwritten(cfg)[4] ? copy(x.val) : nothing + cache_y = EnzymeRules.overwritten(cfg)[5] ? copy(y.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, nothing, (cache_x, cache_y)) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.RevConfigWidth{1}, + ::EnzymeCore.Const{typeof(LossFunctionImpl.fused_agg)}, + dret::EnzymeCore.Active, (cache_x, cache_y), agg_f::EnzymeCore.Const{typeof(sum)}, + lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss}, + x::EnzymeCore.Annotation{<:AbstractArray}, y::EnzymeCore.Const) + EnzymeRules.overwritten(cfg)[4] || (cache_x = x.val) + EnzymeRules.overwritten(cfg)[5] || (cache_y = y.val) + + if !(typeof(x) <: EnzymeCore.Const) + @. x.dval = LossFunctions.deriv(lfn.val, cache_x, cache_y) * dret.val + end + + return ntuple(Returns(nothing), 4) +end + +end diff --git a/src/Lux.jl b/src/Lux.jl index 972b8aa41..a9d1ac552 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -11,7 +11,6 @@ using ConcreteStructs: @concrete using FastClosures: @closure using Functors: Functors, fmap using GPUArraysCore: @allowscalar -using LossFunctions: LossFunctions using Markdown: @doc_str using NNlib: NNlib using Optimisers: Optimisers diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index a5bdaeceb..1f38c36ea 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -6,11 +6,11 @@ module LossFunctionImpl using ArrayInterface: fast_scalar_indexing using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable, @thunk -using EnzymeCore: EnzymeCore, EnzymeRules using FastClosures: @closure -using LossFunctions: LossFunctions +using ForwardDiff: ForwardDiff, Dual, Partials using Statistics: mean +using ..Utils: Utils using ..LuxOps: xlogy const CRC = ChainRulesCore @@ -30,59 +30,66 @@ check_sizes(_, __) = nothing # Aggregation. We are able to define custom aggregation fast paths fused_agg(::typeof(mean), op::OP, x) where {OP} = fused_agg(sum, op, x) / length(x) -function fused_agg(::typeof(mean), lfn::LossFunctions.Traits.Loss, x, y) - return fused_agg(sum, lfn, x, y) / length(x) -end fused_agg(::typeof(sum), op::OP, x::Number) where {OP} = op(x) fused_agg(::typeof(sum), op::OP, x) where {OP} = sum(op, x) -fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x::Number, y::Number) = lfn(x, y) -function fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) - fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) - # mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) leads to slowdowns, better to - # allocate a new array - return sum(lfn.(x, y)) +fused_agg(::typeof(mean), op::OP, x::Number, y::Number) where {OP} = op(x, y) +function fused_agg(::typeof(mean), op::OP, x::AbstractArray, y::AbstractArray) where {OP} + return fused_agg(sum, op, x, y) / length(x) end -fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...) -fused_agg(f::F, op::OP, args...) where {F, OP} = f(op.(args...)) - -function CRC.rrule(::typeof(fused_agg), ::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) - ∇fused_agg = @closure Δ -> begin - ∂x = @thunk LossFunctions.deriv.(Ref(lfn), x, y) .* Δ - return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent() +fused_agg(::typeof(sum), op::OP, x::Number, y::Number) where {OP} = op(x, y) +function fused_agg(::typeof(sum), op::OP, x::AbstractArray, y::AbstractArray) where {OP} + if fast_scalar_indexing(x) && fast_scalar_indexing(y) + res = Core.Compiler._return_type(op, Tuple{eltype(x), eltype(y)})(0) + @simd ivdep for i in eachindex(x, y) + @inbounds res += op(x[i], y[i]) + end + return res end - return fused_agg(sum, lfn, x, y), ∇fused_agg + return fallback_fused_agg(sum, op, x, y) end -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.RevConfigWidth{1}, func::EnzymeCore.Const{typeof(fused_agg)}, - ::Type{<:EnzymeCore.Active}, agg_f::EnzymeCore.Const{typeof(sum)}, - lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss}, - x::EnzymeCore.Annotation{<:AbstractArray}, y::EnzymeCore.Const) - primal = EnzymeRules.needs_primal(cfg) ? func.val(agg_f.val, lfn.val, x.val, y.val) : - nothing - - cache_x = EnzymeRules.overwritten(cfg)[4] ? copy(x.val) : nothing - cache_y = EnzymeRules.overwritten(cfg)[5] ? copy(y.val) : nothing +fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...) +fused_agg(f::F, op::OP, args...) where {F, OP} = fallback_fused_agg(f, op, args...) - return EnzymeRules.AugmentedReturn(primal, nothing, (cache_x, cache_y)) -end +@inline fallback_fused_agg(f::F, op::OP, args...) where {F, OP} = f(op.(args...)) -function EnzymeRules.reverse( - cfg::EnzymeRules.RevConfigWidth{1}, ::EnzymeCore.Const{typeof(fused_agg)}, - dret::EnzymeCore.Active, (cache_x, cache_y), agg_f::EnzymeCore.Const{typeof(sum)}, - lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss}, - x::EnzymeCore.Annotation{<:AbstractArray}, y::EnzymeCore.Const) - EnzymeRules.overwritten(cfg)[4] || (cache_x = x.val) - EnzymeRules.overwritten(cfg)[5] || (cache_y = y.val) +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(fused_agg), ::typeof(sum), op::OP, x, y) where {OP} + if has_custom_derivative(op) + res = fused_agg(sum, op, x, y) + ∇fused_agg_custom_derivative = Δ -> begin + ∂x = @thunk derivative.(Ref(op), x, y) .* Δ + return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent() + end + return res, ∇fused_agg_custom_derivative + end - if !(typeof(x) <: EnzymeCore.Const) - @. x.dval = LossFunctions.deriv(lfn.val, cache_x, cache_y) * dret.val + # Without custom derivatives use ForwardDiff for the looped implementation + if fast_scalar_indexing(x) && fast_scalar_indexing(y) + x_dual = Dual{ + Nothing, eltype(x), 1}.(x, (Partials{1, eltype(x)}((one(eltype(x)),)),)) + x_partials = similar(x) + T = eltype(x) + res = Core.Compiler._return_type(op, Tuple{T, eltype(y)})(0) + @inbounds @simd for i in eachindex(x_partials, x, y) + x_dual = Dual{Nothing, T, 1}(x[i], Partials{1, T}((one(T),))) + tmp = op(x_dual, y[i]) + x_partials[i] = ForwardDiff.partials(tmp, 1) + res += ForwardDiff.value(tmp) + end + ∇fused_agg_loop = Δ -> begin + @simd ivdep for i in eachindex(x_partials) + @inbounds x_partials[i] *= Δ + end + return NoTangent(), NoTangent(), NoTangent(), x_partials, NoTangent() + end + return res, ∇fused_agg_loop end - return ntuple(Returns(nothing), 4) + return CRC.rrule_via_ad(cfg, fallback_fused_agg, sum, op, x, y) end get_ϵ(::Type{T}, ϵ::Real) where {T} = T(ϵ) @@ -91,9 +98,57 @@ get_ϵ(::Type{T}, ::Nothing) where {T} = eps(float(T)) get_loss_dims(::AbstractVector) = Colon() get_loss_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1) +has_custom_derivative(::F) where {F} = false + +has_custom_derivative(f::Utils.Fix3) = has_custom_derivative(f.f) +derivative(f::Utils.Fix3, x, y) = derivative(f.f, x, y, f.x) + # Functional forms of losses +l1_distance_loss(x::T1, y::T2) where {T1, T2} = abs(x - y) +has_custom_derivative(::typeof(l1_distance_loss)) = true +function derivative(::typeof(l1_distance_loss), x::T1, y::T2) where {T1, T2} + return convert(T1, sign(x - y)) +end + +l2_distance_loss(x::T1, y::T2) where {T1, T2} = abs2(x - y) +has_custom_derivative(::typeof(l2_distance_loss)) = true +function derivative(::typeof(l2_distance_loss), x::T1, y::T2) where {T1, T2} + return convert(T1, 2 * (x - y)) +end + +function huber_loss(x::T1, y::T2, δ::T3) where {T1, T2, T3} + T = promote_type(T1, T2, T3) + diff = x - y + abs_diff = abs(diff) + return ifelse(abs_diff ≤ δ, T(0.5) * abs2(diff), δ * (abs_diff - T(0.5) * δ)) +end +has_custom_derivative(::typeof(huber_loss)) = true +function derivative(::typeof(huber_loss), x::T, y::T2, δ::T3) where {T, T2, T3} + diff = x - y + return ifelse(abs(diff) ≤ δ, T(diff), T(δ) * convert(T, sign(diff))) +end + +function l1_hinge_loss(x::T1, y::T2) where {T1, T2} + agreement = x * y + return max(oftype(agreement, false), true - agreement) +end +has_custom_derivative(::typeof(l1_hinge_loss)) = true +function derivative(::typeof(l1_hinge_loss), x::T1, y::T2) where {T1, T2} + return T1(ifelse(x * y ≥ 1, false, true)) +end + +function l2_hinge_loss(x::T1, y::T2) where {T1, T2} + agreement = x * y + return ifelse(agreement ≥ 1, oftype(agreement, false), abs2(true - agreement)) +end +has_custom_derivative(::typeof(l2_hinge_loss)) = true +function derivative(::typeof(l2_hinge_loss), x::T1, y::T2) where {T1, T2} + agreement = x * y + return T1(ifelse(agreement ≥ 1, false, 2 * (agreement - true))) +end + function siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2} - return (1 - y) * x^2 + y * max(promote_type(T1, T2)(0), margin - x)^2 + return (true - y) * x^2 + y * max(promote_type(T1, T2)(false), margin - x)^2 end poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} = x - xlogy(y, x + get_ϵ(T1, ϵ)) @@ -462,7 +517,7 @@ julia> loss(y_pred, y_true) ≈ 0.55 true ``` """ -HingeLoss(; agg=mean) = GenericLossFunction(LossFunctions.L1HingeLoss(); agg) +HingeLoss(; agg=mean) = GenericLossFunction(LossFunctionImpl.l1_hinge_loss; agg) @doc doc""" HuberLoss(; delta = 1, agg = mean) @@ -490,7 +545,8 @@ true """ function HuberLoss(; delta::Union{Nothing, AbstractFloat}=nothing, agg=mean) return GenericLossFunction( - LossFunctions.HuberLoss(ifelse(delta === nothing, Float16(1), delta)); agg) + Utils.Fix3(LossFunctionImpl.huber_loss, ifelse(delta === nothing, true, delta)); + agg) end @doc doc""" @@ -566,7 +622,7 @@ julia> loss(y_model, 1:3) ≈ 0.1 true ``` """ -MAELoss(; agg=mean) = GenericLossFunction(LossFunctions.L1DistLoss(); agg) +MAELoss(; agg=mean) = GenericLossFunction(LossFunctionImpl.l1_distance_loss; agg) const L1Loss = MAELoss @@ -588,7 +644,7 @@ julia> loss(y_model, 1:3) ≈ 0.01 true ``` """ -MSELoss(; agg=mean) = GenericLossFunction(LossFunctions.L2DistLoss(); agg) +MSELoss(; agg=mean) = GenericLossFunction(LossFunctionImpl.l2_distance_loss; agg) const L2Loss = MSELoss @@ -696,7 +752,7 @@ julia> loss(y_pred, y_true) ≈ 0.625 true ``` """ -SquaredHingeLoss(; agg=mean) = GenericLossFunction(LossFunctions.L2HingeLoss(); agg) +SquaredHingeLoss(; agg=mean) = GenericLossFunction(LossFunctionImpl.l2_hinge_loss; agg) @doc doc""" GenericLossFunction(loss_fn; agg = mean) diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index 2435adcb9..9ef21d91d 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -91,11 +91,7 @@ end @jet MSLELoss()(ŷ, y) - if VERSION ≥ v"1.11-" - @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any - else - @test_broken @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any - end + @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu __f = Base.Fix2(MSLELoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -343,7 +339,7 @@ end @test Lux.PoissonLoss()(y, y) ≈ 0.5044459776946685 @jet Lux.PoissonLoss()(ŷ, y) - @test_broken @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) + @test @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) isa Any __f = Base.Fix2(Lux.PoissonLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -357,7 +353,7 @@ end @test DiceCoeffLoss()(y, y) ≈ 0.0 @jet DiceCoeffLoss()(ŷ, y) - @test_broken @inferred Zygote.gradient(DiceCoeffLoss(), ŷ, y) + @test @inferred(Zygote.gradient(DiceCoeffLoss(), ŷ, y)) isa Any broken=true __f = Base.Fix2(DiceCoeffLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3,