Skip to content

Commit

Permalink
style: format some python
Browse files Browse the repository at this point in the history
  • Loading branch information
jesseylin committed Jul 24, 2023
1 parent 6816be8 commit 088ff31
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,40 @@


def roll_n(X: torch.Tensor, axis: int, n: int) -> torch.Tensor:
""" Rolls a tensor with a shift n on the specified axis"""
f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None)
for i in range(X.dim()))
"""Rolls a tensor with a shift n on the specified axis"""
f_idx = tuple(
slice(None, None, None) if i != axis else slice(0, n, None)
for i in range(X.dim())
)
b_idx = tuple(
slice(
None,
None,
None) if i != axis else slice(
n,
None,
None) for i in range(
X.dim()))
slice(None, None, None) if i != axis else slice(n, None, None)
for i in range(X.dim())
)
front = X[f_idx]
back = X[b_idx]
return torch.cat([back, front], axis)


def center_of_mass(input_array: torch.Tensor) -> torch.Tensor:

normalizer = input_array.sum()
grids = torch.meshgrid(*[torch.arange(0, i) for i in input_array.shape])

center = torch.tensor([(input_array *
grids[dir].double()).sum() /
normalizer for dir in range(input_array.ndim)])
center = torch.tensor(
[
(input_array * grids[dir].double()).sum() / normalizer
for dir in range(input_array.ndim)
]
)

if torch.any(torch.isnan(center)):
center = torch.tensor(
[int((input_array.shape[0] - 1) / 2), int((input_array.shape[1] - 1) / 2)])
[int((input_array.shape[0] - 1) / 2), int((input_array.shape[1] - 1) / 2)]
)

return center


def calc_distance_matrix(size_y: int, size_x: int) -> torch.Tensor:

dist_mat = torch.zeros([size_y, size_x])

mid_y = (size_y - 1) / 2
Expand All @@ -54,16 +53,15 @@ def calc_distance_matrix(size_y: int, size_x: int) -> torch.Tensor:

for y in range(size_y):
for x in range(size_x):
dist_mat[y][x] = (1 -
int(torch.linalg.norm(mid -
torch.tensor([y, x]))) /
max_dist) ** DISTANCE_WEIGHT
dist_mat[y][x] = (
1 - int(torch.linalg.norm(mid - torch.tensor([y, x]))) / max_dist
) ** DISTANCE_WEIGHT

return dist_mat


def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]:
'''
"""
Calculates the image moments for an image.
For more information see:
Expand All @@ -75,15 +73,14 @@ def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]:
:param image: 2d gray scale image in form of a numpy array.
:return: Namedtupel with the different moments.
'''
"""

eps = 0.00001

size_y = image.shape[0]
size_x = image.shape[1]

x_grid, y_grid = torch.meshgrid(
torch.arange(0, size_x), torch.arange(0, size_y))
x_grid, y_grid = torch.meshgrid(torch.arange(0, size_x), torch.arange(0, size_y))

y_power1_image = y_grid * image
y_power2_image = y_grid * y_power1_image
Expand Down Expand Up @@ -121,9 +118,9 @@ def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]:

# in the case of very small activation (m00 ~ 0) the position becomes
# infinity, also then use the center position
if mY == float('inf'):
if mY == float("inf"):
mY = (image.shape[0] - 1) / 2
if mX == float('inf'):
if mX == float("inf"):
mX = (image.shape[1] - 1) / 2

# calculate the central moments
Expand All @@ -140,12 +137,31 @@ def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]:
mu03 = m03 - 3 * mX * m02 + 2 * X2 * m01
mu21 = m21 - 2 * mY * m11 - mX * m20 + 2 * Y2 * m01
mu12 = m12 - 2 * mX * m11 - mY * m02 + 2 * X2 * m10
mu22 = m22 - 2 * mX * m21 + X2 * m20 - 2 * mY * m12 + 4 * XY * m11 - \
2 * mY * X2 * m10 + Y2 * m02 - 2 * Y2 * mX * m01 + Y2 * X2 * m00
mu31 = m31 - mX * m30 + 3 * mY * \
(mX * m20 - m21) + 3 * Y2 * (m11 - mX * m10) + Y3 * (mX * m00 - m01)
mu13 = m13 - mY * m03 + 3 * mX * \
(mY * m02 - m12) + 3 * X2 * (m11 - mY * m01) + X3 * (mY * m00 - m10)
mu22 = (
m22
- 2 * mX * m21
+ X2 * m20
- 2 * mY * m12
+ 4 * XY * m11
- 2 * mY * X2 * m10
+ Y2 * m02
- 2 * Y2 * mX * m01
+ Y2 * X2 * m00
)
mu31 = (
m31
- mX * m30
+ 3 * mY * (mX * m20 - m21)
+ 3 * Y2 * (m11 - mX * m10)
+ Y3 * (mX * m00 - m01)
)
mu13 = (
m13
- mY * m03
+ 3 * mX * (mY * m02 - m12)
+ 3 * X2 * (m11 - mY * m01)
+ X3 * (mY * m00 - m10)
)
mu40 = m40 - 4 * mY * m30 + 6 * Y2 * m20 - 4 * Y3 * m10 + Y2 * Y2 * m00
mu04 = m04 - 4 * mX * m03 + 6 * X2 * m02 - 4 * X3 * m01 + X2 * X2 * m00

Expand Down Expand Up @@ -277,7 +293,8 @@ def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]:
flusser10=flusser10,
flusser11=flusser11,
flusser12=flusser12,
flusser13=flusser13)
flusser13=flusser13,
)

return result

Expand All @@ -287,11 +304,13 @@ class LeniaStatistics:
Outputs 17-dimensional embedding.
"""

