Skip to content

Commit

Permalink
Use generated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 26, 2024
1 parent cf2b9a2 commit a8f04bb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
24 changes: 11 additions & 13 deletions src/fast_broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ end
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
struct NotaNumber <: Real end

@inline fast_apply_activation!!(::typeof(identity), x::AbstractArray) = x
@inline fast_apply_activation!!(::typeof(sigmoid_fast), x::AbstractArray) = sigmoid_fast.(x)
@inline fast_apply_activation!!(::typeof(sigmoid), x::AbstractArray) = sigmoid.(x)
@inline fast_apply_activation!!(::typeof(tanh_fast), x::AbstractArray) = tanh_fast.(x)
@inline function fast_apply_activation!!(f::F, x::AbstractArray) where {F}
return __fast_apply_activation_impl!!(f, x)
# Without @generated sometimes we get runtime dispatch
@generated function fast_apply_activation!!(f::F, x::AbstractArray) where {F}
F == typeof(identity) && return :(x)
F in (typeof(sigmoid_fast), typeof(sigmoid), typeof(tanh_fast)) && return :(f.(x))
return :(__fast_apply_activation_impl!!(f, x))
end

@inline function __fast_apply_activation_impl!!(f::F, x::AbstractArray) where {F}
Expand All @@ -43,13 +42,12 @@ function CRC.rrule(
end

# Bias Activation Fused
function fast_bias_activation!!(f::F, x::AbstractArray, b::AbstractArray) where {F}
f === identity && return fast_broadcast!!(+, x, b)
return __fast_bias_activation_impl!!(f, x, b)
end
## Don't dispatch on GPUArray else it wont cant struct of arrays like TrackedArray
function fast_bias_activation!!(::typeof(sigmoid_fast), x::AbstractArray, b::AbstractArray)
return __fast_bias_activation_impl!!(sigmoid, x, b)
@generated function fast_bias_activation!!(
f::F, x::AbstractArray, b::AbstractArray) where {F}
F == typeof(identity) && return :(fast_broadcast!!(+, x, b))
## Don't dispatch on GPUArray else it wont cant struct of arrays like TrackedArray
F == typeof(sigmoid_fast) && return :(fast_bias_activation!!(sigmoid, x, b))
return :(__fast_bias_activation_impl!!(f, x, b))
end

function __fast_bias_activation_impl!!(f::F, x::AbstractArray, b::AbstractArray) where {F}
Expand Down
6 changes: 3 additions & 3 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@testset "Named Tuple Parameters" begin
@test_nowarn test_f(x, ps)

@test_broken begin
@test begin
y, back = Zygote.pullback(test_f, x, ps)
∂x, ∂ps = back(one(y))
∂x !== nothing && ∂ps !== nothing
Expand All @@ -35,13 +35,13 @@
@testset "Component Array Parameters" begin
@test_nowarn test_f(x, ps_ca)

@test_broken begin
@test begin
y, back = Zygote.pullback(test_f, x, ps_ca)
∂x, ∂ps = back(one(y))
∂x !== nothing && ∂ps !== nothing
end

@test_broken begin
@test begin
∂x, ∂ps = Zygote.jacobian(test_f, x, ps_ca)
∂x !== nothing && ∂ps !== nothing
end
Expand Down

0 comments on commit a8f04bb

Please sign in to comment.