Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a faster activation path #558

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -59,6 +60,7 @@ ChainRulesCore = "1.21"
ComponentArrays = "0.15.11"
ConcreteStructs = "0.2.3"
ConstructionBase = "1.5"
FastBroadcast = "0.2.8"
FastClosures = "0.3.2"
ExplicitImports = "1.1.1"
Flux = "0.14.11"
Expand Down
3 changes: 3 additions & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ const NAME_TYPE = Union{Nothing, String, Symbol}
# Utilities
include("utils.jl")

# Backend Functionality
include("fast_broadcast.jl")

# Layer Implementations
include("layers/basic.jl")
include("layers/containers.jl")
Expand Down
108 changes: 108 additions & 0 deletions src/fast_broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# This file temporarily exists here. It will be moved to LuxLib.jl in the future.
using FastBroadcast: @..

# Adapted from NNlib.jl
# This just saves typing `only.(only.(` many times:
# `sigmoid_fast` fails if we use the fast path, don't know why we just avoid the fast
# gradient path for it
@inline function __only_derivative(y, f::F, x) where {F}
return only(only(CRC.derivatives_given_output(y, f, x)))
end

# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
struct NotaNumber <: Real end

@inline fast_apply_activation!!(::typeof(identity), x::AbstractArray) = x
@inline fast_apply_activation!!(::typeof(sigmoid_fast), x::AbstractArray) = sigmoid_fast.(x)
@inline fast_apply_activation!!(::typeof(sigmoid), x::AbstractArray) = sigmoid.(x)
@inline function fast_apply_activation!!(f::F, x::AbstractArray) where {F}
return __fast_apply_activation_impl!!(f, x)
end

@inline function __fast_apply_activation_impl!!(f::F, x::AbstractArray) where {F}
return fast_fast_broadcast!!(f, x)
end

function CRC.rrule(
cfg::RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fast_apply_activation_impl!!),
f::F, x::AbstractArray{T}) where {F, T}
# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
if isconcretetype(Core.Compiler._return_type(
__only_derivative, Tuple{T, F, NotaNumber}))
Ω = __fast_apply_activation_impl!!(f, x)
__∇fast_apply_activation_impl!!_fast = @closure Δ -> begin
∂x = __only_derivative.(Ω, f, NotaNumber()) .* CRC.unthunk(Δ)
return NoTangent(), NoTangent(), ∂x
end
return Ω, __∇fast_apply_activation_impl!!_fast
end

return CRC.rrule_via_ad(cfg, broadcast, f, x)
end

# Bias Activation Fused
function fast_bias_activation!!(f::F, x::AbstractArray, b::AbstractArray) where {F}
f === identity && return fast_broadcast!!(+, x, b)
return __fast_bias_activation_impl!!(f, x, b)
end
function fast_bias_activation!!(::typeof(sigmoid_fast), x::GPUArraysCore.AbstractGPUArray,
b::GPUArraysCore.AbstractGPUArray)
return __fast_bias_activation_impl!!(sigmoid, x, b)
end

function __fast_bias_activation_impl!!(f::F, x::AbstractArray, b::AbstractArray) where {F}
return fast_fast_broadcast!!(f ∘ +, x, b)
end

function CRC.rrule(
cfg::RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fast_bias_activation_impl!!),
f::F, x::AbstractArray{T, N}, b::AbstractArray) where {F, T, N}
# Summing over ndims(x)+1 is a trick to make b_dims type-stable
dims = ntuple(d -> ifelse(size(b, d) == 1, d, N + 1), N)
∇bias(dx) = reshape(sum(dx; dims), size(b))

if f !== sigmoid_fast && isconcretetype(Core.Compiler._return_type(
__only_derivative, Tuple{T, F, NotaNumber}))
Ω = fast_bias_activation!!(f, x, b)
__∇fast_bias_activation_impl!!_fast = @closure Δ -> begin
∂x = __only_derivative.(Ω, f, NotaNumber()) .* CRC.unthunk(Δ)
return NoTangent(), NoTangent(), ∂x, ∇bias(∂x)
end
return Ω, __∇fast_bias_activation_impl!!_fast
end

return CRC.rrule_via_ad(cfg, fast_broadcast!!, f ∘ +, x, b)
end

