diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 900d8d3b2..66c5e6178 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -363,6 +363,152 @@ function Base.show(io::IO, l::ConvTranspose) print(io, ")") end +""" + Upsample(mode = :nearest; [scale, size, align_corners=false]) + Upsample(scale, mode = :nearest) + +Upsampling Layer. + +## Layer Construction + +### Option 1 + + - `mode`: Set to `:nearest`, `:linear`, `:bilinear` or `:trilinear` + +Exactly one of two keywords must be specified: + + - If `scale` is a number, this applies to all but the last two dimensions (channel and + batch) of the input. It may also be a tuple, to control dimensions individually. + - Alternatively, keyword `size` accepts a tuple, to directly specify the leading + dimensions of the output. + +### Option 2 + + - If `scale` is a number, this applies to all but the last two dimensions (channel and + batch) of the input. It may also be a tuple, to control dimensions individually. + - `mode`: Set to `:nearest`, `:bilinear` or `:trilinear` + +Currently supported upsampling `mode`s and corresponding NNlib's methods are: + + - `:nearest` -> `NNlib.upsample_nearest` + - `:bilinear` -> `NNlib.upsample_bilinear` + - `:trilinear` -> `NNlib.upsample_trilinear` + +# Extended Help + +## Other Keyword Arguments + + - `align_corners`: If `true`, the corner pixels of the input and output tensors are + aligned, and thus preserving the values at those pixels. This only has effect when mode + is one of `:bilinear` or `:trilinear`. + +## Inputs + + - `x`: For the input dimensions look into the documentation for the corresponding `NNlib` + function + + + As a rule of thumb, `:nearest` should work with arrays of arbitrary dimensions + + `:bilinear` works with 4D Arrays + + `:trilinear` works with 5D Arrays + +## Returns + + - Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)` + - Empty `NamedTuple()` +""" +@concrete struct Upsample <: AbstractLuxLayer + scale + size + upsample_mode <: StaticSymbol + align_corners <: Bool +end + +function Upsample(mode::SymbolType=static(:nearest); scale=nothing, + size=nothing, align_corners::Bool=false) + @argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear) + + if !xor(isnothing(scale), isnothing(size)) + throw(ArgumentError("Either scale or size should be specified (but not both).")) + end + return Upsample(scale, size, static(mode), align_corners) +end + +Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale) + +function (m::Upsample)(x::AbstractArray, _, st::NamedTuple) + return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st +end +function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple) + return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st +end + +for interp in (:bilinear, :trilinear) + nnlib_interp_func = Symbol(:upsample_, interp) + @eval begin + function lux_upsample_scale_dispatch( + ::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners) + return $(nnlib_interp_func)(x, scale) + end + function lux_upsample_size_dispatch( + ::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners) + return $(nnlib_interp_func)(x; size) + end + end +end + +function lux_upsample_size_dispatch(::StaticSymbol{:nearest}, x, size, _) + return NNlib.upsample_nearest(x; size) +end +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale, _) + return NNlib.upsample_nearest(x, scale) +end +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer, _) + return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2)) +end + +function Base.show(io::IO, u::Upsample) + print(io, "Upsample(", u.upsample_mode) + u.scale !== nothing && print(io, ", scale = $(u.scale)") + u.size !== nothing && print(io, ", size = $(u.size)") + u.align_corners && print(io, ", align_corners = $(u.align_corners)") + print(io, ")") +end + +""" + PixelShuffle(r::Int) + +Pixel shuffling layer with upscale factor `r`. Usually used for generating higher +resolution images while upscaling them. + +See `NNlib.pixel_shuffle` for more details. + +PixelShuffle is not a Layer, rather it returns a [`WrappedFunction`](@ref) with the +function set to `Base.Fix2(pixel_shuffle, r)` + +## Arguments + + - `r`: Upscale factor + +## Inputs + + - `x`: For 4D-arrays representing N images, the operation converts input + `size(x) == (W, H, r² x C, N)` to output of size `(r x W, r x H, C, N)`. For + D-dimensional data, it expects `ndims(x) == D + 2` with channel and batch dimensions, and + divides the number of channels by `rᴰ`. + +## Returns + + - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` + for D-dimensional data, where `D = ndims(x) - 2` +""" +@concrete struct PixelShuffle <: AbstractLuxWrapperLayer{:layer} + layer <: AbstractLuxLayer +end + +function PixelShuffle(r::IntegerType) + return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r))) +end + @doc doc""" MaxPool(window::NTuple; pad=0, stride=window) @@ -503,117 +649,6 @@ function Base.show(io::IO, m::MeanPool) print(io, ")") end -""" - Upsample(mode = :nearest; [scale, size, align_corners=false]) - Upsample(scale, mode = :nearest) - -Upsampling Layer. - -## Layer Construction - -### Option 1 - - - `mode`: Set to `:nearest`, `:linear`, `:bilinear` or `:trilinear` - -Exactly one of two keywords must be specified: - - - If `scale` is a number, this applies to all but the last two dimensions (channel and - batch) of the input. It may also be a tuple, to control dimensions individually. - - Alternatively, keyword `size` accepts a tuple, to directly specify the leading - dimensions of the output. - -### Option 2 - - - If `scale` is a number, this applies to all but the last two dimensions (channel and - batch) of the input. It may also be a tuple, to control dimensions individually. - - `mode`: Set to `:nearest`, `:bilinear` or `:trilinear` - -Currently supported upsampling `mode`s and corresponding NNlib's methods are: - - - `:nearest` -> `NNlib.upsample_nearest` - - `:bilinear` -> `NNlib.upsample_bilinear` - - `:trilinear` -> `NNlib.upsample_trilinear` - -# Extended Help - -## Other Keyword Arguments - - - `align_corners`: If `true`, the corner pixels of the input and output tensors are - aligned, and thus preserving the values at those pixels. This only has effect when mode - is one of `:bilinear` or `:trilinear`. - -## Inputs - - - `x`: For the input dimensions look into the documentation for the corresponding `NNlib` - function - - + As a rule of thumb, `:nearest` should work with arrays of arbitrary dimensions - + `:bilinear` works with 4D Arrays - + `:trilinear` works with 5D Arrays - -## Returns - - - Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)` - - Empty `NamedTuple()` -""" -@concrete struct Upsample <: AbstractLuxLayer - scale - size - upsample_mode <: StaticSymbol - align_corners <: Bool -end - -function Upsample(mode::SymbolType=static(:nearest); scale=nothing, - size=nothing, align_corners::Bool=false) - @argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear) - - if !xor(isnothing(scale), isnothing(size)) - throw(ArgumentError("Either scale or size should be specified (but not both).")) - end - return Upsample(scale, size, static(mode), align_corners) -end - -Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale) - -function (m::Upsample)(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st -end -function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st -end - -for interp in (:bilinear, :trilinear) - nnlib_interp_func = Symbol(:upsample_, interp) - @eval begin - function lux_upsample_scale_dispatch( - ::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners) - return $(nnlib_interp_func)(x, scale) - end - function lux_upsample_size_dispatch( - ::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners) - return $(nnlib_interp_func)(x; size) - end - end -end - -function lux_upsample_size_dispatch(::StaticSymbol{:nearest}, x, size, _) - return NNlib.upsample_nearest(x; size) -end -function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale, _) - return NNlib.upsample_nearest(x, scale) -end -function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer, _) - return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2)) -end - -function Base.show(io::IO, u::Upsample) - print(io, "Upsample(", u.upsample_mode) - u.scale !== nothing && print(io, ", scale = $(u.scale)") - u.size !== nothing && print(io, ", size = $(u.size)") - u.align_corners && print(io, ", align_corners = $(u.align_corners)") - print(io, ")") -end - """ GlobalMaxPool() @@ -725,38 +760,3 @@ function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) whe end Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")") - -""" - PixelShuffle(r::Int) - -Pixel shuffling layer with upscale factor `r`. Usually used for generating higher -resolution images while upscaling them. - -See `NNlib.pixel_shuffle` for more details. - -PixelShuffle is not a Layer, rather it returns a [`WrappedFunction`](@ref) with the -function set to `Base.Fix2(pixel_shuffle, r)` - -## Arguments - - - `r`: Upscale factor - -## Inputs - - - `x`: For 4D-arrays representing N images, the operation converts input - `size(x) == (W, H, r² x C, N)` to output of size `(r x W, r x H, C, N)`. For - D-dimensional data, it expects `ndims(x) == D + 2` with channel and batch dimensions, and - divides the number of channels by `rᴰ`. - -## Returns - - - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` - for D-dimensional data, where `D = ndims(x) - 2` -""" -@concrete struct PixelShuffle <: AbstractLuxWrapperLayer{:layer} - layer <: AbstractLuxLayer -end - -function PixelShuffle(r::IntegerType) - return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r))) -end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index fe039bfd6..878c713be 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -422,7 +422,7 @@ function initialparameters(rng::AbstractRNG, lstm::LSTMCell) for init_bias in lstm.init_bias]...) bias_hh = vcat([init_rnn_bias(rng, init_bias, lstm.out_dims, lstm.out_dims) for init_bias in lstm.init_bias]...) - ps = merge(ps, (bias_ih, bias_hh)) + ps = merge(ps, (; bias_ih, bias_hh)) end has_train_state(lstm) && (ps = merge(ps, (hidden_state=lstm.init_state(rng, lstm.out_dims),)))