From 18a737dc9b5cb55ce0a81a8913bf025097360703 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 19 Aug 2023 20:06:25 -0400 Subject: [PATCH] Remove pretrained resnets for now --- .JuliaFormatter.toml | 1 - Artifacts.toml | 60 +++++++++++++-------------- README.md | 2 +- ext/BoltzMetalheadExt.jl | 5 --- src/Boltz.jl | 9 +--- src/utils.jl | 8 +--- src/vision/vgg.jl | 89 ++++++++++------------------------------ src/vision/vit.jl | 87 ++++++++++++--------------------------- test/vision.jl | 64 ++++++----------------------- 9 files changed, 95 insertions(+), 230 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index d134ef2..dbc3116 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -4,6 +4,5 @@ always_use_return = true margin = 92 indent = 4 format_docstrings = true -join_lines_based_on_source = false separate_kwargs_with_semicolon = true always_for_in = true diff --git a/Artifacts.toml b/Artifacts.toml index ae70f7c..7e02339 100644 --- a/Artifacts.toml +++ b/Artifacts.toml @@ -6,45 +6,45 @@ lazy = true sha256 = "e20107404aba1c2c0ed3fad4314033a2fa600cdc0c55d03bc1bfe4f8e5031105" url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/alexnet.tar.gz" -[resnet101] -git-tree-sha1 = "6c9143d40950726405b88db0cc021fa1dcbc0896" -lazy = true +# [resnet101] +# git-tree-sha1 = "6c9143d40950726405b88db0cc021fa1dcbc0896" +# lazy = true - [[resnet101.download]] - sha256 = "3840f05b3d996b2b3ea1e8fb6617775fd60ad6b8769402200fdc9c8b8dca246f" - url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet101.tar.gz" +# [[resnet101.download]] +# sha256 = "3840f05b3d996b2b3ea1e8fb6617775fd60ad6b8769402200fdc9c8b8dca246f" +# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet101.tar.gz" -[resnet152] -git-tree-sha1 = "892915c44de37537aad97da3de8a4458dfa36297" -lazy = true +# [resnet152] +# git-tree-sha1 = "892915c44de37537aad97da3de8a4458dfa36297" +# lazy = true - [[resnet152.download]] - sha256 = "6033a1ecc46d7f4ed1139067c5f9f5ea0d247656e9abbbe755c4702ec5a636d6" - url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet152.tar.gz" +# [[resnet152.download]] +# sha256 = "6033a1ecc46d7f4ed1139067c5f9f5ea0d247656e9abbbe755c4702ec5a636d6" +# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet152.tar.gz" -[resnet18] -git-tree-sha1 = "1d4a46fee1bb87eeef0ce2c85f63cfe0ff47d4de" -lazy = true +# [resnet18] +# git-tree-sha1 = "1d4a46fee1bb87eeef0ce2c85f63cfe0ff47d4de" +# lazy = true - [[resnet18.download]] - sha256 = "f4041ea1d1ec9bba86c7a5a519daaa49bb096a55fcd4ebf74f0743c8bdcb1c35" - url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet18.tar.gz" +# [[resnet18.download]] +# sha256 = "f4041ea1d1ec9bba86c7a5a519daaa49bb096a55fcd4ebf74f0743c8bdcb1c35" +# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet18.tar.gz" -[resnet34] -git-tree-sha1 = "306a8055ae9207ae2a316e31b376254557e481c9" -lazy = true +# [resnet34] +# git-tree-sha1 = "306a8055ae9207ae2a316e31b376254557e481c9" +# lazy = true - [[resnet34.download]] - sha256 = "d62e40ee9213ea9611e3fcedc958df4011da1fa108fb1537bac91e6b7778a3c8" - url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet34.tar.gz" +# [[resnet34.download]] +# sha256 = "d62e40ee9213ea9611e3fcedc958df4011da1fa108fb1537bac91e6b7778a3c8" +# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet34.tar.gz" -[resnet50] -git-tree-sha1 = "8c5866edb29b53f581a9ed7148efa1dbccde6133" -lazy = true +# [resnet50] +# git-tree-sha1 = "8c5866edb29b53f581a9ed7148efa1dbccde6133" +# lazy = true - [[resnet50.download]] - sha256 = "275365d76e592c6ea35574853a75ee068767641664e7817aedf394fcd7fea25a" - url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet50.tar.gz" +# [[resnet50.download]] +# sha256 = "275365d76e592c6ea35574853a75ee068767641664e7817aedf394fcd7fea25a" +# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet50.tar.gz" [vgg11] git-tree-sha1 = "ea7e8ef9399a0fe0aad2331781af5d6435950d36" diff --git a/README.md b/README.md index 6440ba6..ebc19db 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Pkg.add("Boltz") ## Getting Started ```julia -using Boltz, Lux +using Boltz, Lux, Metalhead model, ps, st = resnet(:resnet18; pretrained=true) ``` diff --git a/ext/BoltzMetalheadExt.jl b/ext/BoltzMetalheadExt.jl index ab6d7ba..ff3e3f3 100644 --- a/ext/BoltzMetalheadExt.jl +++ b/ext/BoltzMetalheadExt.jl @@ -30,11 +30,6 @@ function resnet(name::Symbol; pretrained=false, kwargs...) transform(ResNet(152).layers) end - # Compatibility with pretrained weights - if pretrained - model = Chain(model[1], model[2]) - end - return _initialize_model(name, model; pretrained, kwargs...) end diff --git a/src/Boltz.jl b/src/Boltz.jl index 403e4f9..4303cad 100644 --- a/src/Boltz.jl +++ b/src/Boltz.jl @@ -13,14 +13,7 @@ function __init__() end # Define functions. Methods defined in files or in extensions later -for f in (:alexnet, - :convmixer, - :densenet, - :googlenet, - :mobilenet, - :resnet, - :resnext, - :vgg, +for f in (:alexnet, :convmixer, :densenet, :googlenet, :mobilenet, :resnet, :resnext, :vgg, :vision_transformer) @eval function $(f) end end diff --git a/src/utils.jl b/src/utils.jl index 4c371d8..9c9d133 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -43,12 +43,8 @@ function _get_pretrained_weights_path(name::String) end end -function _initialize_model(name::Symbol, - model; - pretrained::Bool=false, - rng=nothing, - seed=0, - kwargs...) +function _initialize_model(name::Symbol, model; pretrained::Bool=false, rng=nothing, + seed=0, kwargs...) if pretrained path = _get_pretrained_weights_path(name) ps = load(joinpath(path, "$name.jld2"), "parameters") diff --git a/src/vision/vgg.jl b/src/vision/vgg.jl index 023cf59..05dbbe2 100644 --- a/src/vision/vgg.jl +++ b/src/vision/vgg.jl @@ -11,15 +11,12 @@ A VGG block of convolution layers ([reference](https://arxiv.org/abs/1409.1556v6 - `batchnorm`: set to `true` to include batch normalization after each convolution """ function _vgg_block(input_filters, output_filters, depth, batchnorm) - k = (3, 3) - p = (1, 1) layers = [] for _ in 1:depth push!(layers, - Conv(k, input_filters => output_filters, batchnorm ? identity : relu; pad=p)) - if batchnorm - push!(layers, BatchNorm(output_filters, relu)) - end + Conv((3, 3), input_filters => output_filters, + batchnorm ? identity : relu; pad=(1, 1))) + batchnorm && push!(layers, BatchNorm(output_filters, relu)) input_filters = output_filters end return Chain(layers...) @@ -63,12 +60,8 @@ Create VGG classifier (fully connected) layers - `dropout`: the dropout level between each fully connected layer """ function _vgg_classifier_layers(imsize, nclasses, fcsize, dropout) - return Chain(FlattenLayer(), - Dense(Int(prod(imsize)), fcsize, relu), - Dropout(dropout), - Dense(fcsize, fcsize, relu), - Dropout(dropout), - Dense(fcsize, nclasses)) + return Chain(FlattenLayer(), Dense(Int(prod(imsize)), fcsize, relu), Dropout(dropout), + Dense(fcsize, fcsize, relu), Dropout(dropout), Dense(fcsize, nclasses)) end """ @@ -104,69 +97,29 @@ function vgg(name::Symbol; kwargs...) assert_name_present_in(name, (:vgg11, :vgg11_bn, :vgg13, :vgg13_bn, :vgg16, :vgg16_bn, :vgg19, :vgg19_bn)) model = if name == :vgg11 - vgg((224, 224); - config=VGG_CONV_CONFIG[VGG_CONFIG[11]], - inchannels=3, - batchnorm=false, - nclasses=1000, - fcsize=4096, - dropout=0.5f0) + vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[11]], inchannels=3, + batchnorm=false, nclasses=1000, fcsize=4096, dropout=0.5f0) elseif name == :vgg11_bn - vgg((224, 224); - config=VGG_CONV_CONFIG[VGG_CONFIG[11]], - inchannels=3, - batchnorm=true, - nclasses=1000, - fcsize=4096, - dropout=0.5f0) + vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[11]], inchannels=3, + batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0) elseif name == :vgg13 - vgg((224, 224); - config=VGG_CONV_CONFIG[VGG_CONFIG[13]], - inchannels=3, - batchnorm=false, - nclasses=1000, - fcsize=4096, - dropout=0.5f0) + vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[13]], inchannels=3, + batchnorm=false, nclasses=1000, fcsize=4096, dropout=0.5f0) elseif name == :vgg13_bn - vgg((224, 224); - config=VGG_CONV_CONFIG[VGG_CONFIG[13]], - inchannels=3, - batchnorm=true, - nclasses=1000, - fcsize=4096, - dropout=0.5f0) + vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[13]], inchannels=3, + batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0) elseif name == :vgg16 - vgg((224, 224); - config=VGG_CONV_CONFIG[VGG_CONFIG[16]], - inchannels=3, - batchnorm=false, - nclasses=1000, - fcsize=4096, - dropout=0.5f0) + vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[16]], inchannels=3, + batchnorm=false, nclasses=1000, fcsize=4096, dropout=0.5f0) elseif name == :vgg16_bn - vgg((224, 224); - config=VGG_CONV_CONFIG[VGG_CONFIG[16]], - inchannels=3, - batchnorm=true, - nclasses=1000, - fcsize=4096, - dropout=0.5f0) + vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[16]], inchannels=3, + batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0) elseif name == :vgg19 - vgg((224, 224); - config=VGG_CONV_CONFIG[VGG_CONFIG[19]], - inchannels=3, - batchnorm=false, - nclasses=1000, - fcsize=4096, - dropout=0.5f0) + vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[19]], inchannels=3, + batchnorm=false, nclasses=1000, fcsize=4096, dropout=0.5f0) elseif name == :vgg19_bn - vgg((224, 224); - config=VGG_CONV_CONFIG[VGG_CONFIG[19]], - inchannels=3, - batchnorm=true, - nclasses=1000, - fcsize=4096, - dropout=0.5f0) + vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[19]], inchannels=3, + batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0) end return _initialize_model(name, model; kwargs...) end diff --git a/src/vision/vit.jl b/src/vision/vit.jl index b0ab029..e7a5196 100644 --- a/src/vision/vit.jl +++ b/src/vision/vit.jl @@ -14,11 +14,8 @@ struct MultiHeadAttention{Q, A, P} <: projection::P end -function MultiHeadAttention(in_planes::Int, - number_heads::Int; - qkv_bias::Bool=false, - attention_dropout_rate::T=0.0f0, - projection_dropout_rate::T=0.0f0) where {T} +function MultiHeadAttention(in_planes::Int, number_heads::Int; qkv_bias::Bool=false, + attention_dropout_rate::T=0.0f0, projection_dropout_rate::T=0.0f0) where {T} @assert in_planes % number_heads==0 "`in_planes` should be divisible by `number_heads`" qkv_layer = Dense(in_planes, in_planes * 3; use_bias=qkv_bias) attention_dropout = Dropout(attention_dropout_rate) @@ -32,37 +29,26 @@ function (m::MultiHeadAttention)(x::AbstractArray{T, 3}, ps, st) where {T} x_reshaped = reshape(x, nfeatures, seq_len * batch_size) qkv, st_qkv = m.qkv_layer(x_reshaped, ps.qkv_layer, st.qkv_layer) - qkv_reshaped = reshape(qkv, - nfeatures ÷ m.number_heads, - m.number_heads, - seq_len, + qkv_reshaped = reshape(qkv, nfeatures ÷ m.number_heads, m.number_heads, seq_len, 3 * batch_size) query, key, value = _fast_chunk(qkv_reshaped, Val(3), Val(4)) scale = convert(T, sqrt(size(query, 1) / m.number_heads)) - key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), - m.number_heads, - nfeatures ÷ m.number_heads, - seq_len * batch_size) - query_reshaped = reshape(query, - nfeatures ÷ m.number_heads, - m.number_heads, + key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.number_heads, + nfeatures ÷ m.number_heads, seq_len * batch_size) + query_reshaped = reshape(query, nfeatures ÷ m.number_heads, m.number_heads, seq_len * batch_size) attention = softmax(batched_mul(query_reshaped, key_reshaped) .* scale) - attention, st_attention = m.attention_dropout(attention, - ps.attention_dropout, + attention, st_attention = m.attention_dropout(attention, ps.attention_dropout, st.attention_dropout) - value_reshaped = reshape(value, - nfeatures ÷ m.number_heads, - m.number_heads, + value_reshaped = reshape(value, nfeatures ÷ m.number_heads, m.number_heads, seq_len * batch_size) pre_projection = reshape(batched_mul(attention, value_reshaped), (nfeatures, seq_len, batch_size)) y, st_projection = m.projection(reshape(pre_projection, size(pre_projection, 1), :), - ps.projection, - st.projection) + ps.projection, st.projection) st_ = (qkv_layer=st_qkv, attention=st_attention, projection=st_projection) return reshape(y, :, seq_len, batch_size), st_ @@ -131,25 +117,19 @@ Transformer as used in the base ViT architecture. - `mlp_ratio`: ratio of MLP layers to the number of input channels - `dropout_rate`: dropout rate """ -function transformer_encoder(in_planes, - depth, - number_heads; - mlp_ratio=4.0f0, +function transformer_encoder(in_planes, depth, number_heads; mlp_ratio=4.0f0, dropout_rate=0.0f0) hidden_planes = floor(Int, mlp_ratio * in_planes) layers = [Chain(SkipConnection(Chain(LayerNorm((in_planes, 1); affine=true), - MultiHeadAttention(in_planes, - number_heads; - attention_dropout_rate=dropout_rate, - projection_dropout_rate=dropout_rate)), - +), - SkipConnection(Chain(LayerNorm((in_planes, 1); affine=true), - Chain(Dense(in_planes => hidden_planes, gelu), - Dropout(dropout_rate), - Dense(hidden_planes => in_planes), - Dropout(dropout_rate)); - disable_optimizations=true), - +)) for _ in 1:depth] + MultiHeadAttention(in_planes, number_heads; + attention_dropout_rate=droout_rate, + projection_dropout_rate=dropout_rate)), + +), + SkipConnection(Chain(LayerNorm((in_planes, 1); affine=true), + Chain(Dense(in_planes => hidden_planes, gelu), Dropout(dropout_rate), + Dense(hidden_planes => in_planes), Dropout(dropout_rate)); + disable_optimizations=true), + +)) for _ in 1:depth] return Chain(layers...; disable_optimizations=true) end @@ -166,38 +146,25 @@ function patch_embedding(imsize::Tuple{<:Int, <:Int}=(224, 224); "Image dimensions must be divisible by the patch size." return Chain(Conv(patch_size, in_channels => embed_planes; stride=patch_size), - flatten ? _flatten_spatial : identity, - norm_layer(embed_planes)) + flatten ? _flatten_spatial : identity, norm_layer(embed_planes)) end # ViT Implementation -function vision_transformer(; - imsize::Tuple{<:Int, <:Int}=(256, 256), - in_channels::Int=3, - patch_size::Tuple{<:Int, <:Int}=(16, 16), - embed_planes::Int=768, - depth::Int=6, - number_heads=16, - mlp_ratio=4.0f0, - dropout_rate=0.1f0, - embedding_dropout_rate=0.1f0, - pool::Symbol=:class, - num_classes::Int=1000, - kwargs...) +function vision_transformer(; imsize::Tuple{<:Int, <:Int}=(256, 256), in_channels::Int=3, + patch_size::Tuple{<:Int, <:Int}=(16, 16), embed_planes::Int=768, depth::Int=6, + number_heads=16, mlp_ratio=4.0f0, dropout_rate=0.1f0, embedding_dropout_rate=0.1f0, + pool::Symbol=:class, num_classes::Int=1000, kwargs...) @assert pool in (:class, :mean) "Pool type must be either :class (class token) or :mean (mean pooling)" number_patches = prod(imsize .÷ patch_size) return Chain(Chain(patch_embedding(imsize; in_channels, patch_size, embed_planes), - ClassTokens(embed_planes), - ViPosEmbedding(embed_planes, number_patches + 1), + ClassTokens(embed_planes), ViPosEmbedding(embed_planes, number_patches + 1), Dropout(embedding_dropout_rate), transformer_encoder(embed_planes, depth, number_heads; mlp_ratio, dropout_rate), ((pool == :class) ? WrappedFunction(x -> x[:, 1, :]) : - WrappedFunction(_seconddimmean)); - disable_optimizations=true), + WrappedFunction(_seconddimmean)); disable_optimizations=true), Chain(LayerNorm((embed_planes,); affine=true), - Dense(embed_planes, num_classes, tanh); - disable_optimizations=true); + Dense(embed_planes, num_classes, tanh); disable_optimizations=true); disable_optimizations=true) end diff --git a/test/vision.jl b/test/vision.jl index 2bd194a..b3c5e0f 100644 --- a/test/vision.jl +++ b/test/vision.jl @@ -3,61 +3,23 @@ using Metalhead # Trigger Weak Dependency on Metalhead include("test_utils.jl") -models_available = Dict(alexnet => [(:alexnet, true), (:alexnet, false)], +models_available = Dict(alexnet => [(:alexnet, false)], convmixer => [(:small, false), (:base, false), (:large, false)], - densenet => [ - (:densenet121, false), - (:densenet161, false), - (:densenet169, false), - (:densenet201, false), - ], + densenet => [(:densenet121, false), (:densenet161, false), (:densenet169, false), + (:densenet201, false)], googlenet => [(:googlenet, false)], - mobilenet => [ - (:mobilenet_v1, false), - (:mobilenet_v2, false), - (:mobilenet_v3_small, false), - (:mobilenet_v3_large, false), - ], - resnet => [ - (:resnet18, true), - (:resnet18, false), - (:resnet34, true), - (:resnet34, false), - (:resnet50, true), - (:resnet50, false), - (:resnet101, true), - (:resnet101, false), - (:resnet152, true), - (:resnet152, false), - ], + mobilenet => [(:mobilenet_v1, false), (:mobilenet_v2, false), + (:mobilenet_v3_small, false), (:mobilenet_v3_large, false)], + resnet => [(:resnet18, false), (:resnet34, false), (:resnet50, false), + (:resnet101, false), (:resnet152, false)], resnext => [(:resnext50, false), (:resnext101, false), (:resnext152, false)], - vgg => [ - (:vgg11, false), - (:vgg11, true), - (:vgg11_bn, false), - (:vgg11_bn, true), - (:vgg13, false), - (:vgg13, true), - (:vgg13_bn, false), - (:vgg13_bn, true), - (:vgg16, false), - (:vgg16, true), - (:vgg16_bn, false), - (:vgg16_bn, true), - (:vgg19, false), - (:vgg19, true), - (:vgg19_bn, false), - (:vgg19_bn, true), - ], - vision_transformer => [ - (:tiny, false), - (:small, false), - (:base, false), + vgg => [(:vgg11, false), (:vgg11, true), (:vgg11_bn, false), (:vgg11_bn, true), + (:vgg13, false), (:vgg13, true), (:vgg13_bn, false), (:vgg13_bn, true), + (:vgg16, false), (:vgg16, true), (:vgg16_bn, false), (:vgg16_bn, true), + (:vgg19, false), (:vgg19, true), (:vgg19_bn, false), (:vgg19_bn, true)], + vision_transformer => [(:tiny, false), (:small, false), (:base, false), # CI cant handle these - # (:large, false), - # (:huge, false), - # (:giant, false), - # (:gigantic, false), + # (:large, false), (:huge, false), (:giant, false), (:gigantic, false), ]) @testset "$model_creator: $mode" for (mode, aType, device, ongpu) in MODES,