Skip to content

Commit

Permalink
Use generated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 26, 2024
1 parent cf2b9a2 commit 3622541
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions src/fast_broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ end
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
struct NotaNumber <: Real end

@inline fast_apply_activation!!(::typeof(identity), x::AbstractArray) = x
@inline fast_apply_activation!!(::typeof(sigmoid_fast), x::AbstractArray) = sigmoid_fast.(x)
@inline fast_apply_activation!!(::typeof(sigmoid), x::AbstractArray) = sigmoid.(x)
@inline fast_apply_activation!!(::typeof(tanh_fast), x::AbstractArray) = tanh_fast.(x)
@inline function fast_apply_activation!!(f::F, x::AbstractArray) where {F}
return __fast_apply_activation_impl!!(f, x)
# Without @generated sometimes we get runtime dispatch
@generated function fast_apply_activation!!(f::F, x::AbstractArray) where {F}
F == typeof(identity) && return :(x)
F in (typeof(sigmoid_fast), typeof(sigmoid), typeof(tanh_fast)) && return :(f.(x))
return :(__fast_apply_activation_impl!!(f, x))
end

@inline function __fast_apply_activation_impl!!(f::F, x::AbstractArray) where {F}
Expand All @@ -43,13 +42,12 @@ function CRC.rrule(
end

# Bias Activation Fused
function fast_bias_activation!!(f::F, x::AbstractArray, b::AbstractArray) where {F}
f === identity && return fast_broadcast!!(+, x, b)
return __fast_bias_activation_impl!!(f, x, b)
end
## Don't dispatch on GPUArray else it wont cant struct of arrays like TrackedArray
function fast_bias_activation!!(::typeof(sigmoid_fast), x::AbstractArray, b::AbstractArray)
return __fast_bias_activation_impl!!(sigmoid, x, b)
@generated function fast_bias_activation!!(
f::F, x::AbstractArray, b::AbstractArray) where {F}
F == typeof(identity) && return :(fast_broadcast!!(+, x, b))
## Don't dispatch on GPUArray else it wont cant struct of arrays like TrackedArray
F == typeof(sigmoid_fast) && return :(fast_bias_activation!!(sigmoid, x, b))
return :(__fast_bias_activation_impl!!(f, x, b))
end

function __fast_bias_activation_impl!!(f::F, x::AbstractArray, b::AbstractArray) where {F}
Expand Down

0 comments on commit 3622541

Please sign in to comment.