diff --git a/lightorch/nn/functional.py b/lightorch/nn/functional.py index 3c63249..62f94e7 100644 --- a/lightorch/nn/functional.py +++ b/lightorch/nn/functional.py @@ -2,10 +2,9 @@ from torch import nn, Tensor from typing import Optional, Union, Tuple, Callable, List import torch.nn.functional as F -from lightning.pytorch import LightningModule from einops import rearrange from torch.fft import fftn -from .utils import FeatureExtractor +from .utils import FeatureExtractor2D def _fourierconvNd(n: int, x: Tensor, weight: Tensor, bias: Union[Tensor,None]) -> Tensor: @@ -358,7 +357,7 @@ def style_loss( input: Tensor, target: Tensor, F_p: Tensor, - feature_extractor: FeatureExtractor = None, + feature_extractor: FeatureExtractor2D = None, ) -> Tensor: if feature_extractor is not None: phi_input: Tensor = feature_extractor(input) @@ -377,7 +376,7 @@ def perceptual_loss( input: Tensor, target: Tensor, N_phi_p: Tensor, - feature_extractor: FeatureExtractor = None, + feature_extractor: FeatureExtractor2D = None, ) -> Tensor: if feature_extractor is not None: phi_input: Tensor = feature_extractor(input)