diff --git a/torch_dreams/maco/features_visualizations/preconditioning.py b/torch_dreams/maco/features_visualizations/preconditioning.py index 1a9185e..6d3a9c1 100644 --- a/torch_dreams/maco/features_visualizations/preconditioning.py +++ b/torch_dreams/maco/features_visualizations/preconditioning.py @@ -21,6 +21,12 @@ "spectrums/imagenet_decorrelated.npy" + + + + + +# Define the recorrelate_colors function for PyTorch def recorrelate_colors(images: torch.Tensor) -> torch.Tensor: """ Map uncorrelated colors to 'normal colors' by using empirical color @@ -36,31 +42,26 @@ def recorrelate_colors(images: torch.Tensor) -> torch.Tensor: ------- images : torch.Tensor Images recorrelated. - """ - # Define the empirical color correlation matrix from ImageNet + """ imagenet_color_correlation = torch.tensor( [[0.56282854, 0.58447580, 0.58447580], [0.19482528, 0.00000000, -0.19482528], [0.04329450, -0.10823626, 0.06494176]], dtype=torch.float32 ) + images_flat = images.reshape(-1, 3) + images_flat = torch.matmul(images_flat, imagenet_color_correlation) + return images_flat.view_as(images) - # Reshape images to a 2D tensor where each row represents a pixel's RGB values - images_flat = images.view(-1, 3) - # Apply the color correlation matrix - images_flat = torch.matmul(images_flat, imagenet_color_correlation) - # Reshape the flat images back to their original shape - return images_flat.view_as(images) -def to_valid_rgb(images: torch.Tensor, - normalizer: str = 'sigmoid', - values_range: tuple = (0, 1)) -> torch.Tensor: +def to_valid_rgb_fixed(images: torch.Tensor, normalizer: str = 'sigmoid', values_range: tuple = (0, 1)) -> torch.Tensor: """ + Apply transformations to map tensors to valid rgb images. Parameters @@ -77,23 +78,29 @@ def to_valid_rgb(images: torch.Tensor, ------- images : torch.Tensor Images after correction + """ images = recorrelate_colors(images) - if normalizer == 'sigmoid': images = torch.sigmoid(images) elif normalizer == 'clip': images = torch.clamp(images, values_range[0], values_range[1]) else: raise ValueError(f"Invalid normalizer: {normalizer}") + + # Rescale according to value range, now correctly handling the reduction over dimensions + images_flat = images.view(images.size(0), -1) # Flatten all dimensions except the batch + min_vals = images_flat.min(dim=1, keepdim=True)[0].view(images.size(0), 1, 1, 1) + max_vals = images_flat.max(dim=1, keepdim=True)[0].view(images.size(0), 1, 1, 1) + + images = (images - min_vals) / (max_vals - min_vals) # Normalize to [0, 1] + images = images * (values_range[1] - values_range[0]) + values_range[0] # Scale to [min_value, max_value] + + return images + + - # rescale according to value range - images = images - images.min(dim=(1, 2, 3), keepdim=True)[0] - images = images / images.max(dim=(1, 2, 3), keepdim=True)[0] - images *= (values_range[1] - values_range[0]) - images += values_range[0] - return images def fft_2d_freq(width: int, height: int) -> np.ndarray: