Skip to content

Commit

Permalink
Remove pretrained resnets for now
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 20, 2023
1 parent 0fc8496 commit 18a737d
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 230 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ always_use_return = true
margin = 92
indent = 4
format_docstrings = true
join_lines_based_on_source = false
separate_kwargs_with_semicolon = true
always_for_in = true
60 changes: 30 additions & 30 deletions Artifacts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,45 @@ lazy = true
sha256 = "e20107404aba1c2c0ed3fad4314033a2fa600cdc0c55d03bc1bfe4f8e5031105"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/alexnet.tar.gz"

[resnet101]
git-tree-sha1 = "6c9143d40950726405b88db0cc021fa1dcbc0896"
lazy = true
# [resnet101]
# git-tree-sha1 = "6c9143d40950726405b88db0cc021fa1dcbc0896"
# lazy = true

[[resnet101.download]]
sha256 = "3840f05b3d996b2b3ea1e8fb6617775fd60ad6b8769402200fdc9c8b8dca246f"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet101.tar.gz"
# [[resnet101.download]]
# sha256 = "3840f05b3d996b2b3ea1e8fb6617775fd60ad6b8769402200fdc9c8b8dca246f"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet101.tar.gz"

[resnet152]
git-tree-sha1 = "892915c44de37537aad97da3de8a4458dfa36297"
lazy = true
# [resnet152]
# git-tree-sha1 = "892915c44de37537aad97da3de8a4458dfa36297"
# lazy = true

[[resnet152.download]]
sha256 = "6033a1ecc46d7f4ed1139067c5f9f5ea0d247656e9abbbe755c4702ec5a636d6"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet152.tar.gz"
# [[resnet152.download]]
# sha256 = "6033a1ecc46d7f4ed1139067c5f9f5ea0d247656e9abbbe755c4702ec5a636d6"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet152.tar.gz"

[resnet18]
git-tree-sha1 = "1d4a46fee1bb87eeef0ce2c85f63cfe0ff47d4de"
lazy = true
# [resnet18]
# git-tree-sha1 = "1d4a46fee1bb87eeef0ce2c85f63cfe0ff47d4de"
# lazy = true

[[resnet18.download]]
sha256 = "f4041ea1d1ec9bba86c7a5a519daaa49bb096a55fcd4ebf74f0743c8bdcb1c35"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet18.tar.gz"
# [[resnet18.download]]
# sha256 = "f4041ea1d1ec9bba86c7a5a519daaa49bb096a55fcd4ebf74f0743c8bdcb1c35"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet18.tar.gz"

[resnet34]
git-tree-sha1 = "306a8055ae9207ae2a316e31b376254557e481c9"
lazy = true
# [resnet34]
# git-tree-sha1 = "306a8055ae9207ae2a316e31b376254557e481c9"
# lazy = true

[[resnet34.download]]
sha256 = "d62e40ee9213ea9611e3fcedc958df4011da1fa108fb1537bac91e6b7778a3c8"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet34.tar.gz"
# [[resnet34.download]]
# sha256 = "d62e40ee9213ea9611e3fcedc958df4011da1fa108fb1537bac91e6b7778a3c8"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet34.tar.gz"

[resnet50]
git-tree-sha1 = "8c5866edb29b53f581a9ed7148efa1dbccde6133"
lazy = true
# [resnet50]
# git-tree-sha1 = "8c5866edb29b53f581a9ed7148efa1dbccde6133"
# lazy = true

[[resnet50.download]]
sha256 = "275365d76e592c6ea35574853a75ee068767641664e7817aedf394fcd7fea25a"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet50.tar.gz"
# [[resnet50.download]]
# sha256 = "275365d76e592c6ea35574853a75ee068767641664e7817aedf394fcd7fea25a"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet50.tar.gz"

[vgg11]
git-tree-sha1 = "ea7e8ef9399a0fe0aad2331781af5d6435950d36"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Pkg.add("Boltz")
## Getting Started

```julia
using Boltz, Lux
using Boltz, Lux, Metalhead

model, ps, st = resnet(:resnet18; pretrained=true)
```
Expand Down
5 changes: 0 additions & 5 deletions ext/BoltzMetalheadExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ function resnet(name::Symbol; pretrained=false, kwargs...)
transform(ResNet(152).layers)
end