# FastBroadcast.jl is efficient only for same axes arrays
@inline fast_broadcast!!(f::F, x) where {F} = fast_fast_broadcast!!(f, x)
@inline function fast_broadcast!!(f::F, x, ys...) where {F}
ax = axes(x)
all(x -> axes(x) == ax, ys) && return fast_fast_broadcast!!(f, x, ys...)
return fast_generic_broadcast!!(f, x, ys...)
end

## Just use non-mutating version for the broadcast
function CRC.rrule(cfg::RuleConfig{>:CRC.HasReverseMode},
::typeof(fast_broadcast!!), f::F, x, ys...) where {F}
return CRC.rrule_via_ad(cfg, broadcast, f, x, ys...)
end

@inline function fast_fast_broadcast!!(f::F, x, ys...) where {F}
ArrayInterface.can_setindex(x) && return @..(x=f(x, ys...))
return @..(f(x, ys...))
end

@inline function fast_generic_broadcast!!(f::F, x, ys...) where {F}
if all(ArrayInterface.fast_scalar_indexing, (x, ys...))
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, ys...))
ArrayInterface.can_setindex(x) || return copy(bc)
@simd ivdep for idx in eachindex(bc)
@inbounds x[idx] = bc[idx]
end
return x
end
ArrayInterface.can_setindex(x) && return @.(x=f(x, ys...))
return @.(f(x, ys...))
end
18 changes: 9 additions & 9 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,29 +211,29 @@ statelength(d::Dense) = 0
outputsize(d::Dense) = (d.out_dims,)

@inline function (d::Dense{false})(x::AbstractVecOrMat, ps, st::NamedTuple)
return apply_activation(d.activation, ps.weight * x), st
return fast_apply_activation!!(d.activation, ps.weight * x), st
end

