Skip to content

Commit

Permalink
refactor: cleanup Vision module
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 24, 2024
1 parent 957b58b commit 39088c3
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 15 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ ForwardDiff = "0.10.36"
GPUArraysCore = "0.1.6"
JLD2 = "0.4.48"
LazyArtifacts = "1.10"
Lux = "0.5.62"
LuxCore = "0.1.21"
MLDataDevices = "1"
Lux = "0.5.65"
LuxCore = "0.1.24"
MLDataDevices = "1.0.1"
Markdown = "1.10"
Metalhead = "0.9"
NNlib = "0.9.21"
Expand Down
1 change: 0 additions & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using NNlib: NNlib
using WeightInitializers: zeros32, randn32

using ..Boltz: Boltz
using ..Utils: fast_chunk, should_type_assert, mapreduce_stack, unwrap_val, safe_kron,
is_extension_loaded

Expand Down
7 changes: 6 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ mapreduce_stack(xs) = mapreduce(unsqueezeN, catN, xs)
unwrap_val(x) = x
unwrap_val(::Val{T}) where {T} = T

function safe_warning(msg::AbstractString)
@warn msg maxlog=1
return
end

safe_kron(a, b) = map(safe_kron_internal, a, b)
function safe_kron_internal(a::AbstractVector, b::AbstractVector)
return safe_kron_internal(get_device_type((a, b)), a, b)
Expand All @@ -85,7 +90,7 @@ function safe_kron_internal(::Type{CUDADevice}, a::AbstractVector, b::AbstractVe
return vec(kron(reshape(a, :, 1), reshape(b, 1, :)))
end
function safe_kron_internal(::Type{D}, a::AbstractVector, b::AbstractVector) where {D}
@warn "`kron` is not supported on $(D). Falling back to `kron` on CPU." maxlog=1
safe_warning("`kron` is not supported on $(D). Falling back to `kron` on CPU.")
a_cpu = a |> CPUDevice()
b_cpu = b |> CPUDevice()
return safe_kron_internal(CPUDevice, a_cpu, b_cpu) |> get_device((a, b))
Expand Down
9 changes: 7 additions & 2 deletions src/vision/Vision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ using ConcreteStructs: @concrete
using Random: Xoshiro

using Lux: Lux
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxCore: LuxCore, AbstractExplicitLayer
using NNlib: relu

using ..InitializeModels: maybe_initialize_model, INITIALIZE_KWARGS
using ..Layers: Layers
using ..Layers: Layers, ConvBatchNormActivation, ClassTokens, ViPosEmbedding,
VisionTransformerEncoder
using ..Utils: flatten_spatial, second_dim_mean, is_extension_loaded

include("extensions.jl")
include("vit.jl")
include("vgg.jl")

@compat(public,
(AlexNet, ConvMixer, DenseNet, MobileNet, ResNet,
ResNeXt, GoogLeNet, ViT, VisionTransformer, VGG))

end
11 changes: 6 additions & 5 deletions src/vision/vgg.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function __vgg_convolutional_layers(config, batchnorm, inchannels)
function vgg_convolutional_layers(config, batchnorm, inchannels)
layers = Vector{AbstractExplicitLayer}(undef, length(config) * 2)
input_filters = inchannels
for (i, (chs, depth)) in enumerate(config)
layers[2i - 1] = Layers.ConvBatchNormActivation(
layers[2i - 1] = ConvBatchNormActivation(
(3, 3), input_filters => chs, depth, relu; last_layer_activation=true,
conv_kwargs=(; pad=(1, 1)), use_norm=batchnorm, flatten_model=true)
layers[2i] = Lux.MaxPool((2, 2))
Expand All @@ -11,7 +11,7 @@ function __vgg_convolutional_layers(config, batchnorm, inchannels)
return Lux.Chain(layers...)
end

function __vgg_classifier_layers(imsize, nclasses, fcsize, dropout)
function vgg_classifier_layers(imsize, nclasses, fcsize, dropout)
return Lux.Chain(Lux.FlattenLayer(), Lux.Dense(Int(prod(imsize)) => fcsize, relu),
Lux.Dropout(dropout), Lux.Dense(fcsize => fcsize, relu),
Lux.Dropout(dropout), Lux.Dense(fcsize => nclasses))
Expand All @@ -38,14 +38,15 @@ Create a VGG model [1].
image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
function VGG(imsize; config, inchannels, batchnorm=false, nclasses, fcsize, dropout)
feature_extractor = __vgg_convolutional_layers(config, batchnorm, inchannels)
feature_extractor = vgg_convolutional_layers(config, batchnorm, inchannels)

img = ones(Float32, (imsize..., inchannels, 2))
rng = Xoshiro(0)
# TODO: Use Lux.outputsize once it is ready
_ps, _st = LuxCore.setup(rng, feature_extractor)
outsize = size(first(feature_extractor(img, _ps, _st)))

classifier = __vgg_classifier_layers(outsize[1:((end - 1))], nclasses, fcsize, dropout)
classifier = vgg_classifier_layers(outsize[1:((end - 1))], nclasses, fcsize, dropout)

return Lux.Chain(feature_extractor, classifier)
end
Expand Down
6 changes: 3 additions & 3 deletions src/vision/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ function VisionTransformer(;

return Lux.Chain(
Lux.Chain(__patch_embedding(imsize; in_channels, patch_size, embed_planes),
Layers.ClassTokens(embed_planes),
Layers.ViPosEmbedding(embed_planes, number_patches + 1),
ClassTokens(embed_planes),
ViPosEmbedding(embed_planes, number_patches + 1),
Lux.Dropout(embedding_dropout_rate),
Layers.VisionTransformerEncoder(
VisionTransformerEncoder(
embed_planes, depth, number_heads; mlp_ratio, dropout_rate),
Lux.WrappedFunction(ifelse(pool === :class, x -> x[:, 1, :], second_dim_mean));
disable_optimizations=true),
Expand Down

0 comments on commit 39088c3

Please sign in to comment.