Skip to content

Commit

Permalink
added the required files
Browse files Browse the repository at this point in the history
  • Loading branch information
sushmanthreddy committed Mar 24, 2024
1 parent 535ac7f commit 458ea3f
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions torch_dreams/maco/features_visualizations/preconditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
"spectrums/imagenet_decorrelated.npy"







# Define the recorrelate_colors function for PyTorch
def recorrelate_colors(images: torch.Tensor) -> torch.Tensor:
"""
Map uncorrelated colors to 'normal colors' by using empirical color
Expand All @@ -36,31 +42,26 @@ def recorrelate_colors(images: torch.Tensor) -> torch.Tensor:
-------
images : torch.Tensor
Images recorrelated.
"""
# Define the empirical color correlation matrix from ImageNet
"""
imagenet_color_correlation = torch.tensor(
[[0.56282854, 0.58447580, 0.58447580],
[0.19482528, 0.00000000, -0.19482528],
[0.04329450, -0.10823626, 0.06494176]],
dtype=torch.float32
)
images_flat = images.reshape(-1, 3)
images_flat = torch.matmul(images_flat, imagenet_color_correlation)
return images_flat.view_as(images)

# Reshape images to a 2D tensor where each row represents a pixel's RGB values
images_flat = images.view(-1, 3)

# Apply the color correlation matrix
images_flat = torch.matmul(images_flat, imagenet_color_correlation)

# Reshape the flat images back to their original shape
return images_flat.view_as(images)



def to_valid_rgb(images: torch.Tensor,
normalizer: str = 'sigmoid',
values_range: tuple = (0, 1)) -> torch.Tensor:
def to_valid_rgb_fixed(images: torch.Tensor, normalizer: str = 'sigmoid', values_range: tuple = (0, 1)) -> torch.Tensor:
"""
Apply transformations to map tensors to valid rgb images.
Parameters
Expand All @@ -77,23 +78,29 @@ def to_valid_rgb(images: torch.Tensor,
-------
images : torch.Tensor
Images after correction
"""
images = recorrelate_colors(images)

if normalizer == 'sigmoid':
images = torch.sigmoid(images)
elif normalizer == 'clip':
images = torch.clamp(images, values_range[0], values_range[1])
else:
raise ValueError(f"Invalid normalizer: {normalizer}")

# Rescale according to value range, now correctly handling the reduction over dimensions
images_flat = images.view(images.size(0), -1) # Flatten all dimensions except the batch
min_vals = images_flat.min(dim=1, keepdim=True)[0].view(images.size(0), 1, 1, 1)
max_vals = images_flat.max(dim=1, keepdim=True)[0].view(images.size(0), 1, 1, 1)

images = (images - min_vals) / (max_vals - min_vals) # Normalize to [0, 1]
images = images * (values_range[1] - values_range[0]) + values_range[0] # Scale to [min_value, max_value]

return images



# rescale according to value range
images = images - images.min(dim=(1, 2, 3), keepdim=True)[0]
images = images / images.max(dim=(1, 2, 3), keepdim=True)[0]
images *= (values_range[1] - values_range[0])
images += values_range[0]

return images


def fft_2d_freq(width: int, height: int) -> np.ndarray:
Expand Down

0 comments on commit 458ea3f

Please sign in to comment.