# Compatibility with pretrained weights
if pretrained
model = Chain(model[1], model[2])
end

return _initialize_model(name, model; pretrained, kwargs...)
end

Expand Down
9 changes: 1 addition & 8 deletions src/Boltz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,7 @@ function __init__()
end

# Define functions. Methods defined in files or in extensions later
for f in (:alexnet,
:convmixer,
:densenet,
:googlenet,
:mobilenet,
:resnet,
:resnext,
:vgg,
for f in (:alexnet, :convmixer, :densenet, :googlenet, :mobilenet, :resnet, :resnext, :vgg,
:vision_transformer)
@eval function $(f) end
end
Expand Down
8 changes: 2 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ function _get_pretrained_weights_path(name::String)
end
end

function _initialize_model(name::Symbol,
model;
pretrained::Bool=false,
rng=nothing,
seed=0,
kwargs...)
function _initialize_model(name::Symbol, model; pretrained::Bool=false, rng=nothing,
seed=0, kwargs...)
if pretrained
path = _get_pretrained_weights_path(name)
ps = load(joinpath(path, "$name.jld2"), "parameters")
Expand Down
89 changes: 21 additions & 68 deletions src/vision/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@ A VGG block of convolution layers ([reference](https://arxiv.org/abs/1409.1556v6
- `batchnorm`: set to `true` to include batch normalization after each convolution
"""
function _vgg_block(input_filters, output_filters, depth, batchnorm)
k = (3, 3)
p = (1, 1)
layers = []
for _ in 1:depth
push!(layers,
Conv(k, input_filters => output_filters, batchnorm ? identity : relu; pad=p))
if batchnorm
push!(layers, BatchNorm(output_filters, relu))
end
Conv((3, 3), input_filters => output_filters,
batchnorm ? identity : relu; pad=(1, 1)))
batchnorm && push!(layers, BatchNorm(output_filters, relu))
input_filters = output_filters
end
return Chain(layers...)
Expand Down Expand Up @@ -63,12 +60,8 @@ Create VGG classifier (fully connected) layers
- `dropout`: the dropout level between each fully connected layer
"""
function _vgg_classifier_layers(imsize, nclasses, fcsize, dropout)
return Chain(FlattenLayer(),
Dense(Int(prod(imsize)), fcsize, relu),
Dropout(dropout),
Dense(fcsize, fcsize, relu),
Dropout(dropout),
Dense(fcsize, nclasses))
return Chain(FlattenLayer(), Dense(Int(prod(imsize)), fcsize, relu), Dropout(dropout),
Dense(fcsize, fcsize, relu), Dropout(dropout), Dense(fcsize, nclasses))
end

"""
Expand Down Expand Up @@ -104,69 +97,29 @@ function vgg(name::Symbol; kwargs...)
assert_name_present_in(name,
(:vgg11, :vgg11_bn, :vgg13, :vgg13_bn, :vgg16, :vgg16_bn, :vgg19, :vgg19_bn))
model = if name == :vgg11
vgg((224, 224);
config=VGG_CONV_CONFIG[VGG_CONFIG[11]],
inchannels=3,
batchnorm=false,
nclasses=1000,
fcsize=4096,
dropout=0.5f0)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[11]], inchannels=3,
batchnorm=false, nclasses=1000, fcsize=4096, dropout=0.5f0)
elseif name == :vgg11_bn
vgg((224, 224);
config=VGG_CONV_CONFIG[VGG_CONFIG[11]],
inchannels=3,
batchnorm=true,
nclasses=1000,
fcsize=4096,
dropout=0.5f0)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[11]], inchannels=3,
batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0)
elseif name == :vgg13
vgg((224, 224);
config=VGG_CONV_CONFIG[VGG_CONFIG[13]],
inchannels=3,
batchnorm=false,
nclasses=1000,
fcsize=4096,
dropout=0.5f0)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[13]], inchannels=3,
batchnorm=false, nclasses=1000, fcsize=4096, dropout=0.5f0)
elseif name == :vgg13_bn
vgg((224, 224);
config=VGG_CONV_CONFIG[VGG_CONFIG[13]],
inchannels=3,
batchnorm=true,
nclasses=1000,
fcsize=4096,
dropout=0.5f0)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[13]], inchannels=3,
batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0)
elseif name == :vgg16
vgg((224, 224);
config=VGG_CONV_CONFIG[VGG_CONFIG[16]],
inchannels=3,
batchnorm=false,
nclasses=1000,
fcsize=4096,
dropout=0.5f0)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[16]], inchannels=3,
batchnorm=false, nclasses=1000, fcsize=4096, dropout=0.5f0)
elseif name == :vgg16_bn
vgg((224, 224);
config=VGG_CONV_CONFIG[VGG_CONFIG[16]],
inchannels=3,
batchnorm=true,
nclasses=1000,
fcsize=4096,
dropout=0.5f0)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[16]], inchannels=3,
batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0)
elseif name == :vgg19
vgg((224, 224);
config=VGG_CONV_CONFIG[VGG_CONFIG[19]],
inchannels=3,
batchnorm=false,
nclasses=1000,
fcsize=4096,
dropout=0.5f0)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[19]], inchannels=3,
batchnorm=false, nclasses=1000, fcsize=4096, dropout=0.5f0)
elseif name == :vgg19_bn
vgg((224, 224);
config=VGG_CONV_CONFIG[VGG_CONFIG[19]],
inchannels=3,
batchnorm=true,
nclasses=1000,
fcsize=4096,
dropout=0.5f0)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[19]], inchannels=3,
batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0)
end
return _initialize_model(name, model; kwargs...)
end
87 changes: 27 additions & 60 deletions src/vision/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@ struct MultiHeadAttention{Q, A, P} <:
projection::P
end

function MultiHeadAttention(in_planes::Int,
number_heads::Int;
qkv_bias::Bool=false,
attention_dropout_rate::T=0.0f0,
projection_dropout_rate::T=0.0f0) where {T}
function MultiHeadAttention(in_planes::Int, number_heads::Int; qkv_bias::Bool=false,
attention_dropout_rate::T=0.0f0, projection_dropout_rate::T=0.0f0) where {T}
@assert in_planes % number_heads==0 "`in_planes` should be divisible by `number_heads`"
qkv_layer = Dense(in_planes, in_planes * 3; use_bias=qkv_bias)
attention_dropout = Dropout(attention_dropout_rate)
Expand All @@ -32,37 +29,26 @@ function (m::MultiHeadAttention)(x::AbstractArray{T, 3}, ps, st) where {T}

x_reshaped = reshape(x, nfeatures, seq_len * batch_size)
qkv, st_qkv = m.qkv_layer(x_reshaped, ps.qkv_layer, st.qkv_layer)
qkv_reshaped = reshape(qkv,
nfeatures ÷ m.number_heads,
m.number_heads,
seq_len,
qkv_reshaped = reshape(qkv, nfeatures ÷ m.number_heads, m.number_heads, seq_len,
3 * batch_size)
query, key, value = _fast_chunk(qkv_reshaped, Val(3), Val(4))

scale = convert(T, sqrt(size(query, 1) / m.number_heads))
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)),
m.number_heads,
nfeatures ÷ m.number_heads,
seq_len * batch_size)
query_reshaped = reshape(query,
nfeatures ÷ m.number_heads,
m.number_heads,
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.number_heads,
nfeatures ÷ m.number_heads, seq_len * batch_size)
query_reshaped = reshape(query, nfeatures ÷ m.number_heads, m.number_heads,
seq_len * batch_size)

attention = softmax(batched_mul(query_reshaped, key_reshaped) .* scale)
attention, st_attention = m.attention_dropout(attention,
ps.attention_dropout,
attention, st_attention = m.attention_dropout(attention, ps.attention_dropout,
st.attention_dropout)

value_reshaped = reshape(value,
nfeatures ÷ m.number_heads,
m.number_heads,
value_reshaped = reshape(value, nfeatures ÷ m.number_heads, m.number_heads,
seq_len * batch_size)
pre_projection = reshape(batched_mul(attention, value_reshaped),
(nfeatures, seq_len, batch_size))
y, st_projection = m.projection(reshape(pre_projection, size(pre_projection, 1), :),
ps.projection,
st.projection)
ps.projection, st.projection)

st_ = (qkv_layer=st_qkv, attention=st_attention, projection=st_projection)
return reshape(y, :, seq_len, batch_size), st_
Expand Down Expand Up @@ -131,25 +117,19 @@ Transformer as used in the base ViT architecture.
- `mlp_ratio`: ratio of MLP layers to the number of input channels
- `dropout_rate`: dropout rate
"""
function transformer_encoder(in_planes,
depth,
number_heads;
mlp_ratio=4.0f0,
function transformer_encoder(in_planes, depth, number_heads; mlp_ratio=4.0f0,
dropout_rate=0.0f0)
hidden_planes = floor(Int, mlp_ratio * in_planes)
layers = [Chain(SkipConnection(Chain(LayerNorm((in_planes, 1); affine=true),
MultiHeadAttention(in_planes,
number_heads;
attention_dropout_rate=dropout_rate,
projection_dropout_rate=dropout_rate)),
+),
SkipConnection(Chain(LayerNorm((in_planes, 1); affine=true),
Chain(Dense(in_planes => hidden_planes, gelu),
Dropout(dropout_rate),
Dense(hidden_planes => in_planes),
Dropout(dropout_rate));
disable_optimizations=true),
+)) for _ in 1:depth]
MultiHeadAttention(in_planes, number_heads;
attention_dropout_rate=droout_rate,
projection_dropout_rate=dropout_rate)),
+),
SkipConnection(Chain(LayerNorm((in_planes, 1); affine=true),
Chain(Dense(in_planes => hidden_planes, gelu), Dropout(dropout_rate),
Dense(hidden_planes => in_planes), Dropout(dropout_rate));
disable_optimizations=true),
+)) for _ in 1:depth]
return Chain(layers...; disable_optimizations=true)
end

