From 088ff31a1154db3171771521daf83cd074a3e43a Mon Sep 17 00:00:00 2001 From: Jesse Y Lin Date: Mon, 24 Jul 2023 17:30:09 +0200 Subject: [PATCH] style: format some python --- .../flow_lenia/flow_lenia/descriptor.py | 167 +++++---- .../systems/flow_lenia/flow_lenia/main.py | 349 +++++++++++------- 2 files changed, 301 insertions(+), 215 deletions(-) diff --git a/services/base/AutoDiscServer/libs/adtool_custom/systems/flow_lenia/flow_lenia/descriptor.py b/services/base/AutoDiscServer/libs/adtool_custom/systems/flow_lenia/flow_lenia/descriptor.py index 869f3dd6..2e2b9ec2 100644 --- a/services/base/AutoDiscServer/libs/adtool_custom/systems/flow_lenia/flow_lenia/descriptor.py +++ b/services/base/AutoDiscServer/libs/adtool_custom/systems/flow_lenia/flow_lenia/descriptor.py @@ -9,41 +9,40 @@ def roll_n(X: torch.Tensor, axis: int, n: int) -> torch.Tensor: - """ Rolls a tensor with a shift n on the specified axis""" - f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) - for i in range(X.dim())) + """Rolls a tensor with a shift n on the specified axis""" + f_idx = tuple( + slice(None, None, None) if i != axis else slice(0, n, None) + for i in range(X.dim()) + ) b_idx = tuple( - slice( - None, - None, - None) if i != axis else slice( - n, - None, - None) for i in range( - X.dim())) + slice(None, None, None) if i != axis else slice(n, None, None) + for i in range(X.dim()) + ) front = X[f_idx] back = X[b_idx] return torch.cat([back, front], axis) def center_of_mass(input_array: torch.Tensor) -> torch.Tensor: - normalizer = input_array.sum() grids = torch.meshgrid(*[torch.arange(0, i) for i in input_array.shape]) - center = torch.tensor([(input_array * - grids[dir].double()).sum() / - normalizer for dir in range(input_array.ndim)]) + center = torch.tensor( + [ + (input_array * grids[dir].double()).sum() / normalizer + for dir in range(input_array.ndim) + ] + ) if torch.any(torch.isnan(center)): center = torch.tensor( - [int((input_array.shape[0] - 1) / 2), int((input_array.shape[1] - 1) / 2)]) + [int((input_array.shape[0] - 1) / 2), int((input_array.shape[1] - 1) / 2)] + ) return center def calc_distance_matrix(size_y: int, size_x: int) -> torch.Tensor: - dist_mat = torch.zeros([size_y, size_x]) mid_y = (size_y - 1) / 2 @@ -54,16 +53,15 @@ def calc_distance_matrix(size_y: int, size_x: int) -> torch.Tensor: for y in range(size_y): for x in range(size_x): - dist_mat[y][x] = (1 - - int(torch.linalg.norm(mid - - torch.tensor([y, x]))) / - max_dist) ** DISTANCE_WEIGHT + dist_mat[y][x] = ( + 1 - int(torch.linalg.norm(mid - torch.tensor([y, x]))) / max_dist + ) ** DISTANCE_WEIGHT return dist_mat def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]: - ''' + """ Calculates the image moments for an image. For more information see: @@ -75,15 +73,14 @@ def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]: :param image: 2d gray scale image in form of a numpy array. :return: Namedtupel with the different moments. - ''' + """ eps = 0.00001 size_y = image.shape[0] size_x = image.shape[1] - x_grid, y_grid = torch.meshgrid( - torch.arange(0, size_x), torch.arange(0, size_y)) + x_grid, y_grid = torch.meshgrid(torch.arange(0, size_x), torch.arange(0, size_y)) y_power1_image = y_grid * image y_power2_image = y_grid * y_power1_image @@ -121,9 +118,9 @@ def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]: # in the case of very small activation (m00 ~ 0) the position becomes # infinity, also then use the center position - if mY == float('inf'): + if mY == float("inf"): mY = (image.shape[0] - 1) / 2 - if mX == float('inf'): + if mX == float("inf"): mX = (image.shape[1] - 1) / 2 # calculate the central moments @@ -140,12 +137,31 @@ def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]: mu03 = m03 - 3 * mX * m02 + 2 * X2 * m01 mu21 = m21 - 2 * mY * m11 - mX * m20 + 2 * Y2 * m01 mu12 = m12 - 2 * mX * m11 - mY * m02 + 2 * X2 * m10 - mu22 = m22 - 2 * mX * m21 + X2 * m20 - 2 * mY * m12 + 4 * XY * m11 - \ - 2 * mY * X2 * m10 + Y2 * m02 - 2 * Y2 * mX * m01 + Y2 * X2 * m00 - mu31 = m31 - mX * m30 + 3 * mY * \ - (mX * m20 - m21) + 3 * Y2 * (m11 - mX * m10) + Y3 * (mX * m00 - m01) - mu13 = m13 - mY * m03 + 3 * mX * \ - (mY * m02 - m12) + 3 * X2 * (m11 - mY * m01) + X3 * (mY * m00 - m10) + mu22 = ( + m22 + - 2 * mX * m21 + + X2 * m20 + - 2 * mY * m12 + + 4 * XY * m11 + - 2 * mY * X2 * m10 + + Y2 * m02 + - 2 * Y2 * mX * m01 + + Y2 * X2 * m00 + ) + mu31 = ( + m31 + - mX * m30 + + 3 * mY * (mX * m20 - m21) + + 3 * Y2 * (m11 - mX * m10) + + Y3 * (mX * m00 - m01) + ) + mu13 = ( + m13 + - mY * m03 + + 3 * mX * (mY * m02 - m12) + + 3 * X2 * (m11 - mY * m01) + + X3 * (mY * m00 - m10) + ) mu40 = m40 - 4 * mY * m30 + 6 * Y2 * m20 - 4 * Y3 * m10 + Y2 * Y2 * m00 mu04 = m04 - 4 * mX * m03 + 6 * X2 * m02 - 4 * X3 * m01 + X2 * X2 * m00 @@ -277,7 +293,8 @@ def calc_image_moments(image: torch.Tensor) -> typing.Dict[str, torch.Tensor]: flusser10=flusser10, flusser11=flusser11, flusser12=flusser12, - flusser13=flusser13) + flusser13=flusser13, + ) return result @@ -287,11 +304,13 @@ class LeniaStatistics: Outputs 17-dimensional embedding. """ - def __init__(self, premap_key: str = "output", - postmap_key: str = "output", - SX: int = 256, - SY: int = 256): - + def __init__( + self, + premap_key: str = "output", + postmap_key: str = "output", + SX: int = 256, + SY: int = 256, + ): self.premap_key = premap_key self.postmap_key = postmap_key @@ -300,33 +319,34 @@ def __init__(self, premap_key: str = "output", # model self._statistic_names = [ - 'activation_mass', - 'activation_volume', - 'activation_density', - 'activation_mass_distribution', - 'activation_hu1', - 'activation_hu2', - 'activation_hu3', - 'activation_hu4', - 'activation_hu5', - 'activation_hu6', - 'activation_hu7', - 'activation_hu8', - 'activation_flusser9', - 'activation_flusser10', - 'activation_flusser11', - 'activation_flusser12', - 'activation_flusser13'] + "activation_mass", + "activation_volume", + "activation_density", + "activation_mass_distribution", + "activation_hu1", + "activation_hu2", + "activation_hu3", + "activation_hu4", + "activation_hu5", + "activation_hu6", + "activation_hu7", + "activation_hu8", + "activation_flusser9", + "activation_flusser10", + "activation_flusser11", + "activation_flusser12", + "activation_flusser13", + ] self._n_latents = len(self._statistic_names) def map(self, input: typing.Dict) -> typing.Dict: """ - Compute statistics on Lenia's output - Args: - input: Lenia's output - is_output_new_discovery: indicates if it is a new discovery - Returns: - Return a torch tensor in dict + Compute statistics on Lenia's output + Args: + input: Lenia's output + is_output_new_discovery: indicates if it is a new discovery + Returns: + Return a torch tensor in dict """ intermed_dict = deepcopy(input) @@ -346,17 +366,18 @@ def map(self, input: typing.Dict) -> typing.Dict: def sample(self): return self.projector.sample() - def _calc_distance(self, embedding_a: torch.Tensor, - embedding_b: torch.Tensor) -> torch.Tensor: + def _calc_distance( + self, embedding_a: torch.Tensor, embedding_b: torch.Tensor + ) -> torch.Tensor: """ - Compute the distance between 2 embeddings in the latent space - /!\\ batch mode embedding_a and embedding_b can be N*M or M + Compute the distance between 2 embeddings in the latent space + /!\\ batch mode embedding_a and embedding_b can be N*M or M """ # l2 loss return (embedding_a - embedding_b).pow(2).sum(-1).sqrt() def _calc_static_statistics(self, final_obs: torch.Tensor) -> torch.Tensor: - '''Calculates the final statistics for lenia last observation''' + """Calculates the final statistics for lenia last observation""" feature_vector = torch.zeros(self._n_latents) cur_idx = 0 @@ -375,10 +396,10 @@ def _calc_static_statistics(self, final_obs: torch.Tensor) -> torch.Tensor: activation_shift_to_center = mid - activation_center_of_mass activation = final_obs + centered_activation = roll_n(activation, 0, activation_shift_to_center[0].int()) centered_activation = roll_n( - activation, 0, activation_shift_to_center[0].int()) - centered_activation = roll_n( - centered_activation, 1, activation_shift_to_center[1].int()) + centered_activation, 1, activation_shift_to_center[1].int() + ) # calculate the image moments activation_moments = calc_image_moments(centered_activation) @@ -406,13 +427,13 @@ def _calc_static_statistics(self, final_obs: torch.Tensor) -> torch.Tensor: cur_idx += 1 # mass distribution around the center - distance_weight_matrix = calc_distance_matrix( - self.SY, self.SX) + distance_weight_matrix = calc_distance_matrix(self.SY, self.SX) if activation_mass <= EPS: activation_mass_distribution = 1.0 else: activation_mass_distribution = torch.sum( - distance_weight_matrix * centered_activation) / torch.sum(centered_activation) + distance_weight_matrix * centered_activation + ) / torch.sum(centered_activation) activation_mass_distribution_data = activation_mass_distribution feature_vector[cur_idx] = activation_mass_distribution_data diff --git a/services/base/AutoDiscServer/libs/adtool_custom/systems/flow_lenia/flow_lenia/main.py b/services/base/AutoDiscServer/libs/adtool_custom/systems/flow_lenia/flow_lenia/main.py index 76178395..9e3d4122 100644 --- a/services/base/AutoDiscServer/libs/adtool_custom/systems/flow_lenia/flow_lenia/main.py +++ b/services/base/AutoDiscServer/libs/adtool_custom/systems/flow_lenia/flow_lenia/main.py @@ -12,7 +12,7 @@ from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter -os.environ['FFMPEG_BINARY'] = 'ffmpeg' +os.environ["FFMPEG_BINARY"] = "ffmpeg" class VideoWriter: @@ -43,7 +43,7 @@ def __exit__(self, *kw): def show(self, **kw): self.close() - fn = self.params['filename'] + fn = self.params["filename"] display(mvp.ipython_display(fn, **kw)) @@ -51,22 +51,19 @@ def sigmoid(x): return 0.5 * (jnp.tanh(x / 2) + 1) -def ker_f(x, a, w, b): return ( - b * jnp.exp(- (x[..., None] - a)**2 / w)).sum(-1) +def ker_f(x, a, w, b): + return (b * jnp.exp(-((x[..., None] - a) ** 2) / w)).sum(-1) -def bell(x, m, s): return jnp.exp(-((x - m) / s)**2 / 2) +def bell(x, m, s): + return jnp.exp(-(((x - m) / s) ** 2) / 2) def growth(U, m, s): return bell(U, m, s) * 2 - 1 -kx = jnp.array([ - [1., 0., -1.], - [2., 0., -2.], - [1., 0., -1.] -]) +kx = jnp.array([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]]) ky = jnp.transpose(kx) @@ -75,29 +72,38 @@ def sobel_x(A): A : (x, y, c) ret : (x, y, c) """ - return jnp.dstack([jsp.signal.convolve2d(A[:, :, c], kx, mode='same') - for c in range(A.shape[-1])]) + return jnp.dstack( + [jsp.signal.convolve2d(A[:, :, c], kx, mode="same") for c in range(A.shape[-1])] + ) def sobel_y(A): - return jnp.dstack([jsp.signal.convolve2d(A[:, :, c], ky, mode='same') - for c in range(A.shape[-1])]) + return jnp.dstack( + [jsp.signal.convolve2d(A[:, :, c], ky, mode="same") for c in range(A.shape[-1])] + ) @jax.jit def sobel(A): return jnp.concatenate( - (sobel_y(A)[ - :, :, None, :], sobel_x(A)[ - :, :, None, :]), axis=2) + (sobel_y(A)[:, :, None, :], sobel_x(A)[:, :, None, :]), axis=2 + ) def get_kernels(SX: int, SY: int, nb_k: int, params): mid = SX // 2 - Ds = [np.linalg.norm(np.mgrid[-mid:mid, -mid:mid], axis=0) / - ((params['R'] + 15) * params['r'][k]) for k in range(nb_k)] # (x,y,k) - K = jnp.dstack([sigmoid(-(D - 1) * 10) * ker_f(D, params["a"][k], - params["w"][k], params["b"][k]) for k, D in zip(range(nb_k), Ds)]) + Ds = [ + np.linalg.norm(np.mgrid[-mid:mid, -mid:mid], axis=0) + / ((params["R"] + 15) * params["r"][k]) + for k in range(nb_k) + ] # (x,y,k) + K = jnp.dstack( + [ + sigmoid(-(D - 1) * 10) + * ker_f(D, params["a"][k], params["w"][k], params["b"][k]) + for k, D in zip(range(nb_k), Ds) + ] + ) nK = K / jnp.sum(K, axis=(0, 1), keepdims=True) return nK @@ -122,18 +128,18 @@ def conn_from_lists(c0, c1, C): class ReintegrationTracking: - def __init__( - self, - SX=256, - SY=256, - dt=.2, - dd=5, - sigma=.65, - border="wall", - has_hidden=False, - hidden_dims=None, - mix="softmax"): + self, + SX=256, + SY=256, + dt=0.2, + dd=5, + sigma=0.65, + border="wall", + has_hidden=False, + hidden_dims=None, + mix="softmax", + ): self.SX = SX self.SY = SY self.dt = dt @@ -141,7 +147,7 @@ def __init__( self.sigma = sigma self.has_hidden = has_hidden self.hidden_dims = hidden_dims - self.border = border if border in ['wall', 'torus'] else 'wall' + self.border = border if border in ["wall", "torus"] else "wall" self.mix = mix self.apply = self._build_apply() @@ -150,10 +156,9 @@ def __call__(self, *args): return self.apply(*args) def _build_apply(self): - x, y = jnp.arange(self.SX), jnp.arange(self.SY) X, Y = jnp.meshgrid(x, y) - pos = jnp.dstack((Y, X)) + .5 # (SX, SY, 2) + pos = jnp.dstack((Y, X)) + 0.5 # (SX, SY, 2) dxs = [] dys = [] dd = self.dd @@ -170,21 +175,30 @@ def _build_apply(self): def step(X, mu, dx, dy): Xr = jnp.roll(X, (dx, dy), axis=(0, 1)) mur = jnp.roll(mu, (dx, dy), axis=(0, 1)) - if self.border == 'torus': - dpmu = jnp.min(jnp.stack( - [jnp.absolute(pos[..., None] - (mur + jnp.array([di, dj])[None, None, :, None])) - for di in (-self.SX, 0, self.SX) for dj in (-self.SY, 0, self.SY)] - ), axis=0) + if self.border == "torus": + dpmu = jnp.min( + jnp.stack( + [ + jnp.absolute( + pos[..., None] + - (mur + jnp.array([di, dj])[None, None, :, None]) + ) + for di in (-self.SX, 0, self.SX) + for dj in (-self.SY, 0, self.SY) + ] + ), + axis=0, + ) else: dpmu = jnp.absolute(pos[..., None] - mur) - sz = .5 - dpmu + self.sigma - area = jnp.prod(jnp.clip(sz, 0, min( - 1, 2 * self.sigma)), axis=2) / (4 * self.sigma**2) + sz = 0.5 - dpmu + self.sigma + area = jnp.prod(jnp.clip(sz, 0, min(1, 2 * self.sigma)), axis=2) / ( + 4 * self.sigma**2 + ) nX = Xr * area return nX def apply(X, F): - ma = self.dd - self.sigma # upper bound of the flow maggnitude # (x, y, 2, c) : target positions (distribution centers) mu = pos[..., None] + jnp.clip(self.dt * F, -ma, ma) @@ -193,33 +207,42 @@ def apply(X, F): nX = step(X, mu, dxs, dys).sum(axis=0) return nX + # ----------------------------------------------------------------------------------------------- else: @partial(jax.vmap, in_axes=(None, None, None, 0, 0)) def step_flow(X, H, mu, dx, dy): - """Summary - """ + """Summary""" Xr = jnp.roll(X, (dx, dy), axis=(0, 1)) Hr = jnp.roll(H, (dx, dy), axis=(0, 1)) # (x, y, k) mur = jnp.roll(mu, (dx, dy), axis=(0, 1)) - if self.border == 'torus': - dpmu = jnp.min(jnp.stack( - [jnp.absolute(pos[..., None] - (mur + jnp.array([di, dj])[None, None, :, None])) - for di in (-self.SX, 0, self.SX) for dj in (-self.SY, 0, self.SY)] - ), axis=0) + if self.border == "torus": + dpmu = jnp.min( + jnp.stack( + [ + jnp.absolute( + pos[..., None] + - (mur + jnp.array([di, dj])[None, None, :, None]) + ) + for di in (-self.SX, 0, self.SX) + for dj in (-self.SY, 0, self.SY) + ] + ), + axis=0, + ) else: dpmu = jnp.absolute(pos[..., None] - mur) - sz = .5 - dpmu + self.sigma - area = jnp.prod(jnp.clip(sz, 0, min( - 1, 2 * self.sigma)), axis=2) / (4 * self.sigma**2) + sz = 0.5 - dpmu + self.sigma + area = jnp.prod(jnp.clip(sz, 0, min(1, 2 * self.sigma)), axis=2) / ( + 4 * self.sigma**2 + ) nX = Xr * area return nX, Hr def apply(X, H, F): - ma = self.dd - self.sigma # upper bound of the flow maggnitude # (x, y, 2, c) : target positions (distribution centers) mu = pos[..., None] + jnp.clip(self.dt * F, -ma, ma) @@ -227,7 +250,7 @@ def apply(X, H, F): mu = jnp.clip(mu, self.sigma, self.SX - self.sigma) nX, nH = step_flow(X, H, mu, dxs, dys) - if self.mix == 'avg': + if self.mix == "avg": nH = jnp.sum(nH * nX.sum(axis=-1, keepdims=True), axis=0) nX = jnp.sum(nH, axis=0) nH = nH / (nX.sum(axis=-1, keepdims=True) + 1e-10) @@ -235,29 +258,39 @@ def apply(X, H, F): elif self.mix == "softmax": expnX = jnp.exp(nX.sum(axis=-1, keepdims=True)) - 1 nX = jnp.sum(nX, axis=0) - nH = jnp.sum(nH * expnX, axis=0) / \ - (expnX.sum(axis=0) + 1e-10) # avg rule + nH = jnp.sum(nH * expnX, axis=0) / ( + expnX.sum(axis=0) + 1e-10 + ) # avg rule elif self.mix == "stoch": categorical = jax.random.categorical( jax.random.PRNGKey(42), jnp.log(nX.sum(axis=-1, keepdims=True)), - axis=0) + axis=0, + ) mask = jax.nn.one_hot( - categorical, num_classes=(2 * self.dd + 1)**2, axis=-1) + categorical, num_classes=(2 * self.dd + 1) ** 2, axis=-1 + ) mask = jnp.transpose(mask, (3, 0, 1, 2)) nH = jnp.sum(nH * mask, axis=0) nX = jnp.sum(nX, axis=0) elif self.mix == "stoch_gene_wise": mask = jnp.concatenate( - [jax.nn.one_hot(jax.random.categorical( - jax.random.PRNGKey(42), - jnp.log(nX.sum(axis=-1, keepdims=True)), - axis=0), - num_classes=(2 * dd + 1)**2, axis=-1) - for _ in range(self.hidden_dims)], - axis=2) + [ + jax.nn.one_hot( + jax.random.categorical( + jax.random.PRNGKey(42), + jnp.log(nX.sum(axis=-1, keepdims=True)), + axis=0, + ), + num_classes=(2 * dd + 1) ** 2, + axis=-1, + ) + for _ in range(self.hidden_dims) + ], + axis=2, + ) # (2dd+1**2, x, y, nb_k) mask = jnp.transpose(mask, (3, 0, 1, 2)) nH = jnp.sum(nH * mask, axis=0) @@ -270,8 +303,8 @@ def apply(X, H, F): @chex.dataclass class Params: - """Flow Lenia update rule parameters - """ + """Flow Lenia update rule parameters""" + r: jnp.ndarray b: jnp.ndarray w: jnp.ndarray @@ -284,8 +317,8 @@ class Params: @chex.dataclass class CompiledParams: - """Flow Lenia compiled parameters - """ + """Flow Lenia compiled parameters""" + fK: jnp.ndarray m: jnp.ndarray s: jnp.ndarray @@ -309,17 +342,18 @@ def __init__(self, nb_k: int): nb_k (int): number of kernels in the update rule """ self.nb_k = nb_k - self.kernel_keys = 'r b w a m s h'.split() + self.kernel_keys = "r b w a m s h".split() self.spaces = { - "r": {'low': .2, 'high': 1., 'mut_std': .2, 'shape': None}, - "b": {'low': .001, 'high': 1., 'mut_std': .2, 'shape': (3,)}, - "w": {'low': .01, 'high': .5, 'mut_std': .2, 'shape': (3,)}, - "a": {'low': .0, 'high': 1., 'mut_std': .2, 'shape': (3,)}, - "m": {'low': .05, 'high': .5, 'mut_std': .2, 'shape': None}, - "s": {'low': .001, 'high': .18, 'mut_std': .01, 'shape': None}, - "h": {'low': .01, 'high': 1., 'mut_std': .2, 'shape': None}, - 'R': {'low': 2., 'high': 25., 'mut_std': .2, 'shape': None}, + "r": {"low": 0.2, "high": 1.0, "mut_std": 0.2, "shape": None}, + "b": {"low": 0.001, "high": 1.0, "mut_std": 0.2, "shape": (3,)}, + "w": {"low": 0.01, "high": 0.5, "mut_std": 0.2, "shape": (3,)}, + "a": {"low": 0.0, "high": 1.0, "mut_std": 0.2, "shape": (3,)}, + "m": {"low": 0.05, "high": 0.5, "mut_std": 0.2, "shape": None}, + "s": {"low": 0.001, "high": 0.18, "mut_std": 0.01, "shape": None}, + "h": {"low": 0.01, "high": 1.0, "mut_std": 0.2, "shape": None}, + "R": {"low": 2.0, "high": 25.0, "mut_std": 0.2, "shape": None}, } + # ----------------------------------------------------------------------------- def sample(self, key: jnp.ndarray) -> Params: @@ -332,28 +366,25 @@ def sample(self, key: jnp.ndarray) -> Params: key (jnp.ndarray): random generation key """ kernels = {} - for k in 'rmsh': + for k in "rmsh": key, subkey = jax.random.split(key) kernels[k] = jax.random.uniform( key=subkey, - minval=self.spaces[k]['low'], - maxval=self.spaces[k]['high'], - shape=( - self.nb_k, - )) + minval=self.spaces[k]["low"], + maxval=self.spaces[k]["high"], + shape=(self.nb_k,), + ) for k in "awb": key, subkey = jax.random.split(key) kernels[k] = jax.random.uniform( key=subkey, - minval=self.spaces[k]['low'], - maxval=self.spaces[k]['high'], - shape=( - self.nb_k, - 3)) + minval=self.spaces[k]["low"], + maxval=self.spaces[k]["high"], + shape=(self.nb_k, 3), + ) R = jax.random.uniform( - key=key, - minval=self.spaces['R']['low'], - maxval=self.spaces['R']['high']) + key=key, minval=self.spaces["R"]["low"], maxval=self.spaces["R"]["high"] + ) return Params(R=R, **kernels) @@ -390,22 +421,30 @@ def compute_kernels(params: Params) -> CompiledParams: CompiledParams: compiled params which can be used as update rule """ - Ds = [np.linalg.norm(np.mgrid[-mid:mid, -mid:mid], axis=0) / - ((params.R + 15) * params.r[k]) for k in range(nb_k)] # (x,y,k) - K = jnp.dstack([sigmoid(-(D - 1) * 10) * ker_f(D, params.a[k], - params.w[k], params.b[k]) for k, D in zip(range(nb_k), Ds)]) + Ds = [ + np.linalg.norm(np.mgrid[-mid:mid, -mid:mid], axis=0) + / ((params.R + 15) * params.r[k]) + for k in range(nb_k) + ] # (x,y,k) + K = jnp.dstack( + [ + sigmoid(-(D - 1) * 10) + * ker_f(D, params.a[k], params.w[k], params.b[k]) + for k, D in zip(range(nb_k), Ds) + ] + ) # Normalize kernels nK = K / jnp.sum(K, axis=(0, 1), keepdims=True) - fK = jnp.fft.fft2(jnp.fft.fftshift(nK, axes=(0, 1)), - axes=(0, 1)) # Get kernels fft + fK = jnp.fft.fft2( + jnp.fft.fftshift(nK, axes=(0, 1)), axes=(0, 1) + ) # Get kernels fft return CompiledParams(fK=fK, m=params.m, s=params.s, h=params.h) self.apply = jax.jit(compute_kernels) def __call__(self, params: Params): - """callback to apply - """ + """callback to apply""" return self.apply(params) @@ -413,11 +452,12 @@ def __call__(self, params: Params): # ==================================================FLOW LENIA============ # ================================================================================================================== + @chex.dataclass class Config: - """Configuration of Flow Lenia system - """ + """Configuration of Flow Lenia system""" + SX: int SY: int nb_k: int @@ -426,17 +466,17 @@ class Config: c1: t.Iterable dt: float dd: int = 5 - sigma: float = .65 + sigma: float = 0.65 n: int = 2 - theta_A: float = 1. - border: str = 'wall' + theta_A: float = 1.0 + border: str = "wall" @chex.dataclass class State: - """State of the system - """ + """State of the system""" + A: jnp.ndarray @@ -465,7 +505,8 @@ def __init__(self, config: Config): self.rule_space = RuleSpace(config.nb_k) self.kernel_computer = KernelComputer( - self.config.SX, self.config.SY, self.config.nb_k) + self.config.SX, self.config.SY, self.config.nb_k + ) self.RT = ReintegrationTracking( self.config.SX, @@ -473,7 +514,8 @@ def __init__(self, config: Config): self.config.dt, self.config.dd, self.config.sigma, - self.config.border) + self.config.border, + ) self.step_fn = self._build_step_fn() @@ -522,13 +564,13 @@ def step(state: State, params: CompiledParams) -> State: fAk = fA[:, :, self.config.c0] # (x,y,k) - U = jnp.real(jnp.fft.ifft2( - params.fK * fAk, axes=(0, 1))) # (x,y,k) + U = jnp.real(jnp.fft.ifft2(params.fK * fAk, axes=(0, 1))) # (x,y,k) U = growth(U, params.m, params.s) * params.h # (x,y,k) - U = jnp.dstack([U[:, :, self.config.c1[c]].sum(axis=-1) - for c in range(self.config.C)]) # (x,y,c) + U = jnp.dstack( + [U[:, :, self.config.c1[c]].sum(axis=-1) for c in range(self.config.C)] + ) # (x,y,c) # -------------------------------FLOW------------------------------------------ @@ -536,8 +578,9 @@ def step(state: State, params: CompiledParams) -> State: nabla_A = sobel(A.sum(axis=-1, keepdims=True)) # (x, y, 2, 1) - alpha = jnp.clip((A[:, :, None, :] / - self.config.theta_A)**self.config.n, .0, 1.) + alpha = jnp.clip( + (A[:, :, None, :] / self.config.theta_A) ** self.config.n, 0.0, 1.0 + ) F = nabla_U * (1 - alpha) - nabla_A * alpha @@ -550,14 +593,17 @@ def step(state: State, params: CompiledParams) -> State: # ------------------------------------------------------------------------------ def _build_rollout( - self) -> t.Callable[[CompiledParams, State, int], t.Tuple[State, State]]: + self, + ) -> t.Callable[[CompiledParams, State, int], t.Tuple[State, State]]: """build rollout function Returns: t.Callable[[CompiledParams, State, int], t.Tuple[State, State]]: Description """ - def scan_step(carry: t.Tuple[State, CompiledParams], - x) -> t.Tuple[t.Tuple[State, CompiledParams], State]: + + def scan_step( + carry: t.Tuple[State, CompiledParams], x + ) -> t.Tuple[t.Tuple[State, CompiledParams], State]: """Summary Args: @@ -571,8 +617,9 @@ def scan_step(carry: t.Tuple[State, CompiledParams], nstate = jax.jit(self.step_fn)(state, params) return (nstate, params), nstate - def rollout(params: CompiledParams, init_state: State, - steps: int) -> t.Tuple[State, State]: + def rollout( + params: CompiledParams, init_state: State, steps: int + ) -> t.Tuple[State, State]: """Summary Args: @@ -583,8 +630,7 @@ def rollout(params: CompiledParams, init_state: State, Returns: t.Tuple[State, State]: Description """ - return jax.lax.scan( - scan_step, (init_state, params), None, length=steps) + return jax.lax.scan(scan_step, (init_state, params), None, length=steps) return rollout @@ -614,7 +660,8 @@ def __init__(self, config: Config): self.rule_space = RuleSpace(config.nb_k) self.kernel_computer = KernelComputer( - self.config.SX, self.config.SY, self.config.nb_k) + self.config.SX, self.config.SY, self.config.nb_k + ) self.RT = ReintegrationTracking( self.config.SX, @@ -622,7 +669,8 @@ def __init__(self, config: Config): self.config.dt, self.config.dd, self.config.sigma, - self.config.border) + self.config.border, + ) self.step_fn = self._build_step_fn() @@ -671,17 +719,17 @@ def step(state: State, params: CompiledParams) -> State: fAk = fA[:, :, self.config.c0] # (x,y,k) - U = jnp.real(jnp.fft.ifft2( - params.fK * fAk, axes=(0, 1))) # (x,y,k) + U = jnp.real(jnp.fft.ifft2(params.fK * fAk, axes=(0, 1))) # (x,y,k) U = growth(U, params.m, params.s) * params.h # (x,y,k) - U = jnp.dstack([U[:, :, self.config.c1[c]].sum(axis=-1) - for c in range(self.config.C)]) # (x,y,c) + U = jnp.dstack( + [U[:, :, self.config.c1[c]].sum(axis=-1) for c in range(self.config.C)] + ) # (x,y,c) # -------------------------------FLOW------------------------------------------ - nA = jnp.clip(A + self.config.dt * U, 0., 1.) + nA = jnp.clip(A + self.config.dt * U, 0.0, 1.0) return State(A=nA) @@ -690,14 +738,17 @@ def step(state: State, params: CompiledParams) -> State: # ------------------------------------------------------------------------------ def _build_rollout( - self) -> t.Callable[[CompiledParams, State, int], t.Tuple[State, State]]: + self, + ) -> t.Callable[[CompiledParams, State, int], t.Tuple[State, State]]: """build rollout function Returns: t.Callable[[CompiledParams, State, int], t.Tuple[State, State]]: Description """ - def scan_step(carry: t.Tuple[State, CompiledParams], - x) -> t.Tuple[t.Tuple[State, CompiledParams], State]: + + def scan_step( + carry: t.Tuple[State, CompiledParams], x + ) -> t.Tuple[t.Tuple[State, CompiledParams], State]: """Summary Args: @@ -711,8 +762,9 @@ def scan_step(carry: t.Tuple[State, CompiledParams], nstate = jax.jit(self.step_fn)(state, params) return (nstate, params), nstate - def rollout(params: CompiledParams, init_state: State, - steps: int) -> t.Tuple[State, State]: + def rollout( + params: CompiledParams, init_state: State, steps: int + ) -> t.Tuple[State, State]: """Summary Args: @@ -723,10 +775,11 @@ def rollout(params: CompiledParams, init_state: State, Returns: t.Tuple[State, State]: Description """ - return jax.lax.scan( - scan_step, (init_state, params), None, length=steps) + return jax.lax.scan(scan_step, (init_state, params), None, length=steps) return rollout + + # @title Utils @@ -752,8 +805,18 @@ def main(): M = np.ones((C, C), dtype=int) * nb_k nb_k = int(M.sum()) c0, c1 = conn_from_matrix(M) - config = Config(SX=SX, SY=SY, nb_k=nb_k, C=C, c0=c0, c1=c1, - dt=dt, theta_A=theta_A, dd=5, sigma=sigma) + config = Config( + SX=SX, + SY=SY, + nb_k=nb_k, + C=C, + c0=c0, + c1=c1, + dt=dt, + theta_A=theta_A, + dd=5, + sigma=sigma, + ) fl = FlowLenia(config) roll_fn = jax.jit(fl.rollout_fn, static_argnums=(2,)) @@ -765,8 +828,10 @@ def main(): c_params = fl.kernel_computer(params) # Process params mx, my = SX // 2, SY // 2 # center coordinated - A0 = jnp.zeros((SX, SY, C)).at[mx - 20:mx + 20, my - 20:my + 20, :].set( - jax.random.uniform(state_seed, (40, 40, C)) + A0 = ( + jnp.zeros((SX, SY, C)) + .at[mx - 20 : mx + 20, my - 20 : my + 20, :] + .set(jax.random.uniform(state_seed, (40, 40, C))) ) state = State(A=A0)