Skip to content

Commit

Permalink
refactor: MultiHeadSelfAttention Layer
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 23, 2024
1 parent fcc727f commit d3d3b21
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
join_lines_based_on_source = false
join_lines_based_on_source = true
annotate_untyped_fields_with_any = false
4 changes: 2 additions & 2 deletions src/initialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ function initialize_model(
name::Symbol, model; pretrained::Bool=false, rng=nothing, seed=0, kwargs...)
if pretrained
path = get_pretrained_weights_path(name)
ps = load(joinpath(path, "$name.jld2"), "parameters")
st = load(joinpath(path, "$name.jld2"), "states")
ps = JLD2.load(joinpath(path, "$name.jld2"), "parameters")
st = JLD2.load(joinpath(path, "$name.jld2"), "states")
return ps, st
end
if rng === nothing
Expand Down
9 changes: 7 additions & 2 deletions src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ module Layers

using ArgCheck: @argcheck
using ADTypes: AutoForwardDiff, AutoZygote
using Compat: @compat
using ConcreteStructs: @concrete
using ChainRulesCore: ChainRulesCore
using Markdown: @doc_str
using Random: AbstractRNG

using ForwardDiff: ForwardDiff

using Lux: Lux, StatefulLuxLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Markdown: @doc_str
using MLDataDevices: get_device_type, CPUDevice, CUDADevice
using NNlib: NNlib
using Random: AbstractRNG
using WeightInitializers: zeros32, randn32

using ..Boltz: Boltz
Expand All @@ -35,4 +38,6 @@ include("mlp.jl")
include("spline.jl")
include("tensor_product.jl")

@compat public MultiHeadSelfAttention

end
39 changes: 26 additions & 13 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,34 @@ Multi-head self-attention layer
- `attn_dropout_prob`: dropout probability after the self-attention layer
- `proj_dropout_prob`: dropout probability after the projection layer
"""
@concrete struct MultiHeadSelfAttention <:
AbstractExplicitContainerLayer{(:qkv_layer, :dropout, :projection)}
qkv_layer
dropout
projection
nheads::Int
end

# TODO[BREAKING]: rename `qkv_bias` to `use_qkv_bias`
function MultiHeadSelfAttention(in_planes::Int, number_heads::Int; qkv_bias::Bool=false,
attention_dropout_rate::T=0.0f0, projection_dropout_rate::T=0.0f0) where {T}
@argcheck in_planes % number_heads == 0
return MultiHeadSelfAttention(
Lux.Dense(in_planes, in_planes * 3; use_bias=qkv_bias),
Lux.Dropout(attention_dropout_rate),
Lux.Chain(Lux.Dense(in_planes => in_planes), Lux.Dropout(projection_dropout_rate)),
number_heads
)
end

function (mhsa::MultiHeadSelfAttention)(x::AbstractArray{T, 3}, ps, st) where {T}
qkv, st_qkv = mhsa.qkv_layer(x, ps.qkv_layer, st.qkv_layer)
q, k, v = fast_chunk(qkv, Val(3), Val(1))

attn_dropout = Lux.StatefulLuxLayer{true}(mhsa.dropout, ps.dropout, st.dropout)
y, _ = NNlib.dot_product_attention(q, k, v; fdrop=attn_dropout, mhsa.nheads)

z, st_proj = mhsa.projection(y, ps.projection, st.projection)

qkv_layer = Lux.Dense(in_planes, in_planes * 3; use_bias=qkv_bias)
attention_dropout = Lux.Dropout(attention_dropout_rate)
projection = Lux.Chain(
Lux.Dense(in_planes => in_planes), Lux.Dropout(projection_dropout_rate))

return Lux.@compact(; number_heads, qkv_layer, attention_dropout,
projection, dispatch=:MultiHeadSelfAttention) do x::AbstractArray{<:Real, 3}
qkv = qkv_layer(x)
q, k, v = fast_chunk(qkv, Val(3), Val(1))
y, _ = NNlib.dot_product_attention(
q, k, v; fdrop=attention_dropout, nheads=number_heads)
@return projection(y)
end
return z, (; qkv_layer=st_qkv, dropout=attn_dropout.st, projection=st_proj)
end
3 changes: 2 additions & 1 deletion src/vision/Vision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ module Vision

using ArgCheck: @argcheck
using Compat: @compat
using ConcreteStructs: @concrete
using Random: Xoshiro

using Lux: Lux
using LuxCore: LuxCore, AbstractExplicitLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using NNlib: relu

using ..InitializeModels: maybe_initialize_model, INITIALIZE_KWARGS
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ReTestItems, Pkg, InteractiveUtils, Hwloc

@info sprint(io -> versioninfo(io; verbose=true))
@info sprint(versioninfo)

const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))
const EXTRA_PKGS = String[]
Expand Down

0 comments on commit d3d3b21

Please sign in to comment.