Skip to content

Commit

Permalink
Add custom loss to reduce memory pressure
Browse files Browse the repository at this point in the history
  • Loading branch information
roflmaostc committed Jun 4, 2024
1 parent b392b4c commit dab965c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
27 changes: 25 additions & 2 deletions src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ end



my_power_four(x) = x^4

"""
LossThresholdSparsity(;sum_f=abs2, thresholds=(0.65f0, 0.75f0), λ=0.001f0)
Expand All @@ -82,7 +84,7 @@ struct LossThresholdSparsity{F, T, F2} <: LossTarget
thresholds::Tuple{T, T}
λ::T
sparsity_sum_f::F2
function LossThresholdSparsity(; sum_f=abs2, thresholds=(0.8f0, 0.9f0), λ=1f-9, sparsity_sum_f=x -> x^4)
function LossThresholdSparsity(; sum_f=abs2, thresholds=(0.8f0, 0.9f0), λ=1f-9, sparsity_sum_f=my_power_four)
return new{typeof(sum_f), typeof(thresholds[1]), typeof(sparsity_sum_f)}(sum_f, thresholds, λ, sparsity_sum_f)
end
end
Expand All @@ -96,6 +98,7 @@ function (l::LossThresholdSparsity)(x::AbstractArray{T}, target, patterns) where
end



"""
lesson learnt from this: don't to x[isobject] where isobject would be a boolean array.
rather express it with arithmetics. this is much faster
Expand All @@ -121,7 +124,8 @@ end
custom rules for the abs2 loss function (default).
no real speed gain but much less memory consumption
"""
function ChainRulesCore.rrule(l::LossThreshold{typeof(abs2), TT}, x::AbstractArray{T}, target, patterns) where {T, TT}
function ChainRulesCore.rrule(l::LossThreshold{typeof(abs2), TT}, x::AbstractArray{T},
target, patterns) where {T, TT}
res = l(x, target, patterns)
function pb(y)
y = unthunk(y)
Expand All @@ -132,3 +136,22 @@ function ChainRulesCore.rrule(l::LossThreshold{typeof(abs2), TT}, x::AbstractArr
end
return res, pb
end


"""
custom rules for the abs2 loss function (default).
no real speed gain but much less memory consumption
"""
function ChainRulesCore.rrule(l::LossThresholdSparsity{typeof(abs2), TT, typeof(my_power_four)},
x::AbstractArray{T}, target, patterns) where {T, TT}
res = l(x, target, patterns)
function pb(y)
y = unthunk(y)
g = @inbounds (2 .* y .* ((.- SwissVAMyKnife.NNlib.relu.(T(l.thresholds[2]) .- x) .* target) .+
(SwissVAMyKnife.NNlib.relu.(x .- Int(1)) .* target) .+
(SwissVAMyKnife.NNlib.relu.(x .- T(l.thresholds[1])) .* (1 .- target))))
b = @inbounds (4 .* y .* l.λ .* patterns.^3)
return NoTangent(), g, NoTangent(), b
end
return res, pb
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ end

@testset "test rrule of custom loss" begin
l = LossThreshold(sum_f=abs2, thresholds=(0.4, 0.94))
l2 = LossThresholdSparsity(sum_f=abs2, thresholds=(0.4, 0.94), λ=0.01)
x = randn((4,4,4))
x2 = randn((4,4,4))
target = x .> 0.5
test_rrule(l ChainRulesTestUtils.NoTangent(), x, target ChainRulesTestUtils.NoTangent(), x ChainRulesTestUtils.NoTangent())

test_rrule(l2 ChainRulesTestUtils.NoTangent(), x, target ChainRulesTestUtils.NoTangent(), x2)
end


Expand Down

0 comments on commit dab965c

Please sign in to comment.