Skip to content

Commit

Permalink
replace Reshape layer by torch.nn.UnShuffle for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Nov 13, 2024
1 parent f5acef3 commit 47dcf7c
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 17 deletions.
1 change: 0 additions & 1 deletion deel/torchlip/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
from .module import LipschitzModule
from .module import Sequential
from .module import vanilla_model
from .module import Reshape
from .pooling import ScaledAdaptiveAvgPool2d
from .pooling import ScaledAvgPool2d
from .pooling import ScaledL2NormPool2d
Expand Down
11 changes: 1 addition & 10 deletions deel/torchlip/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _is_supported_1lip_layer(layer):
torch.nn.ReLU,
torch.nn.Sigmoid,
torch.nn.Tanh,
Reshape,
torch.nn.Unflatten,
)
if isinstance(layer, supported_1lip_layers):
return True
Expand Down Expand Up @@ -182,12 +182,3 @@ def vanilla_export(self):
else:
layers.append((name, copy.deepcopy(layer)))
return TorchSequential(OrderedDict(layers))


class Reshape(torch.nn.Module):
def __init__(self, target_shape):
super(Reshape, self).__init__()
self.target_shape = target_shape

def forward(self, x):
return reshape(x, self.target_shape)
2 changes: 1 addition & 1 deletion tests/test_compute_layer_sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train_compute_and_verifySV(
logdir = os.path.join("logs", uft.LIP_LAYERS, "%s" % layer_type.__name__)
os.makedirs(logdir, exist_ok=True)

callback_list = []
callback_list = []
if "callbacks" in kwargs and (kwargs["callbacks"] is not None):
callback_list = callback_list + kwargs["callbacks"]
# train model
Expand Down
4 changes: 1 addition & 3 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ def train_k_lip_model(
logdir = os.path.join("logs", uft.LIP_LAYERS, "%s" % layer_type.__name__)
os.makedirs(logdir, exist_ok=True)

callback_list = (
[]
)
callback_list = []
if kwargs["callbacks"] is not None:
callback_list = callback_list + kwargs["callbacks"]
# train model
Expand Down
4 changes: 3 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def test_warning_unsupported_1Lip_layers():
), # kl.Activation("relu"),
uft.get_instance_framework(tSoftmax, {}), # kl.Softmax(),
uft.get_instance_framework(Flatten, {}), # kl.Flatten(),
uft.get_instance_framework(tReshape, {"target_shape": (10,)}), # kl.Reshape(),
uft.get_instance_framework(
tReshape, {"dim": -1, "unflattened_size": (10,)}
), # kl.Reshape(),
uft.get_instance_framework(
tMaxPool2d, {"kernel_size": (2, 2)}
), # kl.MaxPool2d(),
Expand Down
2 changes: 1 addition & 1 deletion tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn import Conv2d as tConv2d
from torch.nn import Conv2d as PadConv2d
from torch.nn import Upsample as tUpSampling2d
from torch.nn import Unflatten as tReshape
from torch import int32 as type_int32
from torch.nn.functional import pad
from torch.nn import MultiMarginLoss as tMultiMarginLoss
Expand All @@ -40,7 +41,6 @@
from deel.torchlip.modules import ScaledL2NormPool2d
from deel.torchlip.modules import InvertibleDownSampling
from deel.torchlip.modules import InvertibleUpSampling
from deel.torchlip.modules import Reshape as tReshape
from deel.torchlip.utils import evaluate_lip_const

from deel.torchlip.modules import (
Expand Down

0 comments on commit 47dcf7c

Please sign in to comment.