Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse the activation and bias #570

Merged
merged 1 commit into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.30"
version = "0.5.31"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
18 changes: 9 additions & 9 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,29 +211,29 @@ statelength(d::Dense) = 0
outputsize(d::Dense) = (d.out_dims,)

@inline function (d::Dense{false})(x::AbstractVecOrMat, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight * x), st
return apply_activation(d.activation, ps.weight * x), st
end

@inline function (d::Dense{false})(x::AbstractArray, ps, st::NamedTuple)
x_reshaped = reshape(x, size(x, 1), :)
return (
reshape(__apply_activation(d.activation, ps.weight * x_reshaped),
reshape(apply_activation(d.activation, ps.weight * x_reshaped),
d.out_dims, size(x)[2:end]...),
st)
end

@inline function (d::Dense{true})(x::AbstractVector, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight * x .+ vec(ps.bias)), st
return apply_bias_activation(d.activation, ps.weight * x, vec(ps.bias)), st
end

@inline function (d::Dense{true})(x::AbstractMatrix, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight * x .+ ps.bias), st
return apply_bias_activation(d.activation, ps.weight * x, ps.bias), st
end

@inline function (d::Dense{true})(x::AbstractArray, ps, st::NamedTuple)
x_reshaped = reshape(x, size(x, 1), :)
return (
reshape(__apply_activation(d.activation, ps.weight * x_reshaped .+ ps.bias),
reshape(apply_bias_activation(d.activation, ps.weight * x_reshaped, ps.bias),
d.out_dims, size(x)[2:end]...),
st)
end
Expand Down Expand Up @@ -315,11 +315,11 @@ statelength(d::Scale) = 0
outputsize(d::Scale) = d.dims

function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight .* x .+ ps.bias), st
return apply_bias_activation(d.activation, ps.weight .* x, ps.bias), st
end

function (d::Scale{false})(x::AbstractArray, ps, st::NamedTuple)
return __apply_activation(d.activation, ps.weight .* x), st
return apply_activation(d.activation, ps.weight .* x), st
end

"""
Expand Down Expand Up @@ -431,9 +431,9 @@ function (b::Bilinear{use_bias})((x, y)::Tuple{<:AbstractVecOrMat, <:AbstractVec
Wyx = reshape(batched_mul(Wy, reshape(x, (d_x, 1, :))), (d_z, :))

if use_bias
return __apply_activation(b.activation, Wyx .+ ps.bias), st
return apply_bias_activation(b.activation, Wyx, ps.bias), st
else
return __apply_activation(b.activation, Wyx), st
return apply_activation(b.activation, Wyx), st
end
end

Expand Down
12 changes: 6 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@
@inline function (c::Conv{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
return __apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
end

@inline function (c::Conv{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
return __apply_activation(c.activation, _conv(x, ps.weight, cdims) .+ ps.bias), st
return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
end

function Base.show(io::IO, l::Conv)
Expand Down Expand Up @@ -620,13 +620,13 @@
@inline function (c::CrossCor{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return __apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st

Check warning on line 623 in src/layers/conv.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/conv.jl#L623

Added line #L623 was not covered by tests
end

@inline function (c::CrossCor{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return __apply_activation(c.activation, _conv(x, ps.weight, cdims) .+ ps.bias), st
return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
end

function Base.show(io::IO, l::CrossCor)
Expand Down Expand Up @@ -752,14 +752,14 @@
x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = _conv_transpose_dims(
x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return __apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims)), st
return apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims)), st
end

@inline function (c::ConvTranspose{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = _conv_transpose_dims(
x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return (
__apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims) .+ ps.bias),
apply_bias_activation(c.activation, _conv_transpose(x, ps.weight, cdims), ps.bias),
st)
end

Expand Down
8 changes: 4 additions & 4 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple)
@set! st.running_var = stats.running_var
end

return __apply_activation(BN.activation, y), st
return apply_activation(BN.activation, y), st
end

function Base.show(io::IO, l::BatchNorm)
Expand Down Expand Up @@ -229,7 +229,7 @@ parameterlength(l::GroupNorm) = _affine(l) ? (l.chs * 2) : 0
function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple)
y = groupnorm(x, _getproperty(ps, Val(:scale)),
_getproperty(ps, Val(:bias)); GN.groups, GN.epsilon)
return __apply_activation(GN.activation, y), st
return apply_activation(GN.activation, y), st
end

function Base.show(io::IO, l::GroupNorm)
Expand Down Expand Up @@ -335,7 +335,7 @@ parameterlength(l::InstanceNorm) = _affine(l) ? (l.chs * 2) : 0
function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple)
y, stats = instancenorm(x, _getproperty(ps, Val(:scale)),
_getproperty(ps, Val(:bias)); IN.epsilon, st.training)
return __apply_activation(IN.activation, y), st
return apply_activation(IN.activation, y), st
end

function Base.show(io::IO, l::InstanceNorm)
Expand Down Expand Up @@ -554,7 +554,7 @@ end
function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple)
y = layernorm(
x, _getproperty(ps, Val(:scale)), _getproperty(ps, Val(:bias)); l.dims, l.epsilon)
return __apply_activation(l.activation, y), st
return apply_activation(l.activation, y), st
end

function Base.show(io::IO, l::LayerNorm)
Expand Down
4 changes: 2 additions & 2 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ const _RNNCellInputType = Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}}

function (rnn::RNNCell{true})((x, (hidden_state,))::_RNNCellInputType, ps, st::NamedTuple)
h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias
h_new = __apply_activation(rnn.activation, h_new)
h_new = apply_activation(rnn.activation, h_new)
return (h_new, (h_new,)), st
end

function (rnn::RNNCell{false})((x, (hidden_state,))::_RNNCellInputType, ps, st::NamedTuple)
h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state
h_new = __apply_activation(rnn.activation, h_new)
h_new = apply_activation(rnn.activation, h_new)
return (h_new, (h_new,)), st
end

Expand Down
12 changes: 8 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,6 @@ function ∇_eachslice(Δ_raw, x::AbstractArray, ::Val{dims}) where {dims}
return CRC.ProjectTo(x)(Δ)
end

# Activation Function
@inline __apply_activation(::typeof(identity), x) = x
@inline __apply_activation(f, x) = f.(x)

# Backend Integration
## Convolution
@inline _conv(x, weight, cdims) = conv(x, weight, cdims)
Expand Down Expand Up @@ -252,3 +248,11 @@ __named_tuple(nt::NamedTuple) = nt

# Nondifferentiable hasmethod. Avoiding type-piracy
@inline _hasmethod(f::F, args...) where {F} = hasmethod(f, args...)

# Helpers for bias and activation functions
## Just Activation Function
@inline apply_activation(::typeof(identity), x) = x
@inline apply_activation(f, x) = f.(x)

@inline apply_bias_activation(::typeof(identity), x, b) = x .+ b
@inline apply_bias_activation(f::F, x, b) where {F} = @. f(x + b)
Loading