From cc33ada33f25fa8f1d74e34f9e2981b5b621f1cd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 08:39:18 -0400 Subject: [PATCH] refactor: static fields in layers --- Project.toml | 4 +- ext/LuxSimpleChainsExt.jl | 4 +- src/Lux.jl | 3 +- src/contrib/contrib.jl | 1 + src/contrib/debug.jl | 47 +++-- src/extended_ops.jl | 22 +- src/helpers/compact.jl | 32 +-- src/helpers/stateful.jl | 88 ++++---- src/layers/basic.jl | 218 ++++++++++---------- src/layers/containers.jl | 43 ++-- src/layers/conv.jl | 289 ++++++++++++++------------- src/layers/extension.jl | 19 +- src/layers/normalize.jl | 175 ++++++++-------- src/layers/recurrent.jl | 176 ++++++++-------- src/transform/simplechains.jl | 33 +-- src/utils.jl | 16 +- test/helpers/size_propagator_test.jl | 4 +- test/qa_tests.jl | 7 +- 18 files changed, 596 insertions(+), 585 deletions(-) diff --git a/Project.toml b/Project.toml index ae1de4e3f..55fd49779 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.68-DEV" +version = "0.5.68" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -11,7 +11,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" @@ -78,7 +77,6 @@ ChainRulesCore = "1.24" Compat = "4.15" ComponentArrays = "0.15.16" ConcreteStructs = "0.2.3" -ConstructionBase = "1.5" DispatchDoctor = "0.4.12" DynamicExpressions = "0.16, 0.17, 0.18, 0.19" Enzyme = "0.12.26" diff --git a/ext/LuxSimpleChainsExt.jl b/ext/LuxSimpleChainsExt.jl index 7e1d8ddd1..c7a607b25 100644 --- a/ext/LuxSimpleChainsExt.jl +++ b/ext/LuxSimpleChainsExt.jl @@ -25,8 +25,8 @@ end equivalent_simplechains_fn(::typeof(NNlib.relu)) = SimpleChains.relu equivalent_simplechains_fn(f::F) where {F} = f -function Lux.make_simplechain_network(layer::Dense{use_bias}) where {use_bias} - return SimpleChains.TurboDense{use_bias}( +function Lux.make_simplechain_network(layer::Dense) + return SimpleChains.TurboDense{Lux.has_bias(layer)}( equivalent_simplechains_fn(layer.activation), layer.out_dims) end diff --git a/src/Lux.jl b/src/Lux.jl index b755ae569..712996c45 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -8,7 +8,6 @@ using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore, NoTangent, @thunk using Compat: @compat using ConcreteStructs: @concrete -using ConstructionBase: ConstructionBase using FastClosures: @closure using Functors: Functors, fmap using GPUArraysCore: @allowscalar @@ -16,7 +15,7 @@ using LossFunctions: LossFunctions using Markdown: @doc_str using Optimisers: Optimisers using Random: Random, AbstractRNG -using Static: StaticBool, True, False, static +using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic using Reexport: @reexport using Statistics: mean using UnrolledUtilities: unrolled_map, unrolled_mapreduce diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index 0713b8866..fdcf6b0c2 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -14,6 +14,7 @@ using Markdown: @doc_str using Optimisers: Optimisers using Random: AbstractRNG, Random using Setfield: Setfield +using Static: StaticSymbol, StaticBool, True, known, static, dynamic const CRC = ChainRulesCore diff --git a/src/contrib/debug.jl b/src/contrib/debug.jl index 3d07f1783..318720882 100644 --- a/src/contrib/debug.jl +++ b/src/contrib/debug.jl @@ -1,6 +1,8 @@ """ - DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both, - error_check::Bool=true, location::KeyPath=KeyPath()) + DebugLayer(layer::AbstractExplicitLayer; + nan_check::Union{Symbol, StaticSymbol, Val}=static(:both), + error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(), + location::Union{KeyPath, String}=KeyPath()) A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging. @@ -41,15 +43,18 @@ track where the error originates. See [`Lux.Experimental.@debug_mode`](@ref) to construct this layer. """ -@concrete struct DebugLayer{NaNCheck, ErrorCheck} <: - AbstractExplicitContainerLayer{(:layer,)} +@concrete struct DebugLayer <: AbstractExplicitContainerLayer{(:layer,)} + nan_check <: StaticSymbol + error_check <: StaticBool layer <: AbstractExplicitLayer location::KeyPath end -function DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both, - error_check::Bool=true, location::Union{KeyPath, String}=KeyPath()) - @argcheck nan_check in (:both, :forward, :backward, :none) +function DebugLayer(layer::AbstractExplicitLayer; + nan_check::Union{Symbol, StaticSymbol, Val}=static(:both), + error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(), + location::Union{KeyPath, String}=KeyPath()) + @argcheck dynamic(nan_check) in (:both, :forward, :backward, :none) if location isa String Base.depwarn( @@ -58,23 +63,23 @@ function DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both, location = KeyPath(Symbol.(split(location, "."))...) end - return DebugLayer{nan_check, error_check}(layer, location) + return DebugLayer(static(nan_check), static(error_check), layer, location) end -function (d::DebugLayer{NaNCheck, ErrorCheck})(x, ps, st) where {NaNCheck, ErrorCheck} +function (d::DebugLayer)(x, ps, st) CRC.ignore_derivatives() do @info lazy"Input Type: $(typeof(x)) | Input Structure: $(Utils.structure(x))." @info lazy"Running Layer: $(d.layer) at location $(d.location)!" - if NaNCheck ∈ (:both, :forward) + if known(d.nan_check) ∈ (:both, :forward) check_nan_and_throw(x, "input", d.layer, d.location) check_nan_and_throw(ps, "parameters", d.layer, d.location) check_nan_and_throw(st, "states", d.layer, d.location) end end - y, stₙ = debug_layer_impl( - d.layer, x, ps, st, d.location, ErrorCheck, NaNCheck ∈ (:both, :backward)) + y, stₙ = debug_layer_impl(d.layer, x, ps, st, d.location, known(d.error_check), + known(d.nan_check) ∈ (:both, :backward)) CRC.ignore_derivatives() do - if NaNCheck ∈ (:both, :forward) + if known(d.nan_check) ∈ (:both, :forward) check_nan_and_throw(y, "output", d.layer, d.location) check_nan_and_throw(stₙ, "states", d.layer, d.location) end @@ -99,34 +104,34 @@ function check_nan_and_throw(x, str::AbstractString, layer, location::KeyPath) return fmap_with_path(nan_check, x) end -function debug_layer_impl(layer, x, ps, st, location, EC, NC) +function debug_layer_impl(layer, x, ps, st, location, error_check, _) y, stₙ = try apply(layer, x, ps, st) catch - EC && + error_check && @error "Layer $(layer) failed!! This layer is present at location $(location)." rethrow() end return y, stₙ end -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(debug_layer_impl), layer, x, ps, st, location, EC, NC) +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(debug_layer_impl), + layer, x, ps, st, location, error_check, nan_check_backward) result, ∇debug_layer_internal = CRC.rrule_via_ad(cfg, apply, layer, x, ps, st) syms = ("LuxCore.apply", "layer", "x", "ps", "st") function ∇debug_layer_internal_with_checks(Δ) - NC && check_nan_and_throw(Δ, "pullback input", layer, location) + nan_check_backward && check_nan_and_throw(Δ, "pullback input", layer, location) gs = try ∇debug_layer_internal(Δ) catch - EC && + error_check && @error "Backward Pass for Layer $(layer) failed!! This layer is present at location $(location)." rethrow() end - if NC - for (i, g) in enumerate(gs) + if nan_check_backward + foreach(enumerate(gs)) do (i, g) check_nan_and_throw(g, "pullback output ($(syms[i]))", layer, location) end end diff --git a/src/extended_ops.jl b/src/extended_ops.jl index 4e8cb37a1..2bcd12555 100644 --- a/src/extended_ops.jl +++ b/src/extended_ops.jl @@ -12,12 +12,14 @@ using Compat: @compat using EnzymeCore: EnzymeCore using FastClosures: @closure using MLDataDevices: get_device_type, AbstractGPUDevice, AbstractDevice -using Static: StaticBool, known +using Static: StaticBool, StaticSymbol, known using ..Utils: Utils const CRC = ChainRulesCore +const KnownSymbolType{v} = Union{Val{v}, StaticSymbol{v}} + # `xlogx` and `xlogy` ## We don't use `LogExpFunctions` since they don't support GPU broadcasting. See ## https://github.com/LuxDL/Lux.jl/pull/796. Additionally we have special broadcast rrules. @@ -79,14 +81,15 @@ end """ getproperty(x, ::Val{v}) + getproperty(x, ::StaticSymbol{v}) -Similar to `Base.getproperty` but requires a `Val`. Additionally if `v` is not present in -`x`, then `nothing` is returned. +Similar to `Base.getproperty` but requires a `Val` (or `Static.StaticSymbol`). Additionally, +if `v` is not present in `x`, then `nothing` is returned. """ -function getproperty(x, ::Val{v}) where {v} +function getproperty(x, ::KnownSymbolType{v}) where {v} return v ∈ Base.propertynames(x) ? Base.getproperty(x, v) : nothing end -@generated function getproperty(x::NamedTuple{names}, ::Val{v}) where {names, v} +@generated function getproperty(x::NamedTuple{names}, ::KnownSymbolType{v}) where {names, v} return v ∈ names ? :(x.$v) : :(nothing) end @@ -228,3 +231,12 @@ const safe_eachslice = LuxOps.eachslice const private_xlogx = LuxOps.xlogx const private_xlogy = LuxOps.xlogy const private_foldl_init = LuxOps.foldl_init + +# These are defined here to avoid a circular dependency among modules +for (op, field) in (:bias => :use_bias, :affine => :affine, + :track_stats => :track_stats, :train_state => :train_state) + @eval function $(Symbol(:has_, op))(l::AbstractExplicitLayer) + res = known(safe_getproperty(l, Val($(Meta.quot(field))))) + return ifelse(res === nothing, false, res) + end +end diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index ba196e331..bbbb8544c 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -292,19 +292,16 @@ macro non_trainable(x) return esc(:($(CompactMacroImpl.NonTrainable)($(x)))) end -@concrete struct CompactLuxLayer{dispatch} <: - AbstractExplicitContainerLayer{(:layers, :value_storage)} - f - name +struct CompactLuxLayer{dispatch, F, N, L, V, SK} <: + AbstractExplicitContainerLayer{(:layers, :value_storage)} + d::StaticSymbol{dispatch} + f::F + name::N strings::NTuple{3, String} - setup_strings - layers - value_storage - stored_kwargs -end - -function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch} - return CompactLuxLayer{dispatch} + setup_strings::Any + layers::L + value_storage::V + stored_kwargs::SK end function initialparameters(rng::AbstractRNG, m::CompactLuxLayer) @@ -320,8 +317,8 @@ function initialstates(rng::AbstractRNG, m::CompactLuxLayer) base_states, (; ₋₋₋kwargs₋₋₋=NamedTuple{m.stored_kwargs[1]}(m.stored_kwargs[2]))) end -function CompactLuxLayer{dispatch}( - f::F, name::NAME_TYPE, str::Tuple, splatted_kwargs; kws...) where {F, dispatch} +function CompactLuxLayer(dispatch::StaticSymbol, f::F, name::NAME_TYPE, + str::Tuple, splatted_kwargs; kws...) where {F} layers, others = [], [] setup_strings = NamedTuple() for (name, val) in pairs(kws) @@ -353,7 +350,7 @@ function CompactLuxLayer{dispatch}( NamedTuple((name => CompactMacroImpl.kwarg_descriptor(val),))) end end - return CompactLuxLayer{dispatch}(f, name, str, setup_strings, NamedTuple((; layers...)), + return CompactLuxLayer(dispatch, f, name, str, setup_strings, NamedTuple((; layers...)), CompactMacroImpl.ValueStorage(; others...), splatted_kwargs) end @@ -423,6 +420,7 @@ using ChainRulesCore: @non_differentiable using ConcreteStructs: @concrete using MacroTools: MacroTools, @capture, combinedef, splitdef using Random: AbstractRNG +using Static: static using LuxCore: LuxCore, AbstractExplicitLayer using ..Lux: Lux, CompactLuxLayer, LuxCompactModelParsingException, StatefulLuxLayer, @@ -467,6 +465,7 @@ function compact_macro_impl(_exs...) # check if user has provided a custom dispatch dispatch, kwexs = extract_reserved_kwarg(kwexs, :dispatch) + dispatch === nothing && (dispatch = QuoteNode(:₋₋₋no_special_dispatch₋₋₋)) # Extract splatted kwargs splat_idxs = findall(ex -> ex.head == :..., kwexs) @@ -495,7 +494,8 @@ function compact_macro_impl(_exs...) fex = supportself(fex, vars, splatted_kwargs) # assemble - return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block), + return esc(:($CompactLuxLayer( + $(static)($(dispatch)), $fex, $name, ($layer, $input, $block), (($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...)))) end diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index 31bb003ec..9d47a1c86 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -46,46 +46,17 @@ mutable struct StatefulLuxLayer{ST, M <: AbstractExplicitLayer, psType, stType} ps::psType st::stType st_any::Any -end - -function StatefulLuxLayer{ST}(model, ps, st, st_any) where {ST} - return StatefulLuxLayer{ST, typeof(model), typeof(ps), typeof(st)}( - model, ps, st, st_any) -end + fixed_state_type::ST -for op in (:trainmode, :testmode) - @eval function LuxCore.$(op)(s::StatefulLuxLayer{ST}) where {ST} - return StatefulLuxLayer{ST}(s.model, s.ps, LuxCore.$(op)(get_state(s))) + function StatefulLuxLayer( + model::AbstractExplicitLayer, ps, st, st_any, fixed_state_type::StaticBool) + return new{typeof(fixed_state_type), typeof(model), typeof(ps), typeof(st)}( + model, ps, st, st_any, fixed_state_type) end end -function LuxCore.update_state( - s::StatefulLuxLayer{ST}, key::Symbol, value; kwargs...) where {ST} - st = LuxCore.update_state(get_state(s), key, value; kwargs...) - return StatefulLuxLayer{ST}(s.model, s.ps, st) -end - -function Base.show(io::IO, ::MIME"text/plain", s::StatefulLuxLayer{ST}) where {ST} - PrettyPrinting.print_wrapper_model(io, "StatefulLuxLayer{$ST}", s.model) -end - -function Functors.functor(::Type{<:StatefulLuxLayer{FT}}, x) where {FT} - return ((; x.model, x.ps, x.st, x.st_any), - nt -> StatefulLuxLayer{FT}(nt.model, nt.ps, nt.st, nt.st_any)) -end - -function LuxCore.parameterlength(m::StatefulLuxLayer) - m.ps === nothing && return LuxCore.parameterlength(m.model) - return LuxCore.parameterlength(m.ps) -end -function LuxCore.statelength(m::StatefulLuxLayer{FT}) where {FT} - FT && return LuxCore.statelength(m.st) - return LuxCore.statelength(m.st_any) -end -LuxCore.apply(m::StatefulLuxLayer, x, p) = m(x, p) - -function ConstructionBase.constructorof(::Type{<:StatefulLuxLayer{FT}}) where {FT} - return StatefulLuxLayer{FT} +function StatefulLuxLayer{ST}(model, ps, st, st_any) where {ST} + return StatefulLuxLayer(model, ps, st, st_any, static(ST)) end function StatefulLuxLayer(model::AbstractExplicitLayer, st::NamedTuple; kwargs...) @@ -104,22 +75,53 @@ function StatefulLuxLayer{false}(model::AbstractExplicitLayer, ps, st::NamedTupl return StatefulLuxLayer{false}(model, ps, nothing, st) end -get_state(s::StatefulLuxLayer{true}) = s.st -get_state(s::StatefulLuxLayer{false}) = s.st_any +for op in (:trainmode, :testmode) + @eval function LuxCore.$(op)(s::StatefulLuxLayer) + return StatefulLuxLayer{dynamic(s.fixed_state_type)}( + s.model, s.ps, LuxCore.$(op)(get_state(s))) + end +end + +function LuxCore.update_state(s::StatefulLuxLayer, key::Symbol, value; kwargs...) + st = LuxCore.update_state(get_state(s), key, value; kwargs...) + return StatefulLuxLayer{dynamic(s.fixed_state_type)}(s.model, s.ps, st) +end -CRC.@non_differentiable get_state(::Any) +function Base.show(io::IO, ::MIME"text/plain", s::StatefulLuxLayer) + PrettyPrinting.print_wrapper_model( + io, "StatefulLuxLayer{$(dynamic(s.fixed_state_type))}", s.model) +end + +function Functors.functor(::Type{<:StatefulLuxLayer}, x) + recon = let ft = x.fixed_state_type + nt -> StatefulLuxLayer(nt.model, nt.ps, nt.st, nt.st_any, ft) + end + return (; x.model, x.ps, x.st, x.st_any), recon +end + +function parameterlength(m::StatefulLuxLayer) + m.ps === nothing && return parameterlength(m.model) + return parameterlength(m.ps) +end +statelength(m::StatefulLuxLayer) = statelength(get_state(m)) +apply(m::StatefulLuxLayer, x, p) = m(x, p) + +get_state(s::StatefulLuxLayer{True}) = s.st +get_state(s::StatefulLuxLayer{False}) = s.st_any + +CRC.@non_differentiable get_state(::StatefulLuxLayer) function set_state!( - s::StatefulLuxLayer{true, M, psType, stType}, st::stType) where {M, psType, stType} + s::StatefulLuxLayer{True, <:Any, <:Any, stType}, st::stType) where {stType} s.st = st end -function set_state!(::StatefulLuxLayer{true, M, psType, stType}, - ::stType2) where {M, psType, stType, stType2} +function set_state!( + ::StatefulLuxLayer{True, <:Any, <:Any, stType}, ::stType2) where {stType, stType2} throw(ArgumentError("Output state from the model has type `$(stType2)`, but expected \ `$(stType)`. Construct the Stateful layer as \ `StatefulLuxLayer{false}` instead of `StatefulLuxLayer{true}`.")) end -set_state!(s::StatefulLuxLayer{false}, st) = (s.st_any = st) +set_state!(s::StatefulLuxLayer{False}, st) = (s.st_any = st) CRC.@non_differentiable set_state!(::Any...) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9e2e8d9e5..77cd98f48 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -38,7 +38,7 @@ end outputsize(r::ReshapeLayer) = r.dims -function (r::ReshapeLayer)(x::AbstractArray, ps, st::NamedTuple) +function (r::ReshapeLayer)(x::AbstractArray, _, st::NamedTuple) return reshape(x, r.dims..., size(x, ndims(x))), st end @@ -81,34 +81,35 @@ julia> y, st_new = model(x, ps, st) ([3.0, 2.0, 1.0], NamedTuple()) ``` """ -@kwdef struct ReverseSequence{D <: Union{Int, Nothing}} <: AbstractExplicitLayer - dim::D = nothing +@concrete struct ReverseSequence <: AbstractExplicitLayer + dim <: Union{Nothing, StaticInt} end -function (r::ReverseSequence{Nothing})(x::AbstractVector{T}, ps, st::NamedTuple) where {T} - return safe_reverse(x), st +ReverseSequence(dim) = ReverseSequence(static(dim)) +ReverseSequence(; dim=nothing) = ReverseSequence(static(dim)) + +function (r::ReverseSequence{Nothing})(x::AbstractArray, _, st::NamedTuple) + return safe_reverse(x; dims=max(ndims(x) - 1, 1)), st end -function (r::ReverseSequence{Nothing})( - x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} - return safe_reverse(x; dims=ndims(x) - 1), st +function (r::ReverseSequence{StaticInt{1}})(x::AbstractVector, _, st::NamedTuple) + return safe_reverse(x), st end -function (r::ReverseSequence)(x::AbstractVector{T}, ps, st::NamedTuple) where {T} - r.dim == 1 && return safe_reverse(x), st - throw(ArgumentError("Cannot specify a dimension other than 1 for AbstractVector{T}")) +function (r::ReverseSequence{StaticInt{N}})(::AbstractVector, _, st::NamedTuple) where {N} + throw(ArgumentError("Cannot specify a dimension ($(N) != 1) for AbstractVector")) end -function (r::ReverseSequence)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} - return safe_reverse(x; dims=r.dim), st +function (r::ReverseSequence{StaticInt{N}})(x::AbstractArray, _, st::NamedTuple) where {N} + return safe_reverse(x; dims=N), st end """ - FlattenLayer(N = nothing) + FlattenLayer(; N = nothing) Flattens the passed array into a matrix. -## Arguments +## Keyword Arguments - `N`: Flatten the first `N` dimensions of the input array. If `nothing`, then all dimensions (except the last) are flattened. Note that the batch dimension is never @@ -140,11 +141,12 @@ julia> y, st_new = model(x, ps, st); (8, 2) ``` """ -struct FlattenLayer{NT <: Union{Nothing, Int}} <: AbstractExplicitLayer - N::NT # FIXME: In v1 promote this to type parameter to allow type-stable forward pass +@concrete struct FlattenLayer <: AbstractExplicitLayer + N <: Union{Nothing, StaticInt} end -FlattenLayer(; N=nothing) = FlattenLayer(N) +FlattenLayer(N) = FlattenLayer(static(N)) +FlattenLayer(; N=nothing) = FlattenLayer(static(N)) function (::FlattenLayer{Nothing})(x::AbstractArray{T, N}, _, st::NamedTuple) where {T, N} return reshape(x, :, size(x, N)), st @@ -175,16 +177,17 @@ Return a view of all the data of the input `x` where the index for dimension `di - `view(x,:,:,...,i,:,:,...)` where `i` is in position `d` - Empty `NamedTuple()` """ -struct SelectDim{dim, index} <: AbstractExplicitLayer end +@concrete struct SelectDim <: AbstractExplicitLayer + dim <: StaticInt + index <: StaticInt +end -SelectDim(dim, index) = SelectDim{dim, index}() +SelectDim(dim, index) = SelectDim(static(dim), static(index)) -function (s::SelectDim{dim, index})(x, ps, st::NamedTuple) where {dim, index} - return selectdim(x, dim, index), st -end +(s::SelectDim)(x, _, st::NamedTuple) = selectdim(x, known(s.dim), known(s.index)), st -function Base.show(io::IO, ::SelectDim{dim, index}) where {dim, index} - return print(io, "SelectDim(", dim, ", ", index, ")") +function Base.show(io::IO, s::SelectDim) + return print(io, "SelectDim(dim = ", s.dim, ", index = ", s.index, ")") end """ @@ -211,7 +214,7 @@ julia> y, st_new = model(x, ps, st) """ struct NoOpLayer <: AbstractExplicitLayer end -(noop::NoOpLayer)(x, ps, st::NamedTuple) = x, st +(noop::NoOpLayer)(x, _, st::NamedTuple) = x, st """ WrappedFunction{DC}(f) @@ -232,15 +235,21 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be ## Inputs - - `x`: s.t `hasmethod(f, (typeof(x),))` is `true` + - `x`: s.t `hasmethod(f, (typeof(x),))` is `true` if :direct_call else + `hasmethod(f, (typeof(x), NamedTuple, NamedTuple))` is `true` ## Returns - Output of `f(x)` - Empty `NamedTuple()` """ -@concrete struct WrappedFunction{DC} <: AbstractExplicitLayer - func +struct WrappedFunction{DC, F} <: AbstractExplicitLayer + call_mode::StaticSymbol{DC} + func::F +end + +function WrappedFunction{call_mode}(f::F) where {call_mode, F} + return WrappedFunction(static(call_mode), f) end function WrappedFunction(f::F) where {F} @@ -266,17 +275,17 @@ function (wf::WrappedFunction{:runtime_check})(x, ps, st::NamedTuple) end wrapped_function_call(f, x, ps, st, ::False) = f(x, ps, st) -wrapped_function_call(f, x, ps, st, ::True) = f(x), st +wrapped_function_call(f, x, _, st, ::True) = f(x), st function Base.show(io::IO, w::WrappedFunction{T}) where {T} - print(io, "WrappedFunction{$(Meta.quot(T))}(") + print(io, "WrappedFunction(", static(w.call_mode), ", ") show(io, w.func) print(io, ")") end """ Dense(in_dims => out_dims, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias::Bool=true, allow_fast_activation::Bool=true) + init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) Create a traditional fully connected layer, whose forward pass is given by: `y = activation.(weight * x .+ bias)` @@ -311,33 +320,35 @@ Create a traditional fully connected layer, whose forward pass is given by: - `weight`: Weight Matrix of size `(out_dims, in_dims)` - `bias`: Bias of size `(out_dims, 1)` (present if `use_bias=true`) """ -@concrete struct Dense{use_bias} <: AbstractExplicitLayer +@concrete struct Dense <: AbstractExplicitLayer activation - in_dims::Int - out_dims::Int + in_dims <: IntegerType + out_dims <: IntegerType init_weight init_bias + use_bias <: StaticBool end -function Base.show(io::IO, d::Dense{use_bias}) where {use_bias} +function Base.show(io::IO, d::Dense) print(io, "Dense($(d.in_dims) => $(d.out_dims)") (d.activation == identity) || print(io, ", $(d.activation)") - use_bias || print(io, ", use_bias=false") + has_bias(d) || print(io, ", use_bias=false") return print(io, ")") end -function Dense(mapping::Pair{<:Int, <:Int}, activation=identity; kwargs...) +function Dense(mapping::Pair{<:IntegerType, <:IntegerType}, activation=identity; kwargs...) return Dense(first(mapping), last(mapping), activation; kwargs...) end -function Dense(in_dims::Int, out_dims::Int, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias::Bool=true, allow_fast_activation::Bool=true) - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation - return Dense{use_bias}(activation, in_dims, out_dims, init_weight, init_bias) +function Dense(in_dims::IntegerType, out_dims::IntegerType, activation=identity; + init_weight=glorot_uniform, init_bias=zeros32, + use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return Dense(activation, in_dims, out_dims, init_weight, init_bias, static(use_bias)) end -function initialparameters(rng::AbstractRNG, d::Dense{use_bias}) where {use_bias} - if use_bias +function initialparameters(rng::AbstractRNG, d::Dense) + if has_bias(d) return (weight=d.init_weight(rng, d.out_dims, d.in_dims), bias=d.init_bias(rng, d.out_dims, 1)) #TODO: In v1 make it a vector else @@ -345,29 +356,23 @@ function initialparameters(rng::AbstractRNG, d::Dense{use_bias}) where {use_bias end end -function parameterlength(d::Dense{use_bias}) where {use_bias} - return use_bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims -end +parameterlength(d::Dense) = d.out_dims * d.in_dims + has_bias(d) * d.out_dims statelength(d::Dense) = 0 outputsize(d::Dense) = (d.out_dims,) -function (d::Dense)(x::AbstractVector, ps, st::NamedTuple) - return vec(first(d(reshape(x, :, 1), ps, st))), st -end - function (d::Dense)(x::AbstractArray, ps, st::NamedTuple) - return reshape(first(d(reshape(x, size(x, 1), :), ps, st)), :, size(x)[2:end]...), st -end - -function (d::Dense)(x::AbstractMatrix, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) bias = safe_vec(safe_getproperty(ps, Val(:bias))) - return fused_dense_bias_activation(d.activation, ps.weight, y, bias), st + z = matrix_to_array( + fused_dense_bias_activation(d.activation, ps.weight, make_abstract_matrix(y), bias), + y) + return z, st end """ - Scale(dims, activation=identity; init_weight=ones32, init_bias=zeros32, bias::Bool=true) + Scale(dims, activation=identity; init_weight=ones32, init_bias=zeros32, use_bias=True(), + allow_fast_activation=True()) Create a Sparsely Connected Layer with a very specific structure (only Diagonal Elements are non-zero). The forward pass is given by: `y = activation.(weight .* x .+ bias)` @@ -402,61 +407,61 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .* - `weight`: Weight Array of size `(dims...)` - `bias`: Bias of size `(dims...)` """ -@concrete struct Scale{use_bias} <: AbstractExplicitLayer +@concrete struct Scale{UB <: StaticBool} <: AbstractExplicitLayer activation - dims + dims <: Tuple{Vararg{IntegerType}} init_weight init_bias + use_bias::UB end -function Base.show(io::IO, d::Scale{use_bias}) where {use_bias} +function Base.show(io::IO, d::Scale) print(io, "Scale($(d.dims)") (d.activation == identity) || print(io, ", $(d.activation)") - use_bias || print(io, ", use_bias=false") + has_bias(d) || print(io, ", use_bias=false") return print(io, ")") end -function Scale( - dims::Tuple{Vararg{Integer}}, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias::Bool=true, allow_fast_activation::Bool=true) - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation - return Scale{use_bias}(activation, dims, init_weight, init_bias) +function Scale(dims::Tuple{Vararg{IntegerType}}, activation=identity; + init_weight=glorot_uniform, init_bias=zeros32, + use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return Scale(activation, dims, init_weight, init_bias, static(use_bias)) end -function Scale(s1::Integer, s23::Integer...; _act=identity, kwargs...) +function Scale(s1::IntegerType, s23::IntegerType...; _act=identity, kwargs...) return Scale(tuple(s1, s23...), _act; kwargs...) end function Scale(size_act...; kwargs...) return Scale(size_act[1:(end - 1)]...; _act=size_act[end], kwargs...) end -function initialparameters(rng::AbstractRNG, d::Scale{use_bias}) where {use_bias} - if use_bias - return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...)) - else - return (weight=d.init_weight(rng, d.dims...),) +function initialparameters(rng::AbstractRNG, d::Scale) + if has_bias(d) + return (; weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...)) end + return (; weight=d.init_weight(rng, d.dims...),) end -parameterlength(d::Scale{use_bias}) where {use_bias} = (1 + use_bias) * prod(d.dims) +parameterlength(d::Scale) = (1 + has_bias(d)) * prod(d.dims) statelength(d::Scale) = 0 outputsize(d::Scale) = d.dims -function (d::Scale{false})(x::AbstractArray, ps, st::NamedTuple) +function (d::Scale{False})(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) return @.(d.activation(y .* ps.weight)), st end -function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple) +function (d::Scale{True})(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) return @.(d.activation(y * ps.weight + ps.bias)), st end """ Bilinear((in1_dims, in2_dims) => out, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias::Bool=true, allow_fast_activation::Bool=true) + init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) Bilinear(in12_dims => out, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias::Bool=true, allow_fast_activation::Bool=true) + init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) Create a fully connected layer between two inputs and an output, and otherwise similar to [`Dense`](@ref). Its output, given vectors `x` & `y`, is another vector `z` with, for all @@ -504,35 +509,38 @@ with `B` the Bilinear layer. - `weight`: Weight Matrix of size `(out_dims, in1_dims, in2_dims)` - `bias`: Bias of size `(out_dims, 1)` (present if `use_bias=true`) """ -@concrete struct Bilinear{use_bias} <: AbstractExplicitLayer +@concrete struct Bilinear <: AbstractExplicitLayer activation - in1_dims::Int - in2_dims::Int - out_dims::Int + in1_dims <: IntegerType + in2_dims <: IntegerType + out_dims <: IntegerType init_weight init_bias + use_bias <: StaticBool end -function Base.show(io::IO, b::Bilinear{use_bias}) where {use_bias} +function Base.show(io::IO, b::Bilinear) print(io, "Bilinear(($(b.in1_dims), $(b.in2_dims)) => $(b.out_dims)") (b.activation == identity) || print(io, ", $(b.activation)") - use_bias || print(io, ", use_bias=false") + has_bias(b) || print(io, ", use_bias=false") return print(io, ")") end -function Bilinear(((in1_dims, in2_dims), out)::Pair{<:Tuple, <:Integer}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - use_bias::Bool=true, allow_fast_activation::Bool=true) - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation - return Bilinear{use_bias}(activation, in1_dims, in2_dims, out, init_weight, init_bias) -end -function Bilinear( - (in12_dims, out)::Pair{<:Integer, <:Integer}, activation=identity; kwargs...) +function Bilinear((in12_dims, out)::Pair{<:IntegerType, <:IntegerType}, + activation=identity; kwargs...) return Bilinear((in12_dims, in12_dims) => out, activation; kwargs...) end -function initialparameters(rng::AbstractRNG, b::Bilinear{use_bias}) where {use_bias} - if use_bias +function Bilinear(((in1_dims, in2_dims), out)::Pair{<:Tuple, <:IntegerType}, + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, + use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return Bilinear( + activation, in1_dims, in2_dims, out, init_weight, init_bias, static(use_bias)) +end + +function initialparameters(rng::AbstractRNG, b::Bilinear) + if has_bias(b) return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims), bias=b.init_bias(rng, b.out_dims, 1)) # TODO: In v1.0 make it a vector else @@ -540,8 +548,8 @@ function initialparameters(rng::AbstractRNG, b::Bilinear{use_bias}) where {use_b end end -function parameterlength(b::Bilinear{use_bias}) where {use_bias} - return b.out_dims * b.in1_dims * b.in2_dims + use_bias * b.out_dims +function parameterlength(b::Bilinear) + return b.out_dims * b.in1_dims * b.in2_dims + has_bias(b) * b.out_dims end statelength(b::Bilinear) = 0 @@ -609,14 +617,12 @@ This layer is often used to store word embeddings and retrieve them using indice - Empty `NamedTuple()` """ @concrete struct Embedding <: AbstractExplicitLayer - in_dims - out_dims::Int + in_dims <: Union{IntegerType, Tuple{Vararg{IntegerType}}} + out_dims <: IntegerType init_weight end -function Embedding( - (in_dims, out_dims)::Pair{<:Union{Integer, NTuple{<:Any, <:Integer}}, <:Integer}; - init_weight=randn32) +function Embedding((in_dims, out_dims)::Pair; init_weight=randn32) return Embedding(in_dims, out_dims, init_weight) end @@ -624,6 +630,12 @@ function initialparameters(rng::AbstractRNG, e::Embedding) return (weight=e.init_weight(rng, e.out_dims, e.in_dims...),) end +function Base.show(io::IO, e::Embedding) + return print(io, "Embedding(", e.in_dims, " => ", e.out_dims, ")") +end + +outputsize(e::Embedding) = (e.out_dims,) + (e::Embedding)(x::Integer, ps, st::NamedTuple) = view(ps.weight, :, x), st function (e::Embedding)(x::AbstractVector{<:Integer}, ps, st::NamedTuple) return NNlib.gather(ps.weight, x), st @@ -648,12 +660,6 @@ function (e::Embedding)(::Tuple{}, _, ::NamedTuple) throw(ArgumentError("Input tuple must contain at least one element")) end -function Base.show(io::IO, e::Embedding) - return print(io, "Embedding(", e.in_dims, " => ", e.out_dims, ")") -end - -outputsize(e::Embedding) = (e.out_dims,) - """ PeriodicEmbedding(idxs, periods) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index d4a1f82a1..21c1506a8 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -256,8 +256,8 @@ BranchLayer( # plus 0 states. ``` """ -struct BranchLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} - layers::T +@concrete struct BranchLayer <: AbstractExplicitContainerLayer{(:layers,)} + layers <: NamedTuple name end @@ -595,8 +595,8 @@ See also [`Parallel`](@ref) to reduce with other operators. [1] Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks" [https://arxiv.org/abs/1302.4389](https://arxiv.org/abs/1302.4389) """ -struct Maxout{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} - layers::T +@concrete struct Maxout <: AbstractExplicitContainerLayer{(:layers,)} + layers <: NamedTuple end Maxout(layers...) = Maxout(Utils.named_tuple_layers(layers...)) @@ -679,32 +679,35 @@ times for gradients might be unreasonably high. - State of `model` """ -struct RepeatedLayer{N, IJ, M <: AbstractExplicitLayer} <: - AbstractExplicitContainerLayer{(:model,)} - model::M +@concrete struct RepeatedLayer <: AbstractExplicitContainerLayer{(:model,)} + nrepeats <: StaticInt + input_injection <: StaticBool + model <: AbstractExplicitLayer end -LuxCore.display_name(::RepeatedLayer{N, IJ}) where {N, IJ} = "RepeatedLayer{$N, $IJ}" - -RepeatedLayer{N, IJ}(model) where {N, IJ} = RepeatedLayer{N, IJ, typeof(model)}(model) -RepeatedLayer{N, IJ}(; model) where {N, IJ} = RepeatedLayer{N, IJ, typeof(model)}(model) +function LuxCore.display_name(r::RepeatedLayer) + return "RepeatedLayer{nrepeats = $(known(r.nrepeats)), \ + input_injection = $(known(r.input_injection))}" +end -function RepeatedLayer(model::AbstractExplicitLayer; repeats::Val{N}=Val(10), - input_injection::Val{IJ}=Val(false)) where {N, IJ} - return RepeatedLayer{N, IJ}(model) +function RepeatedLayer( + model::AbstractExplicitLayer; repeats::Union{StaticInt, Integer, Val}=Val(10), + input_injection::Union{StaticBool, Bool, Val{true}, Val{false}}=Val(false)) + return RepeatedLayer(static(repeats), static(input_injection), model) end (m::RepeatedLayer)(x, ps, st) = repeatedlayer(m, m.model, x, ps, st) @generated function repeatedlayer(::RepeatedLayer{N, IJ}, model, x, ps, st) where {N, IJ} - sts = ntuple(_ -> gensym("st"), N) - xs = ntuple(_ -> gensym("x"), N + IJ) + sts = ntuple(_ -> gensym("st"), known(N)) + xs = ntuple(_ -> gensym("x"), known(N) + known(IJ)) calls = [] - IJ && push!(calls, :($(xs[1]) = x)) - for i in 1:N + known(IJ) && push!(calls, :($(xs[1]) = x)) + for i in 1:known(N) push!(calls, - :(($(xs[i + IJ]), $(sts[i])) = apply( - model, $(IJ ? :(($(xs[i]), x)) : :x), ps, $(i == 1 ? :st : sts[i - 1])))) + :(($(xs[i + known(IJ)]), $(sts[i])) = apply( + model, $(known(IJ) ? :(($(xs[i]), x)) : :x), + ps, $(i == 1 ? :st : sts[i - 1])))) end return quote $(calls...) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index d8fb2572d..ad40385a0 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -62,7 +62,7 @@ end @doc doc""" Conv(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, - pad=0, dilation=1, groups=1, use_bias=true, allow_fast_activation=true) + pad=0, dilation=1, groups=1, use_bias=True(), allow_fast_activation=True()) Standard convolutional layer. @@ -139,41 +139,44 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s - `weight`: Convolution kernel - `bias`: Bias (present if `use_bias=true`) """ -@concrete struct Conv{N, use_bias, M} <: AbstractExplicitLayer +@concrete struct Conv <: AbstractExplicitLayer activation - in_chs::Int - out_chs::Int - kernel_size::NTuple{N, Int} - stride::NTuple{N, Int} - pad::NTuple{M, Int} - dilation::NTuple{N, Int} - groups::Int + in_chs <: IntegerType + out_chs <: IntegerType + kernel_size <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + dilation <: Tuple{Vararg{IntegerType}} + groups <: IntegerType init_weight init_bias + use_bias <: StaticBool end -function Conv(k::NTuple{N, Integer}, ch::Pair{<:Integer, <:Integer}, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, stride=1, pad=0, dilation=1, - groups=1, use_bias::Bool=true, allow_fast_activation::Bool=true) where {N} - stride = Utils.expand(Val(N), stride) - dilation = Utils.expand(Val(N), dilation) +function Conv(k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, + activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, + use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + stride = Utils.expand(Val(length(k)), stride) + dilation = Utils.expand(Val(length(k)), dilation) pad = calc_padding(pad, k, dilation, stride) - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation + @argcheck allequal(length, (stride, dilation, k)) - return Conv{N, use_bias, length(pad)}(activation, first(ch), last(ch), k, stride, - pad, dilation, groups, init_weight, init_bias) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return Conv(activation, first(ch), last(ch), k, stride, pad, dilation, + groups, init_weight, init_bias, static(use_bias)) end -function initialparameters(rng::AbstractRNG, c::Conv{N, use_bias}) where {N, use_bias} +function initialparameters(rng::AbstractRNG, c::Conv) weight = init_conv_filter( - rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, groups=c.groups) - !use_bias && return (; weight) - return (; weight, bias=c.init_bias(rng, ntuple(_ -> 1, N)..., c.out_chs, 1)) # TODO: flatten in v1 + rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, c.groups) + has_bias(c) || return (; weight) + return (; weight, + bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 end -function parameterlength(c::Conv{N, use_bias}) where {N, use_bias} - return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + - (use_bias ? c.out_chs : 0) +function parameterlength(c::Conv) + return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs end function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) @@ -183,7 +186,7 @@ function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) return fused_conv_bias_activation(c.activation, ps.weight, y, bias, cdims), st end -function Base.show(io::IO, l::Conv{N, use_bias}) where {N, use_bias} +function Base.show(io::IO, l::Conv) print(io, "Conv(", l.kernel_size) print(io, ", ", l.in_chs, " => ", l.out_chs) l.activation == identity || print(io, ", ", l.activation) @@ -192,8 +195,8 @@ function Base.show(io::IO, l::Conv{N, use_bias}) where {N, use_bias} all(==(1), l.dilation) || print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) (l.groups == 1) || print(io, ", groups=", l.groups) - (use_bias == false) && print(io, ", use_bias=false") - return print(io, ")") + has_bias(l) || print(io, ", use_bias=false") + print(io, ")") end @doc doc""" @@ -241,19 +244,21 @@ value. See also [`Conv`](@ref), [`MeanPool`](@ref), [`GlobalMaxPool`](@ref), [`AdaptiveMaxPool`](@ref) """ -struct MaxPool{N, M} <: AbstractExplicitLayer - k::NTuple{N, Int} - pad::NTuple{M, Int} - stride::NTuple{N, Int} +@concrete struct MaxPool <: AbstractExplicitLayer + k <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} end -function MaxPool(k::NTuple{N, Integer}; pad=0, stride=k) where {N} - stride = Utils.expand(Val(N), stride) +function MaxPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k) + stride = Utils.expand(Val(length(k)), stride) pad = calc_padding(pad, k, 1, stride) - return MaxPool{N, length(pad)}(k, pad, stride) + @argcheck allequal(length, (stride, k)) + + return MaxPool(k, pad, stride) end -function (m::MaxPool{N, M})(x, ps, st::NamedTuple) where {N, M} +function (m::MaxPool)(x, _, st::NamedTuple) return maxpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st end @@ -309,19 +314,21 @@ value. See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalMeanPool`](@ref), [`AdaptiveMeanPool`](@ref) """ -struct MeanPool{N, M} <: AbstractExplicitLayer - k::NTuple{N, Int} - pad::NTuple{M, Int} - stride::NTuple{N, Int} +@concrete struct MeanPool <: AbstractExplicitLayer + k <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} end -function MeanPool(k::NTuple{N, Integer}; pad=0, stride=k) where {N} - stride = Utils.expand(Val(N), stride) +function MeanPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k) + stride = Utils.expand(Val(length(k)), stride) pad = calc_padding(pad, k, 1, stride) - return MeanPool{N, length(pad)}(k, pad, stride) + @argcheck allequal(length, (stride, k)) + + return MeanPool(k, pad, stride) end -function (m::MeanPool{N, M})(x, ps, st::NamedTuple) where {N, M} +function (m::MeanPool)(x, _, st::NamedTuple) return meanpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st end @@ -379,45 +386,50 @@ Currently supported upsampling `mode`s and corresponding NNlib's methods are: - Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)` - Empty `NamedTuple()` """ -@concrete struct Upsample{mode} <: AbstractExplicitLayer +@concrete struct Upsample <: AbstractExplicitLayer scale size + upsample_mode <: StaticSymbol end -function Upsample(mode::Symbol=:nearest; scale=nothing, size=nothing) - mode in [:nearest, :bilinear, :trilinear] || - throw(ArgumentError("mode=:$mode is not supported.")) +function Upsample(mode::SymbolType=static(:nearest); scale=nothing, size=nothing) + @argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear) if !xor(isnothing(scale), isnothing(size)) throw(ArgumentError("Either scale or size should be specified (but not both).")) end - return Upsample{mode}(scale, size) + return Upsample(scale, size, static(mode)) end -Upsample(scale, mode::Symbol=:nearest) = Upsample(mode; scale) +Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale) + +function (m::Upsample)(x::AbstractArray, _, st::NamedTuple) + return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale), st +end +function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple) + return lux_upsample_size_dispatch(m.upsample_mode, x, m.size), st +end for interp in (:nearest, :bilinear, :trilinear) - interp_func = Symbol(:upsample_, interp) + nnlib_interp_func = Symbol(:upsample_, interp) @eval begin - function (m::Upsample{$(Meta.quot(interp))})(x::AbstractArray, ps, st::NamedTuple) - return NNlib.$(interp_func)(x, m.scale), st + function lux_upsample_scale_dispatch(::StaticSymbol{$(Meta.quot(interp))}, x, scale) + return $(nnlib_interp_func)(x, scale) end - function (m::Upsample{$(Meta.quot(interp)), Nothing})( - x::AbstractArray, ps, st::NamedTuple) - return NNlib.$(interp_func)(x; m.size), st + function lux_upsample_size_dispatch(::StaticSymbol{$(Meta.quot(interp))}, x, size) + return $(nnlib_interp_func)(x; size) end end end -function (m::Upsample{:nearest, Int})(x::AbstractArray, ps, st::NamedTuple) - return NNlib.upsample_nearest(x, ntuple(i -> m.scale, ndims(x) - 2)), st +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer) + return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2)) end -function Base.show(io::IO, u::Upsample{mode}) where {mode} - print(io, "Upsample(") - print(io, ":", mode) +function Base.show(io::IO, u::Upsample) + print(io, "Upsample(", u.upsample_mode) u.scale !== nothing && print(io, ", scale = $(u.scale)") u.size !== nothing && print(io, ", size = $(u.size)") - return print(io, ")") + print(io, ")") end """ @@ -439,7 +451,7 @@ See also [`MaxPool`](@ref), [`AdaptiveMaxPool`](@ref), [`GlobalMeanPool`](@ref) """ struct GlobalMaxPool <: AbstractExplicitLayer end -function (g::GlobalMaxPool)(x, ps, st::NamedTuple) +function (g::GlobalMaxPool)(x, _, st::NamedTuple) return maxpool(x, PoolDims(x, size(x)[1:(end - 2)])), st end @@ -462,7 +474,7 @@ See also [`MeanPool`](@ref), [`AdaptiveMeanPool`](@ref), [`GlobalMaxPool`](@ref) """ struct GlobalMeanPool <: AbstractExplicitLayer end -function (g::GlobalMeanPool)(x, ps, st::NamedTuple) +function (g::GlobalMeanPool)(x, _, st::NamedTuple) return meanpool(x, PoolDims(x, size(x)[1:(end - 2)])), st end @@ -488,18 +500,16 @@ Adaptive Max Pooling layer. Calculates the necessary window size such that its o See also [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref). """ -struct AdaptiveMaxPool{S, O} <: AbstractExplicitLayer - out::NTuple{O, Int} - AdaptiveMaxPool(out::NTuple{O, Int}) where {O} = new{O + 2, O}(out) +struct AdaptiveMaxPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractExplicitLayer + out::O + AdaptiveMaxPool(out) = new{length(out) + 2, typeof(out)}(out) end -function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}, ps, st::NamedTuple) where {S, T} +function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T} return maxpool(x, compute_adaptive_pooling_dims(x, a.out)), st end -function Base.show(io::IO, a::AdaptiveMaxPool) - return print(io, "AdaptiveMaxPool(", a.out, ")") -end +Base.show(io::IO, a::AdaptiveMaxPool) = print(io, "AdaptiveMaxPool(", a.out, ")") """ AdaptiveMeanPool(out::NTuple) @@ -523,18 +533,16 @@ Adaptive Mean Pooling layer. Calculates the necessary window size such that its See also [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref). """ -struct AdaptiveMeanPool{S, O} <: AbstractExplicitLayer - out::NTuple{O, Int} - AdaptiveMeanPool(out::NTuple{O, Int}) where {O} = new{O + 2, O}(out) +struct AdaptiveMeanPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractExplicitLayer + out::O + AdaptiveMeanPool(out) = new{length(out) + 2, typeof(out)}(out) end -function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, ps, st::NamedTuple) where {S, T} +function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T} return meanpool(x, compute_adaptive_pooling_dims(x, a.out)), st end -function Base.show(io::IO, a::AdaptiveMeanPool) - return print(io, "AdaptiveMeanPool(", a.out, ")") -end +Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")") """ PixelShuffle(r::Int) @@ -563,12 +571,12 @@ function set to `Base.Fix2(pixel_shuffle, r)` - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` for D-dimensional data, where `D = ndims(x) - 2` """ -PixelShuffle(r::Int) = WrappedFunction{:direct_call}(Base.Fix2(pixel_shuffle, r)) +PixelShuffle(r::IntegerType) = WrappedFunction{:direct_call}(Base.Fix2(pixel_shuffle, r)) @doc doc""" CrossCor(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, - pad=0, dilation=1, use_bias=true, allow_fast_activation=true) + pad=0, dilation=1, groups=1, use_bias=True(), allow_fast_activation=True()) Cross Correlation layer. @@ -595,6 +603,9 @@ number of observations in a batch. - `stride`: Should each be either single integer, or a tuple with `N` integers - `dilation`: Should each be either single integer, or a tuple with `N` integers + - `groups`: Expected to be an `Int`. It specifies the number of groups to divide a + convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` + and `out_chs` must be divisible by `groups`. - `pad`: Specifies the number of elements added to the borders of the data array. It can be @@ -633,50 +644,55 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s - `weight`: Convolution kernel - `bias`: Bias (present if `use_bias=true`) """ -@concrete struct CrossCor{N, use_bias, M} <: AbstractExplicitLayer +@concrete struct CrossCor <: AbstractExplicitLayer activation - in_chs::Int - out_chs::Int - kernel_size::NTuple{N, Int} - stride::NTuple{N, Int} - pad::NTuple{M, Int} - dilation::NTuple{N, Int} + in_chs <: IntegerType + out_chs <: IntegerType + kernel_size <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + dilation <: Tuple{Vararg{IntegerType}} + groups <: IntegerType init_weight init_bias + use_bias <: StaticBool end -function CrossCor( - k::NTuple{N, Integer}, ch::Pair{<:Integer, <:Integer}, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, stride=1, pad=0, dilation=1, - use_bias::Bool=true, allow_fast_activation::Bool=true) where {N} - stride = Utils.expand(Val(N), stride) - dilation = Utils.expand(Val(N), dilation) +function CrossCor(k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, + activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, + use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + stride = Utils.expand(Val(length(k)), stride) + dilation = Utils.expand(Val(length(k)), dilation) pad = calc_padding(pad, k, dilation, stride) - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation + @argcheck allequal(length, (stride, dilation, k)) - return CrossCor{N, use_bias, length(pad)}( - activation, first(ch), last(ch), k, stride, pad, dilation, init_weight, init_bias) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return CrossCor(activation, first(ch), last(ch), k, stride, pad, dilation, + groups, init_weight, init_bias, static(use_bias)) end -function initialparameters(rng::AbstractRNG, c::CrossCor{N, use_bias}) where {N, use_bias} - weight = init_conv_filter(rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight) - !use_bias && return (; weight) - return (; weight, bias=c.init_bias(rng, ntuple(_ -> 1, N)..., c.out_chs, 1)) # TODO: flatten in v1 +function initialparameters(rng::AbstractRNG, c::CrossCor) + weight = init_conv_filter( + rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, c.groups) + has_bias(c) || return (; weight) + return (; weight, + bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 end -function parameterlength(c::CrossCor{N, use_bias}) where {N, use_bias} - return prod(c.kernel_size) * c.in_chs * c.out_chs + (use_bias ? c.out_chs : 0) +function parameterlength(c::CrossCor) + return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs end function (c::CrossCor)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = DenseConvDims( - DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation); F=true) + DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups); F=true) bias = safe_vec(safe_getproperty(ps, Val(:bias))) return fused_conv_bias_activation(c.activation, ps.weight, y, bias, cdims), st end -function Base.show(io::IO, l::CrossCor{N, use_bias}) where {N, use_bias} +function Base.show(io::IO, l::CrossCor) print(io, "CrossCor(", l.kernel_size) print(io, ", ", l.in_chs, " => ", l.out_chs) l.activation == identity || print(io, ", ", l.activation) @@ -684,15 +700,16 @@ function Base.show(io::IO, l::CrossCor{N, use_bias}) where {N, use_bias} all(==(1), l.stride) || print(io, ", stride=", PrettyPrinting.tuple_string(l.stride)) all(==(1), l.dilation) || print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) - (use_bias == false) && print(io, ", use_bias=false") + (l.groups == 1) || print(io, ", groups=", l.groups) + has_bias(l) || print(io, ", use_bias=false") return print(io, ")") end @doc doc""" ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias=true, - allow_fast_activation=true) + stride=1, pad=0, dilation=1, groups=1, use_bias=True(), + allow_fast_activation=True()) Standard convolutional transpose layer. @@ -747,51 +764,52 @@ Standard convolutional transpose layer. - `weight`: Convolution Transpose kernel - `bias`: Bias (present if `use_bias=true`) """ -@concrete struct ConvTranspose{N, use_bias, M} <: AbstractExplicitLayer +@concrete struct ConvTranspose <: AbstractExplicitLayer activation - in_chs::Int - out_chs::Int - kernel_size::NTuple{N, Int} - stride::NTuple{N, Int} - pad::NTuple{M, Int} - dilation::NTuple{N, Int} - groups::Int + in_chs <: IntegerType + out_chs <: IntegerType + kernel_size <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + dilation <: Tuple{Vararg{IntegerType}} + groups <: IntegerType init_weight init_bias + use_bias <: StaticBool end function ConvTranspose( - k::NTuple{N, Integer}, ch::Pair{<:Integer, <:Integer}, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, stride=1, pad=0, dilation=1, - use_bias::Bool=true, groups=1, allow_fast_activation::Bool=true) where {N} - stride = Utils.expand(Val(N), stride) - dilation = Utils.expand(Val(N), dilation) + k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, + activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, + use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + stride = Utils.expand(Val(length(k)), stride) + dilation = Utils.expand(Val(length(k)), dilation) pad = if pad isa SamePad calc_padding(pad, k .- stride .+ 1, dilation, stride) else calc_padding(pad, k, dilation, stride) end - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation + @argcheck allequal(length, (stride, dilation, k)) - return ConvTranspose{N, use_bias, length(pad)}( - activation, first(ch), last(ch), k, stride, - pad, dilation, groups, init_weight, init_bias) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return ConvTranspose(activation, first(ch), last(ch), k, stride, pad, dilation, + groups, init_weight, init_bias, static(use_bias)) end -function initialparameters( - rng::AbstractRNG, c::ConvTranspose{N, use_bias}) where {N, use_bias} +function initialparameters(rng::AbstractRNG, c::ConvTranspose) weight = init_conv_filter( rng, c.kernel_size, c.out_chs => c.in_chs; init=c.init_weight, c.groups) - !use_bias && return (; weight) - return (; weight, bias=c.init_bias(rng, ntuple(_ -> 1, N)..., c.out_chs, 1)) # TODO: flatten in v1 + has_bias(c) || return (; weight) + return (; weight, + bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 end -function parameterlength(c::ConvTranspose{N, use_bias}) where {N, use_bias} - return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + - (use_bias ? c.out_chs : 0) +function parameterlength(c::ConvTranspose) + return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs end -function (c::ConvTranspose{N})(x::AbstractArray, ps, st::NamedTuple) where {N} +function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) bias = safe_vec(safe_getproperty(ps, Val(:bias))) @@ -801,17 +819,12 @@ end function Base.show(io::IO, l::ConvTranspose) print(io, "ConvTranspose(", l.kernel_size) print(io, ", ", l.in_chs, " => ", l.out_chs) - _print_convtranspose_opt(io, l) - return print(io, ")") -end - -function _print_convtranspose_opt(io::IO, l::ConvTranspose{N, use_bias}) where {N, use_bias} l.activation == identity || print(io, ", ", l.activation) all(==(0), l.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(l.pad)) all(==(1), l.stride) || print(io, ", stride=", PrettyPrinting.tuple_string(l.stride)) all(==(1), l.dilation) || print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) (l.groups == 1) || print(io, ", groups=", l.groups) - (use_bias == false) && print(io, ", use_bias=false") - return nothing + has_bias(l) || print(io, ", use_bias=false") + return print(io, ")") end diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 7d20908ce..6997f4538 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -192,24 +192,27 @@ regular `Array` or not. Default is `false`. """ struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractExplicitLayer}} <: AbstractExplicitLayer + to_array::ToArray layer::SL lux_layer::LL function SimpleChainsLayer{ToArray}(layer, lux_layer=nothing) where {ToArray} - return new{ToArray, typeof(layer), typeof(lux_layer)}(layer, lux_layer) + to_array = static(ToArray) + return new{typeof(to_array), typeof(layer), typeof(lux_layer)}( + to_array, layer, lux_layer) end - function SimpleChainsLayer(layer, ToArray::Union{Bool, Val}=Val(false)) - return new{Utils.unwrap_val(ToArray), typeof(layer), Nothing}(layer, nothing) + function SimpleChainsLayer(layer, ToArray::BoolType=False()) + to_array = static(ToArray) + return new{typeof(to_array), typeof(layer), Nothing}(to_array, layer, nothing) end end function Base.show( io::IO, ::MIME"text/plain", s::SimpleChainsLayer{ToArray}) where {ToArray} - PrettyPrinting.print_wrapper_model(io, "SimpleChainsLayer{$ToArray}", s.lux_layer) + PrettyPrinting.print_wrapper_model( + io, "SimpleChainsLayer{to_array=$ToArray}", s.lux_layer) end -initialstates(::AbstractRNG, ::SimpleChainsLayer) = (;) - function (sc::SimpleChainsLayer)(x, ps, st) y = match_eltype(sc, ps, st, x) return ( @@ -218,8 +221,8 @@ function (sc::SimpleChainsLayer)(x, ps, st) st) end -simple_chain_output(::SimpleChainsLayer{false}, y) = y -simple_chain_output(::SimpleChainsLayer{true}, y) = convert(Array, y) +simple_chain_output(::SimpleChainsLayer{False}, y) = y +simple_chain_output(::SimpleChainsLayer{True}, y) = convert(Array, y) apply_simple_chain(layer, x, ps, ::CPUDevice) = layer(x, ps) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 130d5a0be..dc7de1252 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -1,14 +1,6 @@ -abstract type AbstractNormalizationLayer{affine, track_stats} <: AbstractExplicitLayer end - -has_affine(::AbstractNormalizationLayer{A, T}) where {A, T} = A -is_tracking_stats(::AbstractNormalizationLayer{A, T}) where {A, T} = T - -CRC.@non_differentiable has_affine(::Any) -CRC.@non_differentiable is_tracking_stats(::Any) - @doc doc""" BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, - affine=true, track_stats=true, epsilon=1f-5, momentum=0.1f0, + affine=True(), track_stats=True(), epsilon=1f-5, momentum=0.1f0, allow_fast_activation::Bool=true) [Batch Normalization](https://arxiv.org/abs/1502.03167) layer. @@ -94,43 +86,40 @@ Chain( See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct BatchNorm{affine, track_stats, N} <: - AbstractNormalizationLayer{affine, track_stats} +@concrete struct BatchNorm{N} <: AbstractExplicitLayer activation epsilon::N momentum::N - chs::Int + chs <: IntegerType init_bias init_scale + affine <: StaticBool + track_stats <: StaticBool end -function BatchNorm(chs::Int, activation=identity; init_bias=zeros32, - init_scale=ones32, affine::Bool=true, track_stats::Bool=true, - epsilon=1.0f-5, momentum=0.1f0, allow_fast_activation::Bool=true) - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation - return BatchNorm{affine, track_stats}( - activation, epsilon, momentum, chs, init_bias, init_scale) +function BatchNorm(chs::IntegerType, activation=identity; init_bias=zeros32, + init_scale=ones32, affine::BoolType=True(), track_stats::BoolType=True(), + epsilon=1.0f-5, momentum=0.1f0, allow_fast_activation::BoolType=True()) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return BatchNorm(activation, epsilon, momentum, chs, init_bias, + init_scale, static(affine), static(track_stats)) end function initialparameters(rng::AbstractRNG, l::BatchNorm) - if has_affine(l) - return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) - else - return NamedTuple() - end + has_affine(l) && return (; scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) + return (;) end function initialstates(rng::AbstractRNG, l::BatchNorm) - if is_tracking_stats(l) + if has_track_stats(l) return (running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs), training=Val(true)) - else - return (; training=Val(true)) end + return (; training=Val(true)) end parameterlength(l::BatchNorm) = ifelse(has_affine(l), l.chs * 2, 0) -statelength(l::BatchNorm) = ifelse(is_tracking_stats(l), l.chs * 2, 0) + 1 +statelength(l::BatchNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1 function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) CRC.ignore_derivatives() do @@ -148,7 +137,7 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) end function update_batchnorm_state(BN::BatchNorm, st::NamedTuple, stats) - is_tracking_stats(BN) && return merge(st, (; stats.running_mean, stats.running_var)) + has_track_stats(BN) && return merge(st, (; stats.running_mean, stats.running_var)) return st end @@ -158,7 +147,7 @@ function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(l.chs)") (l.activation == identity) || print(io, ", $(l.activation)") print(io, ", affine=$(has_affine(l))") - print(io, ", track_stats=$(is_tracking_stats(l))") + print(io, ", track_stats=$(has_track_stats(l))") return print(io, ")") end @@ -233,26 +222,28 @@ Chain( See also [`GroupNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct GroupNorm{affine} <: AbstractNormalizationLayer{affine, false} +@concrete struct GroupNorm <: AbstractExplicitLayer activation epsilon - chs::Int + chs <: IntegerType init_bias init_scale - groups::Int + groups <: IntegerType + affine <: StaticBool end -function GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias=zeros32, - init_scale=ones32, affine=true, epsilon=1.0f-5, allow_fast_activation::Bool=true) +function GroupNorm(chs::IntegerType, groups::IntegerType, activation=identity; + init_bias=zeros32, init_scale=ones32, affine::BoolType=True(), + epsilon=1.0f-5, allow_fast_activation::BoolType=True()) @argcheck chs % groups==0 "The number of groups ($(groups)) must divide the number of channels ($chs)" - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation - - return GroupNorm{affine}(activation, epsilon, chs, init_bias, init_scale, groups) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return GroupNorm( + activation, epsilon, chs, init_bias, init_scale, groups, static(affine)) end function initialparameters(rng::AbstractRNG, l::GroupNorm) - return has_affine(l) ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) : - (;) + return has_affine(l) ? + (; scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) : (;) end parameterlength(l::GroupNorm) = has_affine(l) ? (l.chs * 2) : 0 @@ -345,27 +336,25 @@ Chain( See also [`BatchNorm`](@ref), [`GroupNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct InstanceNorm{affine} <: AbstractNormalizationLayer{affine, false} +@concrete struct InstanceNorm <: AbstractExplicitLayer activation epsilon - chs::Int + chs <: IntegerType init_bias init_scale + affine <: StaticBool end function InstanceNorm( - chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, - affine=true, epsilon=1.0f-5, allow_fast_activation::Bool=true) - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation - return InstanceNorm{affine}(activation, epsilon, chs, init_bias, init_scale) + chs::IntegerType, activation=identity; init_bias=zeros32, init_scale=ones32, + affine::BoolType=True(), epsilon=1.0f-5, allow_fast_activation::BoolType=True()) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return InstanceNorm(activation, epsilon, chs, init_bias, init_scale, static(affine)) end function initialparameters(rng::AbstractRNG, l::InstanceNorm) - if has_affine(l) - return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) - else - return (scale=nothing, bias=nothing) - end + has_affine(l) && return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) + return (;) end initialstates(::AbstractRNG, ::InstanceNorm) = (; training=Val(true)) @@ -387,8 +376,8 @@ function Base.show(io::IO, l::InstanceNorm) end @doc doc""" - WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N,Symbol}, - dims::Union{Tuple,Nothing}=nothing) + WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol}, + dims::Union{Tuple, Nothing}=nothing) Applies [weight normalization](https://arxiv.org/abs/1602.07868) to a parameter in the given layer. @@ -427,31 +416,28 @@ parameters: one specifying the magnitude (e.g. `weight_g`) and one specifying th - Same as that of `layer` """ -@concrete struct WeightNorm{which_params, L <: AbstractExplicitLayer} <: - AbstractExplicitLayer - layer::L +@concrete struct WeightNorm <: AbstractExplicitLayer + layer <: AbstractExplicitLayer + which_params dims -end - -function WeightNorm{which_params}(layer::AbstractExplicitLayer; - dims::Union{Tuple, Nothing}=nothing) where {which_params} - return WeightNorm{which_params}(layer, dims) -end -function WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol}, - dims::Union{Tuple, Nothing}=nothing) where {N} - return WeightNorm{which_params}(layer; dims) + function WeightNorm( + layer::AbstractExplicitLayer, which_params, dims::Union{Tuple, Nothing}=nothing) + which_params = static(which_params) + dims = static(dims) + return new{typeof(layer), typeof(which_params), typeof(dims)}( + layer, which_params, dims) + end end -function initialparameters( - rng::AbstractRNG, wn::WeightNorm{which_params}) where {which_params} +function initialparameters(rng::AbstractRNG, wn::WeightNorm) ps_layer = initialparameters(rng, wn.layer) ps_normalized = [] ps_unnormalized = [] i = 1 for k in propertynames(ps_layer) v = ps_layer[k] - if k in which_params + if k in known(wn.which_params) if all(iszero, v) throw(ArgumentError("Parameter $(k) is completely zero. This will result \ in NaN gradients. Either remove this parameter from \ @@ -459,7 +445,7 @@ function initialparameters( actual layer. Typically this is controlled using the \ `init_$(k)` keyword argument.")) end - dim = wn.dims === nothing ? ndims(v) : wn.dims[i] + dim = wn.dims === nothing ? ndims(v) : known(wn.dims[i]) push!(ps_normalized, Symbol(string(k) * "_g") => Utils.norm_except(v; dims=dim)) push!(ps_normalized, Symbol(string(k) * "_v") => v) i += 1 @@ -479,28 +465,27 @@ function (wn::WeightNorm)(x, ps, st::NamedTuple) return apply(wn.layer, y, Utils.merge(psₙ, ps.unnormalized), st) end -@inbounds @generated function get_weight_normalized_parameters( - ::WeightNorm{which_params}, dims::T, ps) where {T, which_params} - parameter_names = string.(which_params) - v_parameter_names = Symbol.(parameter_names .* "_v") - g_parameter_names = Symbol.(parameter_names .* "_g") - normalized_params_symbol = [gensym(p) for p in parameter_names] +@generated function get_weight_normalized_parameters( + ::WeightNorm{L, WP}, dims::T, ps) where {L, WP, T} + which_params = known(WP) + v_parameter_names = Symbol.(which_params, :_v) + g_parameter_names = Symbol.(which_params, :_g) + normalized_params_symbol = [gensym(p) for p in which_params] function get_norm_except_invoke(i) return if T <: Tuple - :(Utils.norm_except(ps.$(v_parameter_names[i]); dims=dims[$i])) + :(Utils.norm_except(ps.$(v_parameter_names[i]); dims=known(dims[$i]))) else :(Utils.norm_except(ps.$(v_parameter_names[i]))) end end calls = [] - for i in 1:length(parameter_names) + for (i, (v_param, g_param)) in enumerate(zip(v_parameter_names, g_parameter_names)) push!(calls, - :($(normalized_params_symbol[i]) = ps.$(v_parameter_names[i]) .* - (ps.$(g_parameter_names[i]) ./ + :($(normalized_params_symbol[i]) = ps.$(v_param) .* (ps.$(g_param) ./ ($(get_norm_except_invoke(i)) .+ - eps(eltype(ps.$(v_parameter_names[i]))))))) + eps(eltype(ps.$(v_param))))))) end push!(calls, :(return NamedTuple{$(which_params)}(tuple($(Tuple(normalized_params_symbol)...))))) @@ -508,13 +493,14 @@ end return Expr(:block, calls...) end -function Base.show(io::IO, w::WeightNorm{which_params}) where {which_params} - return print(io, "WeightNorm{", which_params, "}(", w.layer, ", dims = ", w.dims, ")") +function Base.show(io::IO, ::MIME"text/plain", w::WeightNorm) + return print(io, "WeightNorm(", w.layer, ", dims = ", known(w.dims), + ", normalized_parameters = ", known(w.which_params), ")") end @doc doc""" LayerNorm(shape::NTuple{N, Int}, activation=identity; epsilon=1f-5, dims=Colon(), - affine::Bool=true, init_bias=zeros32, init_scale=ones32,) + affine=true, init_bias=zeros32, init_scale=ones32) Computes mean and standard deviation over the whole input array, and uses these to normalize the whole array. Optionally applies an elementwise affine transformation @@ -547,7 +533,7 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. - `epsilon`: a value added to the denominator for numerical stability. - `dims`: Dimensions to normalize the array over. - If `affine=true`, it also applies a shift and a rescale to the input through to - learnable per-channel bias and scale parameters. + learnable per-element bias and scale parameters. + `init_bias`: Controls how the `bias` is initialized + `init_scale`: Controls how the `scale` is initialized @@ -571,29 +557,30 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. + `bias`: Bias of shape `(shape..., 1)` + `scale`: Scale of shape `(shape..., 1)` """ -@concrete struct LayerNorm{affine, N} <: AbstractNormalizationLayer{affine, false} - shape::NTuple{N, Int} +@concrete struct LayerNorm <: AbstractExplicitLayer + shape activation epsilon init_bias init_scale dims + affine <: StaticBool end -function LayerNorm(shape::NTuple{N, <:Int}, activation=identity; epsilon::T=1.0f-5, - dims=Colon(), affine::Bool=true, init_bias=zeros32, - init_scale=ones32, allow_fast_activation::Bool=true) where {N, T} - activation = allow_fast_activation ? NNlib.fast_act(activation) : activation - return LayerNorm{affine, N}(shape, activation, epsilon, init_bias, init_scale, dims) +function LayerNorm( + shape, activation=identity; epsilon=1.0f-5, dims=Colon(), affine::BoolType=True(), + init_bias=zeros32, init_scale=ones32, allow_fast_activation::BoolType=True()) + activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + return LayerNorm( + shape, activation, epsilon, init_bias, init_scale, dims, static(affine)) end function initialparameters(rng::AbstractRNG, ln::LayerNorm) if has_affine(ln) - return (bias=ln.init_bias(rng, ln.shape..., 1), - scale=ln.init_scale(rng, ln.shape..., 1)) - else - return NamedTuple() + dims = (ln.shape..., 1) + return (; bias=ln.init_bias(rng, dims...), scale=ln.init_scale(rng, dims...)) end + return (;) end function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 42eb0dd10..fd207bcad 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,13 +1,8 @@ -abstract type AbstractRecurrentCell{use_bias, train_state} <: AbstractExplicitLayer end +abstract type AbstractRecurrentCell <: AbstractExplicitLayer end const AbstractDebugRecurrentCell = Experimental.DebugLayer{ <:Any, <:Any, <:AbstractRecurrentCell} -function ConstructionBase.constructorof(::Type{<:AbstractRecurrentCell{ - use_bias, train_state}}) where {use_bias, train_state} - return AbstractRecurrentCell{use_bias, train_state} -end - # Fallback for vector inputs function (rnn::AbstractRecurrentCell)(x::AbstractVector, ps, st::NamedTuple) (y, carry), stₙ = rnn(reshape(x, :, 1), ps, st) @@ -90,23 +85,22 @@ automatically operate over a sequence of inputs. For some discussion on this topic, see https://github.com/LuxDL/Lux.jl/issues/472. """ -@concrete struct Recurrence{R} <: AbstractExplicitContainerLayer{(:cell,)} +@concrete struct Recurrence{R <: StaticBool} <: AbstractExplicitContainerLayer{(:cell,)} cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell} ordering <: AbstractTimeSeriesDataBatchOrdering + return_sequence::R end -ConstructionBase.constructorof(::Type{<:Recurrence{R}}) where {R} = Recurrence{R} - function Recurrence(cell; ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(), return_sequence::Bool=false) - return Recurrence{return_sequence}(cell, ordering) + return Recurrence(cell, ordering, static(return_sequence)) end function (r::Recurrence)(x::AbstractArray, ps, st::NamedTuple) return apply(r, safe_eachslice(x, r.ordering), ps, st) end -function (r::Recurrence{false})(x::Union{AbstractVector, NTuple}, ps, st::NamedTuple) +function (r::Recurrence{False})(x::Union{AbstractVector, NTuple}, ps, st::NamedTuple) (out, carry), st = apply(r.cell, first(x), ps, st) for xᵢ in x[(begin + 1):end] (out, carry), st = apply(r.cell, (xᵢ, carry), ps, st) @@ -114,7 +108,7 @@ function (r::Recurrence{false})(x::Union{AbstractVector, NTuple}, ps, st::NamedT return out, st end -function (r::Recurrence{true})(x::Union{AbstractVector, NTuple}, ps, st::NamedTuple) +function (r::Recurrence{True})(x::Union{AbstractVector, NTuple}, ps, st::NamedTuple) function recur_op(::Nothing, input) (out, carry), state = apply(r.cell, input, ps, st) return [out], carry, state @@ -173,13 +167,10 @@ end function applyrecurrentcell(l::AbstractRecurrentCell, x, ps, st, carry) return apply(l, (x, carry), ps, st) end - -function applyrecurrentcell(l::AbstractRecurrentCell, x, ps, st, ::Nothing) - return apply(l, x, ps, st) -end +applyrecurrentcell(l::AbstractRecurrentCell, x, ps, st, ::Nothing) = apply(l, x, ps, st) @doc doc""" - RNNCell(in_dims => out_dims, activation=tanh; bias::Bool=true, train_state::Bool=false, + RNNCell(in_dims => out_dims, activation=tanh; use_bias=True(), train_state=False(), init_bias=zeros32, init_weight=glorot_uniform, init_state=ones32) An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). @@ -191,7 +182,7 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). - `in_dims`: Input Dimension - `out_dims`: Output (Hidden State) Dimension - `activation`: Activation function - - `bias`: Set to false to deactivate bias + - `use_bias`: Set to false to deactivate bias - `train_state`: Trainable initial hidden state can be activated by setting this to `true` - `init_bias`: Initializer for bias - `init_weight`: Initializer for weight @@ -226,43 +217,42 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). - `rng`: Controls the randomness (if any) in the initial state generation """ -@concrete struct RNNCell{use_bias, train_state} <: - AbstractRecurrentCell{use_bias, train_state} +@concrete struct RNNCell <: AbstractRecurrentCell + train_state <: StaticBool activation - in_dims::Int - out_dims::Int + in_dims <: IntegerType + out_dims <: IntegerType init_bias init_weight init_state + use_bias <: StaticBool end -function RNNCell((in_dims, out_dims)::Pair{<:Int, <:Int}, activation=tanh; - use_bias::Bool=true, train_state::Bool=false, init_bias=zeros32, - init_weight=glorot_uniform, init_state=ones32) - return RNNCell{use_bias, train_state}( - activation, in_dims, out_dims, init_bias, init_weight, init_state) +function RNNCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}, activation=tanh; + use_bias::BoolType=True(), train_state::BoolType=False(), + init_bias=zeros32, init_weight=glorot_uniform, init_state=ones32) + return RNNCell(static(train_state), activation, in_dims, out_dims, + init_bias, init_weight, init_state, static(use_bias)) end -function initialparameters( - rng::AbstractRNG, rnn::RNNCell{use_bias, TS}) where {use_bias, TS} +function initialparameters(rng::AbstractRNG, rnn::RNNCell) ps = (weight_ih=rnn.init_weight(rng, rnn.out_dims, rnn.in_dims), weight_hh=rnn.init_weight(rng, rnn.out_dims, rnn.out_dims)) - use_bias && (ps = merge(ps, (bias=rnn.init_bias(rng, rnn.out_dims),))) - TS && (ps = merge(ps, (hidden_state=rnn.init_state(rng, rnn.out_dims),))) + has_bias(rnn) && (ps = merge(ps, (bias=rnn.init_bias(rng, rnn.out_dims),))) + has_train_state(rnn) && + (ps = merge(ps, (hidden_state=rnn.init_state(rng, rnn.out_dims),))) return ps end initialstates(rng::AbstractRNG, ::RNNCell) = (rng=Utils.sample_replicate(rng),) -function (rnn::RNNCell{use_bias, false})( - x::AbstractMatrix, ps, st::NamedTuple) where {use_bias} +function (rnn::RNNCell{False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) hidden_state = Utils.init_hidden_state(rng, rnn, x) return rnn((x, (hidden_state,)), ps, merge(st, (; rng))) end -function (rnn::RNNCell{use_bias, true})( - x::AbstractMatrix, ps, st::NamedTuple) where {use_bias} +function (rnn::RNNCell{True})(x::AbstractMatrix, ps, st::NamedTuple) hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) return rnn((x, (hidden_state,)), ps, st) end @@ -277,12 +267,12 @@ function (rnn::RNNCell)( return (hₙ, (hₙ,)), st end -function Base.show(io::IO, r::RNNCell{use_bias, TS}) where {use_bias, TS} +function Base.show(io::IO, r::RNNCell) print(io, "RNNCell($(r.in_dims) => $(r.out_dims)") (r.activation == identity) || print(io, ", $(r.activation)") - use_bias || print(io, ", use_bias=false") - TS && print(io, ", train_state=true") - return print(io, ")") + has_bias(r) || print(io, ", use_bias=false") + has_train_state(r) && print(io, ", train_state=true") + print(io, ")") end @doc doc""" @@ -360,74 +350,70 @@ Long Short-Term (LSTM) Cell - `rng`: Controls the randomness (if any) in the initial state generation """ -@concrete struct LSTMCell{use_bias, train_state, train_memory} <: - AbstractRecurrentCell{use_bias, train_state} - in_dims::Int - out_dims::Int +@concrete struct LSTMCell <: AbstractRecurrentCell + train_state <: StaticBool + train_memory <: StaticBool + in_dims <: IntegerType + out_dims <: IntegerType init_bias init_weight init_state init_memory + use_bias <: StaticBool end -function LSTMCell((in_dims, out_dims)::Pair{<:Int, <:Int}; - use_bias::Bool=true, - train_state::Bool=false, - train_memory::Bool=false, - init_weight::NTuple{4, Function}=( - glorot_uniform, glorot_uniform, glorot_uniform, glorot_uniform), - init_bias::NTuple{4, Function}=(zeros32, zeros32, ones32, zeros32), - init_state::Function=zeros32, - init_memory::Function=zeros32) - return LSTMCell{use_bias, train_state, train_memory}( - in_dims, out_dims, init_bias, init_weight, init_state, init_memory) +function LSTMCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; + use_bias::BoolType=True(), train_state::BoolType=False(), + train_memory::BoolType=False(), init_weight=glorot_uniform, + init_bias=zeros32, init_state=zeros32, init_memory=zeros32) + init_weight isa NTuple{4} || (init_weight = ntuple(Returns(init_weight), 4)) + init_bias isa NTuple{4} || (init_bias = ntuple(Returns(init_bias), 4)) + return LSTMCell(static(train_state), static(train_memory), in_dims, out_dims, + init_bias, init_weight, init_state, init_memory, static(use_bias)) end -function initialparameters(rng::AbstractRNG, - lstm::LSTMCell{use_bias, TS, train_memory}) where {use_bias, TS, train_memory} +function initialparameters(rng::AbstractRNG, lstm::LSTMCell) weight_i = vcat([init_weight(rng, lstm.out_dims, lstm.in_dims) for init_weight in lstm.init_weight]...) weight_h = vcat([init_weight(rng, lstm.out_dims, lstm.out_dims) for init_weight in lstm.init_weight]...) ps = (; weight_i, weight_h) - if use_bias + if has_bias(lstm) # TODO: in v1 we make this a flat vector bias = vcat([init_bias(rng, lstm.out_dims, 1) for init_bias in lstm.init_bias]...) ps = merge(ps, (bias=bias,)) end - TS && (ps = merge(ps, (hidden_state=lstm.init_state(rng, lstm.out_dims),))) - train_memory && (ps = merge(ps, (memory=lstm.init_memory(rng, lstm.out_dims),))) + has_train_state(lstm) && + (ps = merge(ps, (hidden_state=lstm.init_state(rng, lstm.out_dims),))) + known(lstm.train_memory) && + (ps = merge(ps, (memory=lstm.init_memory(rng, lstm.out_dims),))) return ps end initialstates(rng::AbstractRNG, ::LSTMCell) = (rng=Utils.sample_replicate(rng),) -function (lstm::LSTMCell{use_bias, false, false})( - x::AbstractMatrix, ps, st::NamedTuple) where {use_bias} +function (lstm::LSTMCell{False, False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) hidden_state = Utils.init_hidden_state(rng, lstm, x) memory = Utils.init_hidden_state(rng, lstm, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end -function (lstm::LSTMCell{use_bias, true, false})( - x::AbstractMatrix, ps, st::NamedTuple) where {use_bias} +function (lstm::LSTMCell{True, False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) memory = Utils.init_hidden_state(rng, lstm, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end -function (lstm::LSTMCell{use_bias, false, true})( - x::AbstractMatrix, ps, st::NamedTuple) where {use_bias} +function (lstm::LSTMCell{False, True})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) hidden_state = Utils.init_hidden_state(rng, lstm, x) memory = Utils.init_trainable_hidden_state(ps.memory, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end -function (lstm::LSTMCell{use_bias, true, true})( - x::AbstractMatrix, ps, st::NamedTuple) where {use_bias} +function (lstm::LSTMCell{True, True})(x::AbstractMatrix, ps, st::NamedTuple) hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) memory = Utils.init_trainable_hidden_state(ps.memory, x) return lstm((x, (hidden_state, memory)), ps, st) @@ -449,13 +435,12 @@ function (lstm::LSTMCell)( return (hidden_state₂, (hidden_state₂, memory₂)), st end -function Base.show(io::IO, - lstm::LSTMCell{use_bias, TS, train_memory}) where {use_bias, TS, train_memory} +function Base.show(io::IO, lstm::LSTMCell) print(io, "LSTMCell($(lstm.in_dims) => $(lstm.out_dims)") - use_bias || print(io, ", use_bias=false") - TS && print(io, ", train_state=true") - train_memory && print(io, ", train_memory=true") - return print(io, ")") + has_bias(lstm) || print(io, ", use_bias=false") + has_train_state(lstm) && print(io, ", train_state=true") + known(lstm.train_memory) && print(io, ", train_memory=true") + print(io, ")") end @doc doc""" @@ -521,51 +506,50 @@ Gated Recurrent Unit (GRU) Cell - `rng`: Controls the randomness (if any) in the initial state generation """ -@concrete struct GRUCell{use_bias, train_state} <: - AbstractRecurrentCell{use_bias, train_state} - in_dims::Int - out_dims::Int +@concrete struct GRUCell <: AbstractRecurrentCell + train_state <: StaticBool + in_dims <: IntegerType + out_dims <: IntegerType init_bias init_weight init_state + use_bias <: StaticBool end -function GRUCell((in_dims, out_dims)::Pair{<:Int, <:Int}; - use_bias::Bool=true, train_state::Bool=false, - init_weight::NTuple{3, Function}=(glorot_uniform, glorot_uniform, glorot_uniform), - init_bias::NTuple{3, Function}=(zeros32, zeros32, zeros32), - init_state::Function=zeros32) - return GRUCell{use_bias, train_state}( - in_dims, out_dims, init_bias, init_weight, init_state) +function GRUCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; + use_bias::BoolType=True(), train_state::BoolType=False(), + init_weight=glorot_uniform, init_bias=zeros32, init_state=zeros32) + init_weight isa NTuple{3} || (init_weight = ntuple(Returns(init_weight), 3)) + init_bias isa NTuple{3} || (init_bias = ntuple(Returns(init_bias), 3)) + return GRUCell(static(train_state), in_dims, out_dims, init_bias, + init_weight, init_state, static(use_bias)) end -function initialparameters( - rng::AbstractRNG, gru::GRUCell{use_bias, TS}) where {use_bias, TS} +function initialparameters(rng::AbstractRNG, gru::GRUCell) weight_i = vcat([init_weight(rng, gru.out_dims, gru.in_dims) for init_weight in gru.init_weight]...) weight_h = vcat([init_weight(rng, gru.out_dims, gru.out_dims) for init_weight in gru.init_weight]...) ps = (; weight_i, weight_h) - if use_bias + if has_bias(gru) bias_i = gru.init_bias[1](rng, gru.out_dims, 1) # TODO: in v1 we make this a flat vector bias_h = vcat([init_bias(rng, gru.out_dims, 1) for init_bias in gru.init_bias]...) ps = merge(ps, (bias_i=bias_i, bias_h=bias_h)) end - TS && (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),))) + has_train_state(gru) && + (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),))) return ps end initialstates(rng::AbstractRNG, ::GRUCell) = (rng=Utils.sample_replicate(rng),) -function (gru::GRUCell{use_bias, true})( - x::AbstractMatrix, ps, st::NamedTuple) where {use_bias} +function (gru::GRUCell{True})(x::AbstractMatrix, ps, st::NamedTuple) hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) return gru((x, (hidden_state,)), ps, st) end -function (gru::GRUCell{use_bias, false})( - x::AbstractMatrix, ps, st::NamedTuple) where {use_bias} +function (gru::GRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) st = merge(st, (; rng)) hidden_state = Utils.init_hidden_state(rng, gru, x) @@ -592,11 +576,11 @@ end gru_cell_compute(x, r, y, ::Nothing) = @. tanh_fast(x + r * y) gru_cell_compute(x, r, y, bias) = @. tanh_fast(x + r * y + bias) -function Base.show(io::IO, g::GRUCell{use_bias, TS}) where {use_bias, TS} +function Base.show(io::IO, g::GRUCell) print(io, "GRUCell($(g.in_dims) => $(g.out_dims)") - use_bias || print(io, ", use_bias=false") - TS && print(io, ", train_state=true") - return print(io, ")") + has_bias(g) || print(io, ", use_bias=false") + has_train_state(g) && print(io, ", train_state=true") + print(io, ")") end """ diff --git a/src/transform/simplechains.jl b/src/transform/simplechains.jl index 7aed5961a..d224fcd64 100644 --- a/src/transform/simplechains.jl +++ b/src/transform/simplechains.jl @@ -34,40 +34,27 @@ julia> lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3), Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10))); -julia> adaptor = ToSimpleChainsAdaptor((28, 28, 1)) -ToSimpleChainsAdaptor{Tuple{Static.StaticInt{28}, Static.StaticInt{28}, Static.StaticInt{1}}}((static(28), static(28), static(1)), false) - -julia> simple_chains_model = adapt(adaptor, lux_model) # or adaptor(lux_model) -SimpleChainsLayer{false}( - Chain( - layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters - layer_2 = MaxPool((2, 2)), - layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters - layer_4 = MaxPool((2, 2)), - layer_5 = FlattenLayer{Int64}(3), - layer_6 = Dense(256 => 128, relu), # 32_896 parameters - layer_7 = Dense(128 => 84, relu), # 10_836 parameters - layer_8 = Dense(84 => 10), # 850 parameters - ), -) # Total: 47_154 parameters, - # plus 0 states. +julia> adaptor = ToSimpleChainsAdaptor((28, 28, 1)); + +julia> simple_chains_model = adapt(adaptor, lux_model); # or adaptor(lux_model) julia> ps, st = Lux.setup(Random.default_rng(), simple_chains_model); julia> x = randn(Float32, 28, 28, 1, 1); -julia> size(first(simple_chains_model(x, ps, st))) == (10, 1) -true +julia> size(first(simple_chains_model(x, ps, st))) +(10, 1) ``` """ -struct ToSimpleChainsAdaptor{ID} <: AbstractFromLuxAdaptor +struct ToSimpleChainsAdaptor{ID, AT} <: AbstractFromLuxAdaptor input_dims::ID - convert_to_array::Bool + convert_to_array::AT - function ToSimpleChainsAdaptor(input_dims, convert_to_array::Bool=false) + function ToSimpleChainsAdaptor(input_dims, convert_to_array::BoolType=False()) input_dims isa Number && (input_dims = (input_dims,)) input_dims isa Tuple{Vararg{Integer}} && (input_dims = static(input_dims)) - return new{typeof(input_dims)}(input_dims, convert_to_array) + return new{typeof(input_dims), typeof(convert_to_array)}( + input_dims, convert_to_array) end end diff --git a/src/utils.jl b/src/utils.jl index eb780911d..d47cea613 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,12 +8,17 @@ using EnzymeCore: EnzymeRules using ForwardDiff: Dual using Functors: fmapstructure using Random: AbstractRNG +using Static: Static, StaticBool, StaticInteger, StaticSymbol using LuxCore: LuxCore, AbstractExplicitLayer using MLDataDevices: get_device const CRC = ChainRulesCore +const BoolType = Union{StaticBool, Bool, Val{true}, Val{false}} +const IntegerType = Union{Integer, StaticInteger} +const SymbolType = Union{Symbol, StaticSymbol} + # Aliased `size` from Base size(x::AbstractArray) = Base.size(x) size(x::T) where {T} = hasmethod(Base.size, Tuple{T}) ? Base.size(x) : nothing @@ -188,9 +193,18 @@ function named_tuple_layers(layers::Vararg{AbstractExplicitLayer, N}) where {N} return NamedTuple{ntuple(i -> Symbol(:layer_, i), N)}(layers) end +make_abstract_matrix(x::AbstractVector) = reshape(x, :, 1) +make_abstract_matrix(x::AbstractMatrix) = x +make_abstract_matrix(x::AbstractArray{T, N}) where {T, N} = reshape(x, Base.size(x, 1), :) + +matrix_to_array(x::AbstractMatrix, ::AbstractVector) = vec(x) +matrix_to_array(x::AbstractMatrix, ::AbstractMatrix) = x +matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:end]...) + end -using .Utils: Utils +using .Utils: Utils, BoolType, IntegerType, SymbolType, make_abstract_matrix, + matrix_to_array const safe_reverse = Utils.reverse const safe_vec = Utils.vec diff --git a/test/helpers/size_propagator_test.jl b/test/helpers/size_propagator_test.jl index bf08fde72..9dde070ce 100644 --- a/test/helpers/size_propagator_test.jl +++ b/test/helpers/size_propagator_test.jl @@ -3,7 +3,7 @@ @testset "Simple Chain (LeNet)" begin lenet = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), - Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(), + Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10)) ps, st = Lux.setup(rng, lenet) @@ -14,7 +14,7 @@ @testset "Chain with BatchNorm" begin lenet = Chain(Conv((5, 5), 1 => 6, relu), BatchNorm(6, relu), MaxPool((2, 2)), Conv((5, 5), 6 => 16, relu), BatchNorm(16, relu), - MaxPool((2, 2)), FlattenLayer(), Dense(256 => 120, relu), + MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), BatchNorm(120, relu), Dense(120 => 84, relu), Dropout(0.5f0), BatchNorm(84, relu), Dense(84 => 10), BatchNorm(10, relu)) ps, st = Lux.setup(rng, lenet) diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 64ca3b946..06e20f31e 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,14 +1,12 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin using Aqua, ChainRulesCore, ForwardDiff - Aqua.test_all(Lux; piracies=false, ambiguities=false) + Aqua.test_all(Lux; ambiguities=false) Aqua.test_ambiguities(Lux; exclude=[ForwardDiff.jacobian, ForwardDiff.gradient, Lux.AutoDiffInternalImpl.batched_jacobian, Lux.AutoDiffInternalImpl.jacobian_vector_product, Lux.AutoDiffInternalImpl.jacobian_vector_product_impl]) - Aqua.test_piracies( - Lux; treat_as_own=[ChainRulesCore.frule, ChainRulesCore.rrule, Core.kwcall]) end @testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] tags=[:others] begin @@ -31,8 +29,7 @@ end end # Some of the tests are flaky on prereleases -@testitem "doctests: Quality Assurance" tags=[:others] skip=:(length(VERSION.prerelease) > - 0) begin +@testitem "doctests: Quality Assurance" tags=[:others] begin using Documenter doctestexpr = quote