Skip to content

Commit

Permalink
Fuse the activation and bias
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 30, 2024
1 parent 9768d5a commit 1c1d5bc
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.30"
version = "0.5.31"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
18 changes: 9 additions & 9 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,29 +211,29 @@ statelength(d::Dense) = 0
outputsize(d::Dense) = (d.out_dims,)

@inline function (d::Dense{false})(x::AbstractVecOrMat, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight * x), st
return apply_activation(d.activation, ps.weight * x), st
end

@inline function (d::Dense{false})(x::AbstractArray, ps, st::NamedTuple)
x_reshaped = reshape(x, size(x, 1), :)
return (
reshape(__apply_activation(d.activation, ps.weight * x_reshaped),
reshape(apply_activation(d.activation, ps.weight * x_reshaped),
d.out_dims, size(x)[2:end]...),
st)
end

@inline function (d::Dense{true})(x::AbstractVector, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight * x .+ vec(ps.bias)), st
return apply_bias_activation(d.activation, ps.weight * x, vec(ps.bias)), st
end

@inline function (d::Dense{true})(x::AbstractMatrix, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight * x .+ ps.bias), st
return apply_bias_activation(d.activation, ps.weight * x, ps.bias), st
end

@inline function (d::Dense{true})(x::AbstractArray, ps, st::NamedTuple)
x_reshaped = reshape(x, size(x, 1), :)
return (
reshape(__apply_activation(d.activation, ps.weight * x_reshaped .+ ps.bias),
reshape(apply_bias_activation(d.activation, ps.weight * x_reshaped, ps.bias),
d.out_dims, size(x)[2:end]...),
st)
end
Expand Down Expand Up @@ -315,11 +315,11 @@ statelength(d::Scale) = 0
outputsize(d::Scale) = d.dims

function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight .* x .+ ps.bias), st
return apply_bias_activation(d.activation, ps.weight .* x, ps.bias), st
end

function (d::Scale{false})(x::AbstractArray, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight .* x), st
return apply_activation(d.activation, ps.weight .* x), st
end

"""
Expand Down Expand Up @@ -431,9 +431,9 @@ function (b::Bilinear{use_bias})((x, y)::Tuple{<:AbstractVecOrMat, <:AbstractVec
Wyx = reshape(batched_mul(Wy, reshape(x, (d_x, 1, :))), (d_z, :))

if use_bias
return __apply_activation(b.activation, Wyx .+ ps.bias), st
return apply_bias_activation(b.activation, Wyx, ps.bias), st
else
return __apply_activation(b.activation, Wyx), st
return apply_activation(b.activation, Wyx), st
end
end

Expand Down
12 changes: 6 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ end
@inline function (c::Conv{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
return __apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
end

@inline function (c::Conv{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
return __apply_activation(c.activation, _conv(x, ps.weight, cdims) .+ ps.bias), st
return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
end

function Base.show(io::IO, l::Conv)
Expand Down Expand Up @@ -620,13 +620,13 @@ end
@inline function (c::CrossCor{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return __apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st

Check warning on line 623 in src/layers/conv.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/conv.jl#L623

Added line #L623 was not covered by tests
end

@inline function (c::CrossCor{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return __apply_activation(c.activation, _conv(x, ps.weight, cdims) .+ ps.bias), st
return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
end

function Base.show(io::IO, l::CrossCor)
Expand Down Expand Up @@ -752,14 +752,14 @@ end
x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = _conv_transpose_dims(
x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return __apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims)), st
return apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims)), st
end

@inline function (c::ConvTranspose{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = _conv_transpose_dims(
x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return (
__apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims) .+ ps.bias),
apply_bias_activation(c.activation, _conv_transpose(x, ps.weight, cdims), ps.bias),
st)
end

Expand Down
8 changes: 4 additions & 4 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple)
@set! st.running_var = stats.running_var
end

return __apply_activation(BN.activation, y), st
return apply_activation(BN.activation, y), st
end

function Base.show(io::IO, l::BatchNorm)
Expand Down Expand Up @@ -229,7 +229,7 @@ parameterlength(l::GroupNorm) = _affine(l) ? (l.chs * 2) : 0
function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple)
y = groupnorm(x, _getproperty(ps, Val(:scale)),
_getproperty(ps, Val(:bias)); GN.groups, GN.epsilon)
return __apply_activation(GN.activation, y), st
return apply_activation(GN.activation, y), st
end

function Base.show(io::IO, l::GroupNorm)
Expand Down Expand Up @@ -335,7 +335,7 @@ parameterlength(l::InstanceNorm) = _affine(l) ? (l.chs * 2) : 0
function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple)
y, stats = instancenorm(x, _getproperty(ps, Val(:scale)),
_getproperty(ps, Val(:bias)); IN.epsilon, st.training)
return __apply_activation(IN.activation, y), st
return apply_activation(IN.activation, y), st
end

function Base.show(io::IO, l::InstanceNorm)
Expand Down Expand Up @@ -554,7 +554,7 @@ end
function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple)
y = layernorm(
x, _getproperty(ps, Val(:scale)), _getproperty(ps, Val(:bias)); l.dims, l.epsilon)
return __apply_activation(l.activation, y), st
return apply_activation(l.activation, y), st
end

function Base.show(io::IO, l::LayerNorm)
Expand Down
4 changes: 2 additions & 2 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ const _RNNCellInputType = Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}}

function (rnn::RNNCell{true})((x, (hidden_state,))::_RNNCellInputType, ps, st::NamedTuple)
h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias
h_new = __apply_activation(rnn.activation, h_new)
h_new = apply_activation(rnn.activation, h_new)
return (h_new, (h_new,)), st
end

function (rnn::RNNCell{false})((x, (hidden_state,))::_RNNCellInputType, ps, st::NamedTuple)
h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state
h_new = __apply_activation(rnn.activation, h_new)
h_new = apply_activation(rnn.activation, h_new)
return (h_new, (h_new,)), st
end

Expand Down
12 changes: 8 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,6 @@ function ∇_eachslice(Δ_raw, x::AbstractArray, ::Val{dims}) where {dims}
return CRC.ProjectTo(x)(Δ)
end

# Activation Function
@inline __apply_activation(::typeof(identity), x) = x
@inline __apply_activation(f, x) = f.(x)

# Backend Integration
## Convolution
@inline _conv(x, weight, cdims) = conv(x, weight, cdims)
Expand Down Expand Up @@ -252,3 +248,11 @@ __named_tuple(nt::NamedTuple) = nt

# Nondifferentiable hasmethod. Avoiding type-piracy
@inline _hasmethod(f::F, args...) where {F} = hasmethod(f, args...)

# Helpers for bias and activation functions
## Just Activation Function
@inline apply_activation(::typeof(identity), x) = x
@inline apply_activation(f, x) = f.(x)

@inline apply_bias_activation(::typeof(identity), x, b) = x .+ b
@inline apply_bias_activation(f::F, x, b) where {F} = @. f(x + b)

0 comments on commit 1c1d5bc

Please sign in to comment.