@inline function (d::Dense{false})(x::AbstractArray, ps, st::NamedTuple)
x_reshaped = reshape(x, size(x, 1), :)
return (
reshape(apply_activation(d.activation, ps.weight * x_reshaped),
reshape(fast_apply_activation!!(d.activation, ps.weight * x_reshaped),
d.out_dims, size(x)[2:end]...),
st)
end

@inline function (d::Dense{true})(x::AbstractVector, ps, st::NamedTuple)
return apply_bias_activation(d.activation, ps.weight * x, vec(ps.bias)), st
return fast_bias_activation!!(d.activation, ps.weight * x, vec(ps.bias)), st
end

@inline function (d::Dense{true})(x::AbstractMatrix, ps, st::NamedTuple)
return apply_bias_activation(d.activation, ps.weight * x, ps.bias), st
return fast_bias_activation!!(d.activation, ps.weight * x, ps.bias), st
end

@inline function (d::Dense{true})(x::AbstractArray, ps, st::NamedTuple)
x_reshaped = reshape(x, size(x, 1), :)
return (
reshape(apply_bias_activation(d.activation, ps.weight * x_reshaped, ps.bias),
reshape(fast_bias_activation!!(d.activation, ps.weight * x_reshaped, ps.bias),
d.out_dims, size(x)[2:end]...),
st)
end
Expand Down Expand Up @@ -315,11 +315,11 @@ statelength(d::Scale) = 0
outputsize(d::Scale) = d.dims

function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple)
return apply_bias_activation(d.activation, ps.weight .* x, ps.bias), st
return fast_apply_activation!!(d.activation, ps.weight .* x .+ ps.bias), st
end

function (d::Scale{false})(x::AbstractArray, ps, st::NamedTuple)
return apply_activation(d.activation, ps.weight .* x), st
return fast_apply_activation!!(d.activation, ps.weight .* x), st
end

"""
Expand Down Expand Up @@ -431,9 +431,9 @@ function (b::Bilinear{use_bias})((x, y)::Tuple{<:AbstractVecOrMat, <:AbstractVec
Wyx = reshape(batched_mul(Wy, reshape(x, (d_x, 1, :))), (d_z, :))

if use_bias
return apply_bias_activation(b.activation, Wyx, ps.bias), st
return fast_bias_activation!!(b.activation, Wyx, ps.bias), st
else
return apply_activation(b.activation, Wyx), st
return fast_apply_activation!!(b.activation, Wyx), st
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ _flatten_model(x) = x
@generated function applychain(
layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}) where {fields}
N = length(fields)
x_symbols = vcat([:x], [gensym() for _ in 1:N])
st_symbols = [gensym() for _ in 1:N]
x_symbols = vcat([:x], [gensym("x") for _ in 1:N])
st_symbols = [gensym("st") for _ in 1:N]
calls = [:(($(x_symbols[i + 1]), $(st_symbols[i])) = Lux.apply(
layers.$(fields[i]), $(x_symbols[i]), ps.$(fields[i]), st.$(fields[i])))
for i in 1:N]
Expand Down
12 changes: 6 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ end
@inline function (c::Conv{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
return fast_apply_activation!!(c.activation, _conv(x, ps.weight, cdims)), st
end

@inline function (c::Conv{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
return fast_bias_activation!!(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
end

function Base.show(io::IO, l::Conv)
Expand Down Expand Up @@ -620,13 +620,13 @@ end
@inline function (c::CrossCor{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
return fast_apply_activation!!(c.activation, _conv(x, ps.weight, cdims)), st
end

@inline function (c::CrossCor{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
return fast_bias_activation!!(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
end

function Base.show(io::IO, l::CrossCor)
Expand Down Expand Up @@ -752,14 +752,14 @@ end
x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = _conv_transpose_dims(
x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return apply_activation(c.activation, _conv_transpose(x, ps.weight, cdims)), st
return fast_apply_activation!!(c.activation, _conv_transpose(x, ps.weight, cdims)), st
end

@inline function (c::ConvTranspose{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = _conv_transpose_dims(
x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return (
apply_bias_activation(c.activation, _conv_transpose(x, ps.weight, cdims), ps.bias),
fast_bias_activation!!(c.activation, _conv_transpose(x, ps.weight, cdims), ps.bias),
st)
end

Expand Down
8 changes: 4 additions & 4 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple)
@set! st.running_var = stats.running_var
end

return apply_activation(BN.activation, y), st
return fast_apply_activation!!(BN.activation, y), st
end

function Base.show(io::IO, l::BatchNorm)
Expand Down Expand Up @@ -229,7 +229,7 @@ parameterlength(l::GroupNorm) = _affine(l) ? (l.chs * 2) : 0
function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple)
y = groupnorm(x, _getproperty(ps, Val(:scale)),
_getproperty(ps, Val(:bias)); GN.groups, GN.epsilon)
return apply_activation(GN.activation, y), st
return fast_apply_activation!!(GN.activation, y), st
end

function Base.show(io::IO, l::GroupNorm)
Expand Down Expand Up @@ -335,7 +335,7 @@ parameterlength(l::InstanceNorm) = _affine(l) ? (l.chs * 2) : 0
function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple)
y, stats = instancenorm(x, _getproperty(ps, Val(:scale)),
_getproperty(ps, Val(:bias)); IN.epsilon, st.training)
return apply_activation(IN.activation, y), st
return fast_apply_activation!!(IN.activation, y), st
end

function Base.show(io::IO, l::InstanceNorm)
Expand Down Expand Up @@ -554,7 +554,7 @@ end
function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple)
y = layernorm(
x, _getproperty(ps, Val(:scale)), _getproperty(ps, Val(:bias)); l.dims, l.epsilon)
return apply_activation(l.activation, y), st
return fast_apply_activation!!(l.activation, y), st
end

function Base.show(io::IO, l::LayerNorm)
Expand Down
4 changes: 2 additions & 2 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ const _RNNCellInputType = Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}}

function (rnn::RNNCell{true})((x, (hidden_state,))::_RNNCellInputType, ps, st::NamedTuple)
h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias
h_new = apply_activation(rnn.activation, h_new)
h_new = fast_apply_activation!!(rnn.activation, h_new)
return (h_new, (h_new,)), st
end

function (rnn::RNNCell{false})((x, (hidden_state,))::_RNNCellInputType, ps, st::NamedTuple)
h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state
h_new = apply_activation(rnn.activation, h_new)
h_new = fast_apply_activation!!(rnn.activation, h_new)
return (h_new, (h_new,)), st
end

Expand Down
8 changes: 0 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,3 @@ __named_tuple(nt::NamedTuple) = nt

# Nondifferentiable hasmethod. Avoiding type-piracy
@inline _hasmethod(f::F, args...) where {F} = hasmethod(f, args...)

# Helpers for bias and activation functions
## Just Activation Function
@inline apply_activation(::typeof(identity), x) = x
@inline apply_activation(f, x) = f.(x)

@inline apply_bias_activation(::typeof(identity), x, b) = x .+ b
@inline apply_bias_activation(f::F, x, b) where {F} = @. f(x + b)
Loading