Expand All @@ -166,38 +146,25 @@ function patch_embedding(imsize::Tuple{<:Int, <:Int}=(224, 224);
"Image dimensions must be divisible by the patch size."

return Chain(Conv(patch_size, in_channels => embed_planes; stride=patch_size),
flatten ? _flatten_spatial : identity,
norm_layer(embed_planes))
flatten ? _flatten_spatial : identity, norm_layer(embed_planes))
end

# ViT Implementation
function vision_transformer(;
imsize::Tuple{<:Int, <:Int}=(256, 256),
in_channels::Int=3,
patch_size::Tuple{<:Int, <:Int}=(16, 16),
embed_planes::Int=768,
depth::Int=6,
number_heads=16,
mlp_ratio=4.0f0,
dropout_rate=0.1f0,
embedding_dropout_rate=0.1f0,
pool::Symbol=:class,
num_classes::Int=1000,
kwargs...)
function vision_transformer(; imsize::Tuple{<:Int, <:Int}=(256, 256), in_channels::Int=3,
patch_size::Tuple{<:Int, <:Int}=(16, 16), embed_planes::Int=768, depth::Int=6,
number_heads=16, mlp_ratio=4.0f0, dropout_rate=0.1f0, embedding_dropout_rate=0.1f0,
pool::Symbol=:class, num_classes::Int=1000, kwargs...)
@assert pool in (:class, :mean) "Pool type must be either :class (class token) or :mean (mean pooling)"
number_patches = prod(imsize patch_size)

return Chain(Chain(patch_embedding(imsize; in_channels, patch_size, embed_planes),
ClassTokens(embed_planes),
ViPosEmbedding(embed_planes, number_patches + 1),
ClassTokens(embed_planes), ViPosEmbedding(embed_planes, number_patches + 1),
Dropout(embedding_dropout_rate),
transformer_encoder(embed_planes, depth, number_heads; mlp_ratio, dropout_rate),
((pool == :class) ? WrappedFunction(x -> x[:, 1, :]) :
WrappedFunction(_seconddimmean));
disable_optimizations=true),
WrappedFunction(_seconddimmean)); disable_optimizations=true),
Chain(LayerNorm((embed_planes,); affine=true),
Dense(embed_planes, num_classes, tanh);
disable_optimizations=true);
Dense(embed_planes, num_classes, tanh); disable_optimizations=true);
disable_optimizations=true)
end

Expand Down
Loading

0 comments on commit 18a737d

Please sign in to comment.