Skip to content

Commit

Permalink
refactor: move the Upsample layer
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 5, 2024
1 parent 1122d40 commit e47f063
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 147 deletions.
292 changes: 146 additions & 146 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,152 @@ function Base.show(io::IO, l::ConvTranspose)
print(io, ")")
end

"""
Upsample(mode = :nearest; [scale, size, align_corners=false])
Upsample(scale, mode = :nearest)
Upsampling Layer.
## Layer Construction
### Option 1
- `mode`: Set to `:nearest`, `:linear`, `:bilinear` or `:trilinear`
Exactly one of two keywords must be specified:
- If `scale` is a number, this applies to all but the last two dimensions (channel and
batch) of the input. It may also be a tuple, to control dimensions individually.
- Alternatively, keyword `size` accepts a tuple, to directly specify the leading
dimensions of the output.
### Option 2
- If `scale` is a number, this applies to all but the last two dimensions (channel and
batch) of the input. It may also be a tuple, to control dimensions individually.
- `mode`: Set to `:nearest`, `:bilinear` or `:trilinear`
Currently supported upsampling `mode`s and corresponding NNlib's methods are:
- `:nearest` -> `NNlib.upsample_nearest`
- `:bilinear` -> `NNlib.upsample_bilinear`
- `:trilinear` -> `NNlib.upsample_trilinear`
# Extended Help
## Other Keyword Arguments
- `align_corners`: If `true`, the corner pixels of the input and output tensors are
aligned, and thus preserving the values at those pixels. This only has effect when mode
is one of `:bilinear` or `:trilinear`.
## Inputs
- `x`: For the input dimensions look into the documentation for the corresponding `NNlib`
function
+ As a rule of thumb, `:nearest` should work with arrays of arbitrary dimensions
+ `:bilinear` works with 4D Arrays
+ `:trilinear` works with 5D Arrays
## Returns
- Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)`
- Empty `NamedTuple()`
"""
@concrete struct Upsample <: AbstractLuxLayer
scale
size
upsample_mode <: StaticSymbol
align_corners <: Bool
end

function Upsample(mode::SymbolType=static(:nearest); scale=nothing,
size=nothing, align_corners::Bool=false)
@argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear)

if !xor(isnothing(scale), isnothing(size))
throw(ArgumentError("Either scale or size should be specified (but not both)."))
end
return Upsample(scale, size, static(mode), align_corners)
end

Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale)

function (m::Upsample)(x::AbstractArray, _, st::NamedTuple)
return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st
end
function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple)
return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st
end

for interp in (:bilinear, :trilinear)
nnlib_interp_func = Symbol(:upsample_, interp)
@eval begin
function lux_upsample_scale_dispatch(
::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners)
return $(nnlib_interp_func)(x, scale)
end
function lux_upsample_size_dispatch(
::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners)
return $(nnlib_interp_func)(x; size)
end
end
end

function lux_upsample_size_dispatch(::StaticSymbol{:nearest}, x, size, _)
return NNlib.upsample_nearest(x; size)
end
function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale, _)
return NNlib.upsample_nearest(x, scale)
end
function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer, _)
return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2))
end

function Base.show(io::IO, u::Upsample)
print(io, "Upsample(", u.upsample_mode)
u.scale !== nothing && print(io, ", scale = $(u.scale)")
u.size !== nothing && print(io, ", size = $(u.size)")
u.align_corners && print(io, ", align_corners = $(u.align_corners)")
print(io, ")")
end

"""
PixelShuffle(r::Int)
Pixel shuffling layer with upscale factor `r`. Usually used for generating higher
resolution images while upscaling them.
See `NNlib.pixel_shuffle` for more details.
PixelShuffle is not a Layer, rather it returns a [`WrappedFunction`](@ref) with the
function set to `Base.Fix2(pixel_shuffle, r)`
## Arguments
- `r`: Upscale factor
## Inputs
- `x`: For 4D-arrays representing N images, the operation converts input
`size(x) == (W, H, r² x C, N)` to output of size `(r x W, r x H, C, N)`. For
D-dimensional data, it expects `ndims(x) == D + 2` with channel and batch dimensions, and
divides the number of channels by `rᴰ`.
## Returns
- Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)`
for D-dimensional data, where `D = ndims(x) - 2`
"""
@concrete struct PixelShuffle <: AbstractLuxWrapperLayer{:layer}
layer <: AbstractLuxLayer
end

function PixelShuffle(r::IntegerType)
return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r)))
end