def __init__(self, premap_key: str = "output",
postmap_key: str = "output",
SX: int = 256,
SY: int = 256):

def __init__(
self,
premap_key: str = "output",
postmap_key: str = "output",
SX: int = 256,
SY: int = 256,
):
self.premap_key = premap_key
self.postmap_key = postmap_key

Expand All @@ -300,33 +319,34 @@ def __init__(self, premap_key: str = "output",

# model
self._statistic_names = [
'activation_mass',
'activation_volume',
'activation_density',
'activation_mass_distribution',
'activation_hu1',
'activation_hu2',
'activation_hu3',
'activation_hu4',
'activation_hu5',
'activation_hu6',
'activation_hu7',
'activation_hu8',
'activation_flusser9',
'activation_flusser10',
'activation_flusser11',
'activation_flusser12',
'activation_flusser13']
"activation_mass",
"activation_volume",
"activation_density",
"activation_mass_distribution",
"activation_hu1",
"activation_hu2",
"activation_hu3",
"activation_hu4",
"activation_hu5",
"activation_hu6",
"activation_hu7",
"activation_hu8",
"activation_flusser9",
"activation_flusser10",
"activation_flusser11",
"activation_flusser12",
"activation_flusser13",
]
self._n_latents = len(self._statistic_names)

def map(self, input: typing.Dict) -> typing.Dict:
"""
Compute statistics on Lenia's output
Args:
input: Lenia's output
is_output_new_discovery: indicates if it is a new discovery
Returns:
Return a torch tensor in dict
Compute statistics on Lenia's output
Args:
input: Lenia's output
is_output_new_discovery: indicates if it is a new discovery
Returns:
Return a torch tensor in dict
"""
intermed_dict = deepcopy(input)

Expand All @@ -346,17 +366,18 @@ def map(self, input: typing.Dict) -> typing.Dict:
def sample(self):
return self.projector.sample()

def _calc_distance(self, embedding_a: torch.Tensor,
embedding_b: torch.Tensor) -> torch.Tensor:
def _calc_distance(
self, embedding_a: torch.Tensor, embedding_b: torch.Tensor
) -> torch.Tensor:
"""
Compute the distance between 2 embeddings in the latent space
/!\\ batch mode embedding_a and embedding_b can be N*M or M
Compute the distance between 2 embeddings in the latent space
/!\\ batch mode embedding_a and embedding_b can be N*M or M
"""
# l2 loss
return (embedding_a - embedding_b).pow(2).sum(-1).sqrt()

def _calc_static_statistics(self, final_obs: torch.Tensor) -> torch.Tensor:
'''Calculates the final statistics for lenia last observation'''
"""Calculates the final statistics for lenia last observation"""

feature_vector = torch.zeros(self._n_latents)
cur_idx = 0
Expand All @@ -375,10 +396,10 @@ def _calc_static_statistics(self, final_obs: torch.Tensor) -> torch.Tensor:
activation_shift_to_center = mid - activation_center_of_mass

activation = final_obs
centered_activation = roll_n(activation, 0, activation_shift_to_center[0].int())
centered_activation = roll_n(
activation, 0, activation_shift_to_center[0].int())
centered_activation = roll_n(
centered_activation, 1, activation_shift_to_center[1].int())
centered_activation, 1, activation_shift_to_center[1].int()
)

# calculate the image moments
activation_moments = calc_image_moments(centered_activation)
Expand Down Expand Up @@ -406,13 +427,13 @@ def _calc_static_statistics(self, final_obs: torch.Tensor) -> torch.Tensor:
cur_idx += 1

# mass distribution around the center
distance_weight_matrix = calc_distance_matrix(
self.SY, self.SX)
distance_weight_matrix = calc_distance_matrix(self.SY, self.SX)
if activation_mass <= EPS:
activation_mass_distribution = 1.0
else:
activation_mass_distribution = torch.sum(
distance_weight_matrix * centered_activation) / torch.sum(centered_activation)
distance_weight_matrix * centered_activation
) / torch.sum(centered_activation)

activation_mass_distribution_data = activation_mass_distribution
feature_vector[cur_idx] = activation_mass_distribution_data
Expand Down
Loading

0 comments on commit 088ff31

Please sign in to comment.