From a8f04bb73c571c1f0695784dc3baddc120bbe64b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 26 Mar 2024 14:15:01 -0400 Subject: [PATCH] Use generated functions --- src/fast_broadcast.jl | 24 +++++++++++------------- test/misc_tests.jl | 6 +++--- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/fast_broadcast.jl b/src/fast_broadcast.jl index f249b6207..b13da503a 100644 --- a/src/fast_broadcast.jl +++ b/src/fast_broadcast.jl @@ -13,12 +13,11 @@ end # is independent of `x`, as `_return_type` says `Union{}` when calling is an error. struct NotaNumber <: Real end -@inline fast_apply_activation!!(::typeof(identity), x::AbstractArray) = x -@inline fast_apply_activation!!(::typeof(sigmoid_fast), x::AbstractArray) = sigmoid_fast.(x) -@inline fast_apply_activation!!(::typeof(sigmoid), x::AbstractArray) = sigmoid.(x) -@inline fast_apply_activation!!(::typeof(tanh_fast), x::AbstractArray) = tanh_fast.(x) -@inline function fast_apply_activation!!(f::F, x::AbstractArray) where {F} - return __fast_apply_activation_impl!!(f, x) +# Without @generated sometimes we get runtime dispatch +@generated function fast_apply_activation!!(f::F, x::AbstractArray) where {F} + F == typeof(identity) && return :(x) + F in (typeof(sigmoid_fast), typeof(sigmoid), typeof(tanh_fast)) && return :(f.(x)) + return :(__fast_apply_activation_impl!!(f, x)) end @inline function __fast_apply_activation_impl!!(f::F, x::AbstractArray) where {F} @@ -43,13 +42,12 @@ function CRC.rrule( end # Bias Activation Fused -function fast_bias_activation!!(f::F, x::AbstractArray, b::AbstractArray) where {F} - f === identity && return fast_broadcast!!(+, x, b) - return __fast_bias_activation_impl!!(f, x, b) -end -## Don't dispatch on GPUArray else it wont cant struct of arrays like TrackedArray -function fast_bias_activation!!(::typeof(sigmoid_fast), x::AbstractArray, b::AbstractArray) - return __fast_bias_activation_impl!!(sigmoid, x, b) +@generated function fast_bias_activation!!( + f::F, x::AbstractArray, b::AbstractArray) where {F} + F == typeof(identity) && return :(fast_broadcast!!(+, x, b)) + ## Don't dispatch on GPUArray else it wont cant struct of arrays like TrackedArray + F == typeof(sigmoid_fast) && return :(fast_bias_activation!!(sigmoid, x, b)) + return :(__fast_bias_activation_impl!!(f, x, b)) end function __fast_bias_activation_impl!!(f::F, x::AbstractArray, b::AbstractArray) where {F} diff --git a/test/misc_tests.jl b/test/misc_tests.jl index ade4ea936..524dd44ac 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -19,7 +19,7 @@ @testset "Named Tuple Parameters" begin @test_nowarn test_f(x, ps) - @test_broken begin + @test begin y, back = Zygote.pullback(test_f, x, ps) ∂x, ∂ps = back(one(y)) ∂x !== nothing && ∂ps !== nothing @@ -35,13 +35,13 @@ @testset "Component Array Parameters" begin @test_nowarn test_f(x, ps_ca) - @test_broken begin + @test begin y, back = Zygote.pullback(test_f, x, ps_ca) ∂x, ∂ps = back(one(y)) ∂x !== nothing && ∂ps !== nothing end - @test_broken begin + @test begin ∂x, ∂ps = Zygote.jacobian(test_f, x, ps_ca) ∂x !== nothing && ∂ps !== nothing end