diff --git a/Project.toml b/Project.toml index 7eb0770..ed653d0 100644 --- a/Project.toml +++ b/Project.toml @@ -46,9 +46,9 @@ ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" JLD2 = "0.4.48" LazyArtifacts = "1.10" -Lux = "0.5.62" -LuxCore = "0.1.21" -MLDataDevices = "1" +Lux = "0.5.65" +LuxCore = "0.1.24" +MLDataDevices = "1.0.1" Markdown = "1.10" Metalhead = "0.9" NNlib = "0.9.21" diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 3b1cad7..4780670 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -15,7 +15,6 @@ using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer using NNlib: NNlib using WeightInitializers: zeros32, randn32 -using ..Boltz: Boltz using ..Utils: fast_chunk, should_type_assert, mapreduce_stack, unwrap_val, safe_kron, is_extension_loaded diff --git a/src/utils.jl b/src/utils.jl index e93be00..fa25657 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -75,6 +75,11 @@ mapreduce_stack(xs) = mapreduce(unsqueezeN, catN, xs) unwrap_val(x) = x unwrap_val(::Val{T}) where {T} = T +function safe_warning(msg::AbstractString) + @warn msg maxlog=1 + return +end + safe_kron(a, b) = map(safe_kron_internal, a, b) function safe_kron_internal(a::AbstractVector, b::AbstractVector) return safe_kron_internal(get_device_type((a, b)), a, b) @@ -85,7 +90,7 @@ function safe_kron_internal(::Type{CUDADevice}, a::AbstractVector, b::AbstractVe return vec(kron(reshape(a, :, 1), reshape(b, 1, :))) end function safe_kron_internal(::Type{D}, a::AbstractVector, b::AbstractVector) where {D} - @warn "`kron` is not supported on $(D). Falling back to `kron` on CPU." maxlog=1 + safe_warning("`kron` is not supported on $(D). Falling back to `kron` on CPU.") a_cpu = a |> CPUDevice() b_cpu = b |> CPUDevice() return safe_kron_internal(CPUDevice, a_cpu, b_cpu) |> get_device((a, b)) diff --git a/src/vision/Vision.jl b/src/vision/Vision.jl index 70eb8b2..8674067 100644 --- a/src/vision/Vision.jl +++ b/src/vision/Vision.jl @@ -6,15 +6,20 @@ using ConcreteStructs: @concrete using Random: Xoshiro using Lux: Lux -using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer +using LuxCore: LuxCore, AbstractExplicitLayer using NNlib: relu using ..InitializeModels: maybe_initialize_model, INITIALIZE_KWARGS -using ..Layers: Layers +using ..Layers: Layers, ConvBatchNormActivation, ClassTokens, ViPosEmbedding, + VisionTransformerEncoder using ..Utils: flatten_spatial, second_dim_mean, is_extension_loaded include("extensions.jl") include("vit.jl") include("vgg.jl") +@compat(public, + (AlexNet, ConvMixer, DenseNet, MobileNet, ResNet, + ResNeXt, GoogLeNet, ViT, VisionTransformer, VGG)) + end diff --git a/src/vision/vgg.jl b/src/vision/vgg.jl index 310a195..cd3380e 100644 --- a/src/vision/vgg.jl +++ b/src/vision/vgg.jl @@ -1,8 +1,8 @@ -function __vgg_convolutional_layers(config, batchnorm, inchannels) +function vgg_convolutional_layers(config, batchnorm, inchannels) layers = Vector{AbstractExplicitLayer}(undef, length(config) * 2) input_filters = inchannels for (i, (chs, depth)) in enumerate(config) - layers[2i - 1] = Layers.ConvBatchNormActivation( + layers[2i - 1] = ConvBatchNormActivation( (3, 3), input_filters => chs, depth, relu; last_layer_activation=true, conv_kwargs=(; pad=(1, 1)), use_norm=batchnorm, flatten_model=true) layers[2i] = Lux.MaxPool((2, 2)) @@ -11,7 +11,7 @@ function __vgg_convolutional_layers(config, batchnorm, inchannels) return Lux.Chain(layers...) end -function __vgg_classifier_layers(imsize, nclasses, fcsize, dropout) +function vgg_classifier_layers(imsize, nclasses, fcsize, dropout) return Lux.Chain(Lux.FlattenLayer(), Lux.Dense(Int(prod(imsize)) => fcsize, relu), Lux.Dropout(dropout), Lux.Dense(fcsize => fcsize, relu), Lux.Dropout(dropout), Lux.Dense(fcsize => nclasses)) @@ -38,14 +38,15 @@ Create a VGG model [1]. image recognition." arXiv preprint arXiv:1409.1556 (2014). """ function VGG(imsize; config, inchannels, batchnorm=false, nclasses, fcsize, dropout) - feature_extractor = __vgg_convolutional_layers(config, batchnorm, inchannels) + feature_extractor = vgg_convolutional_layers(config, batchnorm, inchannels) img = ones(Float32, (imsize..., inchannels, 2)) rng = Xoshiro(0) + # TODO: Use Lux.outputsize once it is ready _ps, _st = LuxCore.setup(rng, feature_extractor) outsize = size(first(feature_extractor(img, _ps, _st))) - classifier = __vgg_classifier_layers(outsize[1:((end - 1))], nclasses, fcsize, dropout) + classifier = vgg_classifier_layers(outsize[1:((end - 1))], nclasses, fcsize, dropout) return Lux.Chain(feature_extractor, classifier) end diff --git a/src/vision/vit.jl b/src/vision/vit.jl index c857259..cc94349 100644 --- a/src/vision/vit.jl +++ b/src/vision/vit.jl @@ -21,10 +21,10 @@ function VisionTransformer(; return Lux.Chain( Lux.Chain(__patch_embedding(imsize; in_channels, patch_size, embed_planes), - Layers.ClassTokens(embed_planes), - Layers.ViPosEmbedding(embed_planes, number_patches + 1), + ClassTokens(embed_planes), + ViPosEmbedding(embed_planes, number_patches + 1), Lux.Dropout(embedding_dropout_rate), - Layers.VisionTransformerEncoder( + VisionTransformerEncoder( embed_planes, depth, number_heads; mlp_ratio, dropout_rate), Lux.WrappedFunction(ifelse(pool === :class, x -> x[:, 1, :], second_dim_mean)); disable_optimizations=true),