Skip to content

Commit

Permalink
Deprecation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
jamaliki committed Oct 24, 2024
1 parent d04efd0 commit 8b3a686
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
12 changes: 12 additions & 0 deletions model_angelo/models/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,15 @@ def __init__(self):
def forward(self, x, y):
s = torch.sigmoid(self.gate)
return s * x + (1 - s) * y


class Upsample(nn.Module):
def __init__(self, scale_factor, mode="nearest"):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode

def forward(self, x):
return F.interpolate(
x, scale_factor=self.scale_factor, mode=self.mode
)
3 changes: 2 additions & 1 deletion model_angelo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
NormalizeStd,
RegularBlock,
ResBlock,
Upsample
)


Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
self.main_layers.append(pooling_class(2, stride=2))
for i in range(self.num_blocks):
if i == self.num_blocks - 2 and downsample_x:
self.main_layers.append(nn.Upsample(scale_factor=2))
self.main_layers.append(Upsample(scale_factor=2))
self.main_layers.append(
ResBlock(
in_channels=self.g_width,
Expand Down
1 change: 0 additions & 1 deletion model_angelo/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,4 +571,3 @@ def compile_if_possible(module: nn.Module) -> nn.Module:
module = torch.compile(module)
return module


0 comments on commit 8b3a686

Please sign in to comment.