From d3d3b2102b9b94af04a254f0c3f5e58ff452cac5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Aug 2024 16:54:49 -0700 Subject: [PATCH] refactor: `MultiHeadSelfAttention` Layer --- .JuliaFormatter.toml | 2 +- src/initialize.jl | 4 ++-- src/layers/Layers.jl | 9 +++++++-- src/layers/attention.jl | 39 ++++++++++++++++++++++++++------------- src/vision/Vision.jl | 3 ++- test/runtests.jl | 2 +- 6 files changed, 39 insertions(+), 20 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 5ad5726..2f19034 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -5,5 +5,5 @@ indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true always_for_in = true -join_lines_based_on_source = false +join_lines_based_on_source = true annotate_untyped_fields_with_any = false diff --git a/src/initialize.jl b/src/initialize.jl index e4e2a7c..1cb4232 100644 --- a/src/initialize.jl +++ b/src/initialize.jl @@ -24,8 +24,8 @@ function initialize_model( name::Symbol, model; pretrained::Bool=false, rng=nothing, seed=0, kwargs...) if pretrained path = get_pretrained_weights_path(name) - ps = load(joinpath(path, "$name.jld2"), "parameters") - st = load(joinpath(path, "$name.jld2"), "states") + ps = JLD2.load(joinpath(path, "$name.jld2"), "parameters") + st = JLD2.load(joinpath(path, "$name.jld2"), "states") return ps, st end if rng === nothing diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index a692318..e0fe7e4 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -2,15 +2,18 @@ module Layers using ArgCheck: @argcheck using ADTypes: AutoForwardDiff, AutoZygote +using Compat: @compat using ConcreteStructs: @concrete using ChainRulesCore: ChainRulesCore +using Markdown: @doc_str +using Random: AbstractRNG + using ForwardDiff: ForwardDiff + using Lux: Lux, StatefulLuxLayer using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer -using Markdown: @doc_str using MLDataDevices: get_device_type, CPUDevice, CUDADevice using NNlib: NNlib -using Random: AbstractRNG using WeightInitializers: zeros32, randn32 using ..Boltz: Boltz @@ -35,4 +38,6 @@ include("mlp.jl") include("spline.jl") include("tensor_product.jl") +@compat public MultiHeadSelfAttention + end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 5d90fff..0c6b686 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -12,21 +12,34 @@ Multi-head self-attention layer - `attn_dropout_prob`: dropout probability after the self-attention layer - `proj_dropout_prob`: dropout probability after the projection layer """ +@concrete struct MultiHeadSelfAttention <: + AbstractExplicitContainerLayer{(:qkv_layer, :dropout, :projection)} + qkv_layer + dropout + projection + nheads::Int +end + +# TODO[BREAKING]: rename `qkv_bias` to `use_qkv_bias` function MultiHeadSelfAttention(in_planes::Int, number_heads::Int; qkv_bias::Bool=false, attention_dropout_rate::T=0.0f0, projection_dropout_rate::T=0.0f0) where {T} @argcheck in_planes % number_heads == 0 + return MultiHeadSelfAttention( + Lux.Dense(in_planes, in_planes * 3; use_bias=qkv_bias), + Lux.Dropout(attention_dropout_rate), + Lux.Chain(Lux.Dense(in_planes => in_planes), Lux.Dropout(projection_dropout_rate)), + number_heads + ) +end + +function (mhsa::MultiHeadSelfAttention)(x::AbstractArray{T, 3}, ps, st) where {T} + qkv, st_qkv = mhsa.qkv_layer(x, ps.qkv_layer, st.qkv_layer) + q, k, v = fast_chunk(qkv, Val(3), Val(1)) + + attn_dropout = Lux.StatefulLuxLayer{true}(mhsa.dropout, ps.dropout, st.dropout) + y, _ = NNlib.dot_product_attention(q, k, v; fdrop=attn_dropout, mhsa.nheads) + + z, st_proj = mhsa.projection(y, ps.projection, st.projection) - qkv_layer = Lux.Dense(in_planes, in_planes * 3; use_bias=qkv_bias) - attention_dropout = Lux.Dropout(attention_dropout_rate) - projection = Lux.Chain( - Lux.Dense(in_planes => in_planes), Lux.Dropout(projection_dropout_rate)) - - return Lux.@compact(; number_heads, qkv_layer, attention_dropout, - projection, dispatch=:MultiHeadSelfAttention) do x::AbstractArray{<:Real, 3} - qkv = qkv_layer(x) - q, k, v = fast_chunk(qkv, Val(3), Val(1)) - y, _ = NNlib.dot_product_attention( - q, k, v; fdrop=attention_dropout, nheads=number_heads) - @return projection(y) - end + return z, (; qkv_layer=st_qkv, dropout=attn_dropout.st, projection=st_proj) end diff --git a/src/vision/Vision.jl b/src/vision/Vision.jl index b6aea1f..70eb8b2 100644 --- a/src/vision/Vision.jl +++ b/src/vision/Vision.jl @@ -2,10 +2,11 @@ module Vision using ArgCheck: @argcheck using Compat: @compat +using ConcreteStructs: @concrete using Random: Xoshiro using Lux: Lux -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer using NNlib: relu using ..InitializeModels: maybe_initialize_model, INITIALIZE_KWARGS diff --git a/test/runtests.jl b/test/runtests.jl index c7d4513..5ed2d20 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using ReTestItems, Pkg, InteractiveUtils, Hwloc -@info sprint(io -> versioninfo(io; verbose=true)) +@info sprint(versioninfo) const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) const EXTRA_PKGS = String[]