Skip to content

Commit

Permalink
Fix MLP normalization argument (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
weihua916 authored Mar 14, 2024
1 parent 96bdf12 commit c1f0cb8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
4 changes: 3 additions & 1 deletion test/nn/models/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@


@pytest.mark.parametrize('batch_size', [0, 5])
def test_mlp(batch_size):
@pytest.mark.parametrize('normalization', ["layer_norm", "batch_norm"])
def test_mlp(batch_size, normalization):
channels = 8
out_channels = 1
num_layers = 3
Expand All @@ -20,6 +21,7 @@ def test_mlp(batch_size):
num_layers=num_layers,
col_stats=dataset.col_stats,
col_names_dict=tensor_frame.col_names_dict,
normalization=normalization,
)
model.reset_parameters()
out = model(tensor_frame)
Expand Down
16 changes: 9 additions & 7 deletions torch_frame/nn/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class MLP(Module):
for categorical feature and :obj:`LinearEncoder()` for
numerical feature)
normalization (str, optional): The type of normalization to use.
:obj:`batchnorm`, :obj:`layernorm`, or :obj:`None`.
(default: :obj:`layernorm`)
:obj:`batch_norm`, :obj:`layer_norm`, or :obj:`None`.
(default: :obj:`layer_norm`)
dropout_prob (float): The dropout probability (default: `0.2`).
"""
def __init__(
Expand All @@ -62,7 +62,7 @@ def __init__(
col_names_dict: dict[torch_frame.stype, list[str]],
stype_encoder_dict: dict[torch_frame.stype, StypeEncoder]
| None = None,
normalization: str | None = "layernorm",
normalization: str | None = "layer_norm",
dropout_prob: float = 0.2,
) -> None:
super().__init__()
Expand All @@ -81,10 +81,13 @@ def __init__(
)

self.mlp = Sequential()
norm_cls = LayerNorm if normalization == "layernorm" else BatchNorm1d

for _ in range(num_layers - 1):
self.mlp.append(Linear(channels, channels))
self.mlp.append(norm_cls(channels))
if normalization == "layer_norm":
self.mlp.append(LayerNorm(channels))
elif normalization == "batch_norm":
self.mlp.append(BatchNorm1d(channels))
self.mlp.append(ReLU())
self.mlp.append(Dropout(p=dropout_prob))
self.mlp.append(Linear(channels, out_channels))
Expand All @@ -94,8 +97,7 @@ def __init__(
def reset_parameters(self) -> None:
self.encoder.reset_parameters()
for param in self.mlp:
if (isinstance(param, Linear) or isinstance(param, BatchNorm1d)
or isinstance(param, LayerNorm)):
if hasattr(param, 'reset_parameters'):
param.reset_parameters()

def forward(self, tf: TensorFrame) -> Tensor:
Expand Down
16 changes: 8 additions & 8 deletions torch_frame/nn/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ class FCResidualBlock(Module):
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
normalization (str, optional): The type of normalization to use.
:obj:`batchnorm`, :class:`torch.nn.LayerNorm`, or :obj:`None`.
(default: :class:`torch.nn.LayerNorm`)
:obj:`layer_norm`, :obj:`batch_norm`, or :obj:`None`.
(default: :obj:`layer_norm`)
dropout_prob (float): The dropout probability (default: `0.0`, i.e.,
no dropout).
"""
def __init__(
self,
in_channels: int,
out_channels: int,
normalization: str | None = "layernorm",
normalization: str | None = "layer_norm",
dropout_prob: float = 0.0,
) -> None:
super().__init__()
Expand All @@ -52,10 +52,10 @@ def __init__(

self.norm1: BatchNorm1d | LayerNorm | None
self.norm2: BatchNorm1d | LayerNorm | None
if normalization == "batchnorm":
if normalization == "batch_norm":
self.norm1 = BatchNorm1d(out_channels)
self.norm2 = BatchNorm1d(out_channels)
elif normalization == "layernorm":
elif normalization == "layer_norm":
self.norm1 = LayerNorm(out_channels)
self.norm2 = LayerNorm(out_channels)
else:
Expand Down Expand Up @@ -127,8 +127,8 @@ class ResNet(Module):
for categorical feature and :obj:`LinearEncoder()` for
numerical feature)
normalization (str, optional): The type of normalization to use.
:obj:`batchnorm`, :obj:`layernorm`, or :obj:`None`.
(default: :obj:`layernorm`)
:obj:`batch_norm`, :obj:`layer_norm`, or :obj:`None`.
(default: :obj:`layer_norm`)
dropout_prob (float): The dropout probability (default: `0.2`).
"""
def __init__(
Expand All @@ -140,7 +140,7 @@ def __init__(
col_names_dict: dict[torch_frame.stype, list[str]],
stype_encoder_dict: dict[torch_frame.stype, StypeEncoder]
| None = None,
normalization: str | None = "layernorm",
normalization: str | None = "layer_norm",
dropout_prob: float = 0.2,
) -> None:
super().__init__()
Expand Down

0 comments on commit c1f0cb8

Please sign in to comment.