From 1c1d5bc77f77346df8a32b80255eb61f5a0a8ec0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 30 Mar 2024 00:40:24 -0400 Subject: [PATCH] Fuse the activation and bias --- Project.toml | 2 +- src/layers/basic.jl | 18 +++++++++--------- src/layers/conv.jl | 12 ++++++------ src/layers/normalize.jl | 8 ++++---- src/layers/recurrent.jl | 4 ++-- src/utils.jl | 12 ++++++++---- 6 files changed, 30 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 97911b16b..14463300a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.30" +version = "0.5.31" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2beae2121..de99ef1ff 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -211,29 +211,29 @@ statelength(d::Dense) = 0 outputsize(d::Dense) = (d.out_dims,) @inline function (d::Dense{false})(x::AbstractVecOrMat, ps, st::NamedTuple) - return __apply_activation(d.activation, ps.weight * x), st + return apply_activation(d.activation, ps.weight * x), st end @inline function (d::Dense{false})(x::AbstractArray, ps, st::NamedTuple) x_reshaped = reshape(x, size(x, 1), :) return ( - reshape(__apply_activation(d.activation, ps.weight * x_reshaped), + reshape(apply_activation(d.activation, ps.weight * x_reshaped), d.out_dims, size(x)[2:end]...), st) end @inline function (d::Dense{true})(x::AbstractVector, ps, st::NamedTuple) - return __apply_activation(d.activation, ps.weight * x .+ vec(ps.bias)), st + return apply_bias_activation(d.activation, ps.weight * x, vec(ps.bias)), st end @inline function (d::Dense{true})(x::AbstractMatrix, ps, st::NamedTuple) - return __apply_activation(d.activation, ps.weight * x .+ ps.bias), st + return apply_bias_activation(d.activation, ps.weight * x, ps.bias), st end @inline function (d::Dense{true})(x::AbstractArray, ps, st::NamedTuple) x_reshaped = reshape(x, size(x, 1), :) return ( - reshape(__apply_activation(d.activation, ps.weight * x_reshaped .+ ps.bias), + reshape(apply_bias_activation(d.activation, ps.weight * x_reshaped, ps.bias), d.out_dims, size(x)[2:end]...), st) end @@ -315,11 +315,11 @@ statelength(d::Scale) = 0 outputsize(d::Scale) = d.dims function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple) - return __apply_activation(d.activation, ps.weight .* x .+ ps.bias), st + return apply_bias_activation(d.activation, ps.weight .* x, ps.bias), st end function (d::Scale{false})(x::AbstractArray, ps, st::NamedTuple) - return __apply_activation(d.activation, ps.weight .* x), st + return apply_activation(d.activation, ps.weight .* x), st end """ @@ -431,9 +431,9 @@ function (b::Bilinear{use_bias})((x, y)::Tuple{<:AbstractVecOrMat, <:AbstractVec Wyx = reshape(batched_mul(Wy, reshape(x, (d_x, 1, :))), (d_z, :)) if use_bias - return __apply_activation(b.activation, Wyx .+ ps.bias), st + return apply_bias_activation(b.activation, Wyx, ps.bias), st else - return __apply_activation(b.activation, Wyx), st + return apply_activation(b.activation, Wyx), st end end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 0a0badd61..7476572ff 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -117,13 +117,13 @@ end @inline function (c::Conv{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N} cdims = DenseConvDims( x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups) - return __apply_activation(c.activation, _conv(x, ps.weight, cdims)), st + return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st end @inline function (c::Conv{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N} cdims = DenseConvDims( x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups) - return __apply_activation(c.activation, _conv(x, ps.weight, cdims) .+ ps.bias), st + return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st end function Base.show(io::IO, l::Conv) @@ -620,13 +620,13 @@ end @inline function (c::CrossCor{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N} cdims = DenseConvDims( DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true) - return __apply_activation(c.activation, _conv(x, ps.weight, cdims)), st + return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st end @inline function (c::CrossCor{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N} cdims = DenseConvDims( DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true) - return __apply_activation(c.activation, _conv(x, ps.weight, cdims) .+ ps.bias), st + return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st end function Base.show(io::IO, l::CrossCor) @@ -752,14 +752,14 @@ end x::AbstractArray, ps, st::NamedTuple) where {N} cdims = _conv_transpose_dims( x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) - return __apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims)), st + return apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims)), st end @inline function (c::ConvTranspose{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N} cdims = _conv_transpose_dims( x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) return ( - __apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims) .+ ps.bias), + apply_bias_activation(c.activation, _conv_transpose(x, ps.weight, cdims), ps.bias), st) end diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 9efbb73e2..38b587c5d 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -130,7 +130,7 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) @set! st.running_var = stats.running_var end - return __apply_activation(BN.activation, y), st + return apply_activation(BN.activation, y), st end function Base.show(io::IO, l::BatchNorm) @@ -229,7 +229,7 @@ parameterlength(l::GroupNorm) = _affine(l) ? (l.chs * 2) : 0 function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple) y = groupnorm(x, _getproperty(ps, Val(:scale)), _getproperty(ps, Val(:bias)); GN.groups, GN.epsilon) - return __apply_activation(GN.activation, y), st + return apply_activation(GN.activation, y), st end function Base.show(io::IO, l::GroupNorm) @@ -335,7 +335,7 @@ parameterlength(l::InstanceNorm) = _affine(l) ? (l.chs * 2) : 0 function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple) y, stats = instancenorm(x, _getproperty(ps, Val(:scale)), _getproperty(ps, Val(:bias)); IN.epsilon, st.training) - return __apply_activation(IN.activation, y), st + return apply_activation(IN.activation, y), st end function Base.show(io::IO, l::InstanceNorm) @@ -554,7 +554,7 @@ end function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) y = layernorm( x, _getproperty(ps, Val(:scale)), _getproperty(ps, Val(:bias)); l.dims, l.epsilon) - return __apply_activation(l.activation, y), st + return apply_activation(l.activation, y), st end function Base.show(io::IO, l::LayerNorm) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index dc9e1836c..15b145706 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -287,13 +287,13 @@ const _RNNCellInputType = Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}} function (rnn::RNNCell{true})((x, (hidden_state,))::_RNNCellInputType, ps, st::NamedTuple) h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias - h_new = __apply_activation(rnn.activation, h_new) + h_new = apply_activation(rnn.activation, h_new) return (h_new, (h_new,)), st end function (rnn::RNNCell{false})((x, (hidden_state,))::_RNNCellInputType, ps, st::NamedTuple) h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state - h_new = __apply_activation(rnn.activation, h_new) + h_new = apply_activation(rnn.activation, h_new) return (h_new, (h_new,)), st end diff --git a/src/utils.jl b/src/utils.jl index e8aad7a0a..92ea20bfe 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -118,10 +118,6 @@ function ∇_eachslice(Δ_raw, x::AbstractArray, ::Val{dims}) where {dims} return CRC.ProjectTo(x)(Δ) end -# Activation Function -@inline __apply_activation(::typeof(identity), x) = x -@inline __apply_activation(f, x) = f.(x) - # Backend Integration ## Convolution @inline _conv(x, weight, cdims) = conv(x, weight, cdims) @@ -252,3 +248,11 @@ __named_tuple(nt::NamedTuple) = nt # Nondifferentiable hasmethod. Avoiding type-piracy @inline _hasmethod(f::F, args...) where {F} = hasmethod(f, args...) + +# Helpers for bias and activation functions +## Just Activation Function +@inline apply_activation(::typeof(identity), x) = x +@inline apply_activation(f, x) = f.(x) + +@inline apply_bias_activation(::typeof(identity), x, b) = x .+ b +@inline apply_bias_activation(f::F, x, b) where {F} = @. f(x + b)