@doc doc"""
MaxPool(window::NTuple; pad=0, stride=window)
Expand Down Expand Up @@ -503,117 +649,6 @@ function Base.show(io::IO, m::MeanPool)
print(io, ")")
end

"""
Upsample(mode = :nearest; [scale, size, align_corners=false])
Upsample(scale, mode = :nearest)
Upsampling Layer.
## Layer Construction
### Option 1
- `mode`: Set to `:nearest`, `:linear`, `:bilinear` or `:trilinear`
Exactly one of two keywords must be specified:
- If `scale` is a number, this applies to all but the last two dimensions (channel and
batch) of the input. It may also be a tuple, to control dimensions individually.
- Alternatively, keyword `size` accepts a tuple, to directly specify the leading
dimensions of the output.
### Option 2
- If `scale` is a number, this applies to all but the last two dimensions (channel and
batch) of the input. It may also be a tuple, to control dimensions individually.
- `mode`: Set to `:nearest`, `:bilinear` or `:trilinear`
Currently supported upsampling `mode`s and corresponding NNlib's methods are:
- `:nearest` -> `NNlib.upsample_nearest`
- `:bilinear` -> `NNlib.upsample_bilinear`
- `:trilinear` -> `NNlib.upsample_trilinear`
# Extended Help
## Other Keyword Arguments
- `align_corners`: If `true`, the corner pixels of the input and output tensors are
aligned, and thus preserving the values at those pixels. This only has effect when mode
is one of `:bilinear` or `:trilinear`.
## Inputs
- `x`: For the input dimensions look into the documentation for the corresponding `NNlib`
function
+ As a rule of thumb, `:nearest` should work with arrays of arbitrary dimensions
+ `:bilinear` works with 4D Arrays
+ `:trilinear` works with 5D Arrays
## Returns
- Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)`
- Empty `NamedTuple()`
"""
@concrete struct Upsample <: AbstractLuxLayer
scale
size
upsample_mode <: StaticSymbol
align_corners <: Bool
end

function Upsample(mode::SymbolType=static(:nearest); scale=nothing,
size=nothing, align_corners::Bool=false)
@argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear)

if !xor(isnothing(scale), isnothing(size))
throw(ArgumentError("Either scale or size should be specified (but not both)."))
end
return Upsample(scale, size, static(mode), align_corners)
end

Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale)

function (m::Upsample)(x::AbstractArray, _, st::NamedTuple)
return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st
end
function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple)
return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st
end

for interp in (:bilinear, :trilinear)
nnlib_interp_func = Symbol(:upsample_, interp)
@eval begin
function lux_upsample_scale_dispatch(
::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners)
return $(nnlib_interp_func)(x, scale)
end
function lux_upsample_size_dispatch(
::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners)
return $(nnlib_interp_func)(x; size)
end
end
end

function lux_upsample_size_dispatch(::StaticSymbol{:nearest}, x, size, _)
return NNlib.upsample_nearest(x; size)
end
function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale, _)
return NNlib.upsample_nearest(x, scale)
end
function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer, _)
return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2))
end

function Base.show(io::IO, u::Upsample)
print(io, "Upsample(", u.upsample_mode)
u.scale !== nothing && print(io, ", scale = $(u.scale)")
u.size !== nothing && print(io, ", size = $(u.size)")
u.align_corners && print(io, ", align_corners = $(u.align_corners)")
print(io, ")")
end

"""
GlobalMaxPool()
Expand Down Expand Up @@ -725,38 +760,3 @@ function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) whe
end

Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")")

"""
PixelShuffle(r::Int)
Pixel shuffling layer with upscale factor `r`. Usually used for generating higher
resolution images while upscaling them.
See `NNlib.pixel_shuffle` for more details.
PixelShuffle is not a Layer, rather it returns a [`WrappedFunction`](@ref) with the
function set to `Base.Fix2(pixel_shuffle, r)`
## Arguments
- `r`: Upscale factor
## Inputs
- `x`: For 4D-arrays representing N images, the operation converts input
`size(x) == (W, H, r² x C, N)` to output of size `(r x W, r x H, C, N)`. For
D-dimensional data, it expects `ndims(x) == D + 2` with channel and batch dimensions, and
divides the number of channels by `rᴰ`.
## Returns
- Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)`
for D-dimensional data, where `D = ndims(x) - 2`
"""
@concrete struct PixelShuffle <: AbstractLuxWrapperLayer{:layer}
layer <: AbstractLuxLayer
end

function PixelShuffle(r::IntegerType)
return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r)))
end
2 changes: 1 addition & 1 deletion src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ function initialparameters(rng::AbstractRNG, lstm::LSTMCell)
for init_bias in lstm.init_bias]...)
bias_hh = vcat([init_rnn_bias(rng, init_bias, lstm.out_dims, lstm.out_dims)
for init_bias in lstm.init_bias]...)
ps = merge(ps, (bias_ih, bias_hh))
ps = merge(ps, (; bias_ih, bias_hh))
end
has_train_state(lstm) &&
(ps = merge(ps, (hidden_state=lstm.init_state(rng, lstm.out_dims),)))
Expand Down

0 comments on commit e47f063

Please sign in to comment.