Skip to content

Commit

Permalink
add maxpool to refiners layer
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Nov 20, 2023
1 parent f666bc8 commit 2d4c477
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/refiners/fluxion/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
from refiners.fluxion.layers.embedding import Embedding
from refiners.fluxion.layers.converter import Converter
from refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d

__all__ = [
"Embedding",
Expand Down Expand Up @@ -104,4 +105,6 @@
"ReflectionPad2d",
"PixelUnshuffle",
"Converter",
"MaxPool1d",
"MaxPool2d",
]
42 changes: 42 additions & 0 deletions src/refiners/fluxion/layers/maxpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from torch import nn
from refiners.fluxion.layers.module import Module


class MaxPool1d(nn.MaxPool1d, Module):
def __init__(
self,
kernel_size: int,
stride: int | None = None,
padding: int = 0,
dilation: int = 1,
return_indices: bool = False,
ceil_mode: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode,
)


class MaxPool2d(nn.MaxPool2d, Module):
def __init__(
self,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] | None = None,
padding: int | tuple[int, int] = (0, 0),
dilation: int | tuple[int, int] = (1, 1),
return_indices: bool = False,
ceil_mode: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding, # type: ignore
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode,
)

0 comments on commit 2d4c477

Please sign in to comment.