diff --git a/.github/analytics/get_repo_metrics.py b/.github/analytics/get_repo_metrics.py index cb9166c19e..600270a0df 100644 --- a/.github/analytics/get_repo_metrics.py +++ b/.github/analytics/get_repo_metrics.py @@ -331,7 +331,9 @@ def main(_): ) issues.get() - df_issues = df_issues0 = pd.DataFrame(list(_get_issues_features(issues.raw_data))) + df_issues = df_issues0 = pd.DataFrame( + list(_get_issues_features(issues.raw_data)) + ) df_issues['issue_response_time'] = ( df_issues['time_labeled_or_converted'] - df_issues['created_at'] ) @@ -350,7 +352,9 @@ def main(_): prs.get() df_prs = df_prs0 = pd.DataFrame(list(_get_pr_features(prs.raw_data))) - time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(axis=1) + time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min( + axis=1 + ) df_prs['pr_response_time'] = time_response - df_prs['ready_for_review_at'] df_prs['pr_resolution_time'] = ( df_prs['time_merged_or_closed'] - df_prs['ready_for_review_at'] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2fd3f6b440..2b27b92b9a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,12 +13,16 @@ repos: hooks: - id: jupytext args: [--sync] +- repo: https://github.com/google/pyink + rev: 23.5.0 + hooks: + - id: pyink - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 hooks: - id: check-toml - id: trailing-whitespace - exclude: ^docs/.*\.md$|^dev/.*\.py$ + exclude: ^docs/.*\.md$ - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: @@ -29,7 +33,3 @@ repos: --extra-keys, "metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab", ] -- repo: https://github.com/google/pyink - rev: 23.5.0 - hooks: - - id: pyink diff --git a/dev/update_requirements.py b/dev/update_requirements.py index 8c576d923b..f44226e1e9 100644 --- a/dev/update_requirements.py +++ b/dev/update_requirements.py @@ -39,6 +39,8 @@ Alternatively, the list can also be provided from the local environment with: python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6" + +This will use the currently installed versions of all packages. """ import pathlib diff --git a/examples/imagenet/train_test.py b/examples/imagenet/train_test.py index 8dd86b723f..f9cc657d67 100644 --- a/examples/imagenet/train_test.py +++ b/examples/imagenet/train_test.py @@ -55,7 +55,9 @@ def test_create_model_local(self): Uses smaller inputs than `test_create_model` to due to higher compute. """ - model = train.create_model(model_cls=models._ResNet1Local, half_precision=False) # pylint: disable=protected-access + model = train.create_model( + model_cls=models._ResNet1Local, half_precision=False + ) # pylint: disable=protected-access params, batch_stats = train.initialized(random.PRNGKey(0), 64, model) variables = {'params': params, 'batch_stats': batch_stats} x = random.normal(random.PRNGKey(1), (1, 64, 64, 3)) diff --git a/examples/lm1b/temperature_sampler_test.py b/examples/lm1b/temperature_sampler_test.py index 0627d68530..a0c7f46ece 100644 --- a/examples/lm1b/temperature_sampler_test.py +++ b/examples/lm1b/temperature_sampler_test.py @@ -37,7 +37,9 @@ def tokens_to_logits(tokens, cache): logits = logits.squeeze(axis=1) return logits, cache - new_tokens = temperature_sample(tokens, cache, tokens_to_logits, key, topk=5) + new_tokens = temperature_sample( + tokens, cache, tokens_to_logits, key, topk=5 + ) np.testing.assert_array_equal(new_tokens, [[5, 6, 7, 8]]) diff --git a/examples/lm1b/train.py b/examples/lm1b/train.py index 3b65b7e11d..5284a8072b 100644 --- a/examples/lm1b/train.py +++ b/examples/lm1b/train.py @@ -285,7 +285,9 @@ def per_host_sum_pmap(in_tree): host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices) def pre_pmap(xs): - return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs) + return jax.tree_util.tree_map( + lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs + ) def post_pmap(xs): return jax.tree_util.tree_map(lambda x: x[0], xs) @@ -525,7 +527,9 @@ def encode_strings(strs, max_len): # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): - batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter))) + batch = common_utils.shard( + jax.tree_util.tree_map(np.asarray, next(train_iter)) + ) state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) @@ -542,7 +546,9 @@ def encode_strings(strs, max_len): lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") - summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop + summary = jax.tree_util.tree_map( + lambda x: x / denominator, metrics_sums + ) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary["perplexity"] = jnp.clip( jnp.exp(summary["loss"]), a_max=1.0e4 diff --git a/examples/nlp_seq/train.py b/examples/nlp_seq/train.py index 8f5a1d192a..72ce469507 100644 --- a/examples/nlp_seq/train.py +++ b/examples/nlp_seq/train.py @@ -42,9 +42,9 @@ FLAGS = flags.FLAGS -flags.DEFINE_string('model_dir', default='', help=('Directory for model data.')) +flags.DEFINE_string('model_dir', default='', help='Directory for model data.') -flags.DEFINE_string('experiment', default='xpos', help=('Experiment name.')) +flags.DEFINE_string('experiment', default='xpos', help='Experiment name.') flags.DEFINE_integer('batch_size', default=64, help='Batch size for training.') @@ -58,7 +58,7 @@ 'num_train_steps', default=75000, help='Number of train steps.' ) -flags.DEFINE_float('learning_rate', default=0.05, help=('Learning rate.')) +flags.DEFINE_float('learning_rate', default=0.05, help='Learning rate.') flags.DEFINE_float( 'weight_decay', @@ -74,9 +74,9 @@ 'random_seed', default=0, help='Integer for PRNG random seed.' ) -flags.DEFINE_string('train', default='', help=('Path to training data.')) +flags.DEFINE_string('train', default='', help='Path to training data.') -flags.DEFINE_string('dev', default='', help=('Path to development data.')) +flags.DEFINE_string('dev', default='', help='Path to development data.') def create_learning_rate_scheduler( @@ -354,7 +354,9 @@ def eval_step(params, batch): tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): - batch = common_utils.shard(jax.tree_util.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access + batch = common_utils.shard( + jax.tree_util.tree_map(lambda x: x._numpy(), batch) + ) # pylint: disable=protected-access state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 9d9e31af2f..c108cccf1c 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -195,7 +195,9 @@ def get_experience( sim_state = sim.conn.recv() sim_states.append(sim_state) sim_states = np.concatenate(sim_states, axis=0) - log_probs, values = agent.policy_action(state.apply_fn, state.params, sim_states) + log_probs, values = agent.policy_action( + state.apply_fn, state.params, sim_states + ) log_probs, values = jax.device_get((log_probs, values)) probs = np.exp(np.array(log_probs)) for i, sim in enumerate(simulators): @@ -343,7 +345,9 @@ def train( score = test_episodes.policy_test(1, state.apply_fn, state.params, game) frames = step * config.num_agents * config.actor_steps summary_writer.scalar('game_score', score, frames) - logging.info('Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score) + logging.info( + 'Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score + ) # Core training code. alpha = ( diff --git a/examples/seq2seq/models.py b/examples/seq2seq/models.py index dd34e60eb8..ea1c50198d 100644 --- a/examples/seq2seq/models.py +++ b/examples/seq2seq/models.py @@ -109,7 +109,9 @@ def __call__( encoding format). """ # Encode inputs. - encoder = nn.RNN(nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder') + encoder = nn.RNN( + nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder' + ) decoder = nn.RNN( DecoderLSTMCell( decoder_inputs.shape[-1], self.teacher_force, self.vocab_size diff --git a/examples/sst2/models.py b/examples/sst2/models.py index 3511fc2c3f..f5a4519970 100644 --- a/examples/sst2/models.py +++ b/examples/sst2/models.py @@ -197,12 +197,16 @@ def setup(self): def __call__(self, embedded_inputs, lengths): # Forward LSTM. - initial_state = self.forward_lstm.initialize_carry(embedded_inputs[:, 0].shape) + initial_state = self.forward_lstm.initialize_carry( + embedded_inputs[:, 0].shape + ) _, forward_outputs = self.forward_lstm(initial_state, embedded_inputs) # Backward LSTM. reversed_inputs = flip_sequences(embedded_inputs, lengths) - initial_state = self.backward_lstm.initialize_carry(reversed_inputs[:, 0].shape) + initial_state = self.backward_lstm.initialize_carry( + reversed_inputs[:, 0].shape + ) _, backward_outputs = self.backward_lstm(initial_state, reversed_inputs) backward_outputs = flip_sequences(backward_outputs, lengths) diff --git a/examples/wmt/decode.py b/examples/wmt/decode.py index 3755adb51f..7d3e048d99 100644 --- a/examples/wmt/decode.py +++ b/examples/wmt/decode.py @@ -155,7 +155,9 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_util.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree_util.tree_map( + lambda x: add_beam_dim(x, beam_size), cache + ) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -350,13 +352,15 @@ def beam_search_loop_body_fn(state): [state.finished_flags, newly_finished], axis=1 ) # --> [batch, beams, length], [batch, beams], [batch, beams] - top_finished_seq, top_finished_scores, top_finished_flags = ( - gather_topk_beams( - [finished_seqs, finished_scores, finished_flags], - finished_scores, - batch_size, - beam_size, - ) + ( + top_finished_seq, + top_finished_scores, + top_finished_flags, + ) = gather_topk_beams( + [finished_seqs, finished_scores, finished_flags], + finished_scores, + batch_size, + beam_size, ) return BeamState( diff --git a/examples/wmt/train.py b/examples/wmt/train.py index 35adeb913f..929cdf9c93 100644 --- a/examples/wmt/train.py +++ b/examples/wmt/train.py @@ -355,7 +355,9 @@ def per_host_sum_pmap(in_tree): host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices) def pre_pmap(xs): - return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs) + return jax.tree_util.tree_map( + lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs + ) def post_pmap(xs): return jax.tree_util.tree_map(lambda x: x[0], xs) @@ -626,7 +628,9 @@ def decode_tokens(toks): # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): - batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter))) + batch = common_utils.shard( + jax.tree_util.tree_map(np.asarray, next(train_iter)) + ) state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) @@ -643,7 +647,9 @@ def decode_tokens(toks): lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") - summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop + summary = jax.tree_util.tree_map( + lambda x: x / denominator, metrics_sums + ) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index e6f7a82007..b8ae5e417b 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -156,7 +156,9 @@ def body_fn(c, xs, init_mode=False): 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) - broadcast_in, constants_out = jax.tree_util.tree_unflatten(out_tree(), out_flat) + broadcast_in, constants_out = jax.tree_util.tree_unflatten( + out_tree(), out_flat + ) c, ys = lax.scan( body_fn, init, xs, length=length, reverse=reverse, unroll=unroll diff --git a/flax/core/meta.py b/flax/core/meta.py index 509b2a2926..16e703247e 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -237,7 +237,9 @@ def body(mdl, c): value: Any names: LogicalNames = struct.field(pytree_node=False) - mesh: Optional[jax.sharding.Mesh] = struct.field(default=None, pytree_node=False) + mesh: Optional[jax.sharding.Mesh] = struct.field( + default=None, pytree_node=False + ) def unbox(self, apply_constraint=True) -> Any: """Returns the wrapped value with the partitioning applied as a sharding constraint.""" diff --git a/flax/errors.py b/flax/errors.py index 5f9655e05d..df28287728 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -52,7 +52,9 @@ def __init__(self): class FlaxError(Exception): def __init__(self, message): - error_page = 'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html' + error_page = ( + 'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html' + ) module_name = self.__class__.__module__ class_name = self.__class__.__name__ error_msg = f'{message} ({error_page}#{module_name}.{class_name})' @@ -706,7 +708,8 @@ def __call__(self, x): def __init__(self): super().__init__( - 'Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself' + 'Can only call init, init_with_output or apply methods on an instance' + ' of the Module class, not the Module class itself' ) @@ -755,7 +758,8 @@ def __call__(self, x): def __init__(self): super().__init__( - 'Trying to access a property that is accessing a non-existent attribute.' + 'Trying to access a property that is accessing a non-existent' + ' attribute.' ) @@ -770,7 +774,8 @@ class InvalidCheckpointError(FlaxError): def __init__(self, path, step): super().__init__( - f'Trying to save an outdated checkpoint at step: "{step}" and path: "{path}".' + f'Trying to save an outdated checkpoint at step: "{step}" and path:' + f' "{path}".' ) diff --git a/flax/jax_utils.py b/flax/jax_utils.py index 6755734079..512b52515f 100644 --- a/flax/jax_utils.py +++ b/flax/jax_utils.py @@ -218,7 +218,9 @@ def transpose_out(x): def body_wrapper(c, xs): if keepdims: - xs = jax.tree_util.tree_map(lambda x: x.reshape((1,) * len(axis) + x.shape), xs) + xs = jax.tree_util.tree_map( + lambda x: x.reshape((1,) * len(axis) + x.shape), xs + ) xs = jax.tree_util.tree_map(transpose_out, xs) c, ys = body_fn(c, xs) if keepdims: diff --git a/flax/linen/attention.py b/flax/linen/attention.py index d469688183..3ecdc72ff1 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -237,7 +237,9 @@ class MultiHeadDotProductAttention(Module): deterministic: Optional[bool] = None precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) use_bias: bool = True attention_fn: Callable[..., Array] = dot_product_attention decode: bool = False @@ -316,9 +318,12 @@ def __call__( 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) ) if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = ( - cached_key.value.shape - ) + ( + *batch_dims, + max_length, + num_heads, + depth_per_head, + ) = cached_key.value.shape # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: diff --git a/flax/linen/experimental/layers_with_named_axes.py b/flax/linen/experimental/layers_with_named_axes.py index a2ca4d619c..cb8c4e09d8 100644 --- a/flax/linen/experimental/layers_with_named_axes.py +++ b/flax/linen/experimental/layers_with_named_axes.py @@ -39,7 +39,9 @@ default_kernel_init = initializers.lecun_normal() -default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0) +default_embed_init = initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0 +) class Dense(nn.Module): @@ -66,7 +68,9 @@ class Dense(nn.Module): param_dtype: DType = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, DType], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, DType], Array] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, DType], Array] = ( + initializers.zeros_init() + ) kernel_axes: Tuple[str, ...] = () dot_general: DotGeneralT = lax.dot_general @@ -301,8 +305,12 @@ class LayerNorm(nn.Module): param_dtype: DType = jnp.float32 use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, DType], Array] = initializers.zeros_init() - scale_init: Callable[[PRNGKey, Shape, DType], Array] = initializers.ones_init() + bias_init: Callable[[PRNGKey, Shape, DType], Array] = ( + initializers.zeros_init() + ) + scale_init: Callable[[PRNGKey, Shape, DType], Array] = ( + initializers.ones_init() + ) @nn.compact def __call__(self, x): diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 2e6eee4429..463efeb10d 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -93,7 +93,9 @@ class DenseGeneral(Module): dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) precision: PrecisionLike = None dot_general: DotGeneralT = lax.dot_general @@ -206,7 +208,9 @@ class Dense(Module): param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) dot_general: DotGeneralT = lax.dot_general @compact @@ -331,7 +335,9 @@ class _Conv(Module): param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated @property @@ -439,7 +445,7 @@ def maybe_broadcast( else: if self.feature_group_count != 1: raise NotImplementedError( - f'`lax.conv_general_dilated_local` does not support ' + '`lax.conv_general_dilated_local` does not support ' f'`feature_group_count != 1`, got `{self.feature_group_count}`.' ) @@ -655,7 +661,9 @@ class ConvTranspose(Module): param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) transpose_kernel: bool = False @compact @@ -755,7 +763,8 @@ def __call__(self, inputs: Array) -> Array: # Compute period along each spatial dimension - it's input size scaled # by the stride. scaled_x_dims = [ - x_dim * stride for x_dim, stride in zip(jnp.shape(inputs)[1:-1], strides) + x_dim * stride + for x_dim, stride in zip(jnp.shape(inputs)[1:-1], strides) ] # Compute difference between the current size of y and the final output # size, and complement this difference to 2 * period - that gives how @@ -797,7 +806,9 @@ def __call__(self, inputs: Array) -> Array: return y -default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0) +default_embed_init = initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0 +) class Embed(Module): diff --git a/flax/linen/module.py b/flax/linen/module.py index 995a68a05a..39bcadadc5 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -580,7 +580,9 @@ def reset(self) -> None: def export(self) -> '_ModuleInternalState': """Exports transform-preserved state across transform boundary.""" - setup_state = SetupState.TRANSFORMED if self.setup_called else SetupState.NEW + setup_state = ( + SetupState.TRANSFORMED if self.setup_called else SetupState.NEW + ) cloned = _ModuleInternalState( in_compact_method=self.in_compact_method, in_setup=self.in_setup, @@ -970,7 +972,9 @@ def _call_wrapped_method(self, fun, args, kwargs): if filter_fn and filter_fn(self, fun_name): self.sow('intermediates', fun_name, y) if add_call_info: - _args, _kwargs, _y = flax.linen.summary._represent_tree((args, kwargs, y)) + _args, _kwargs, _y = flax.linen.summary._represent_tree( + (args, kwargs, y) + ) _context.call_info_stack[-1].calls.append( _CallInfo( call_index, @@ -1284,7 +1288,11 @@ def clone( Returns: A clone of the this Module with the updated attributes and parent. """ - attrs = {f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.init} + attrs = { + f.name: getattr(self, f.name) + for f in dataclasses.fields(self) + if f.init + } attrs.update(parent=parent, **updates) @@ -2027,7 +2035,9 @@ def __call__(self, x): self.scope.put_variable(col, name, xs) return True - def perturb(self, name: str, value: T, collection: str = 'perturbations') -> T: + def perturb( + self, name: str, value: T, collection: str = 'perturbations' + ) -> T: """Add an zero-value variable ('perturbation') to the intermediate value. The gradient of `value` would be the same as the gradient of this diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 747740c259..259127380b 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -90,7 +90,9 @@ class RNNCellBase(Module): """RNN cell base class.""" @nowrap - def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]) -> Carry: + def initialize_carry( + self, rng: PRNGKey, input_shape: Tuple[int, ...] + ) -> Carry: """Initialize the RNN cell carry. Args: @@ -223,7 +225,9 @@ class DenseParams(Module): param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) @compact def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: @@ -322,7 +326,9 @@ def _concat_dense( bias = jnp.concatenate(biases, axis=-1) else: bias = None - inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) + inputs, kernel, bias = promote_dtype( + inputs, kernel, bias, dtype=self.dtype + ) y = jnp.dot(inputs, kernel) if use_bias: # This assert is here since mypy can't infer that bias cannot be None @@ -727,7 +733,9 @@ class RNN(Module): ) variable_broadcast: lift.CollectionFilter = 'params' variable_carry: lift.CollectionFilter = False - split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict({'params': False}) + split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict( + {'params': False} + ) def __post_init__(self) -> None: if self.cell_size is not NEVER: @@ -794,7 +802,9 @@ def __call__( # Infer the number of batch dimensions from the input shape. # Cells like ConvLSTM have additional spatial dimensions. - time_axis = 0 if time_major else inputs.ndim - (self.cell.num_feature_axes + 1) + time_axis = ( + 0 if time_major else inputs.ndim - (self.cell.num_feature_axes + 1) + ) # make time_axis positive if time_axis < 0: diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 92f5796825..88236b91f7 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -257,7 +257,9 @@ def _with_sharding_constraint_one_fallback( raise ValueError(f'Axis names {axis_resources} did not match a rule') else: return x - return _with_sharding_constraint(x, jax.sharding.PartitionSpec(*mesh_axes), mesh=mesh) + return _with_sharding_constraint( + x, jax.sharding.PartitionSpec(*mesh_axes), mesh=mesh + ) def _is_logical_spec(x): diff --git a/flax/linen/summary.py b/flax/linen/summary.py index 09a171f6a4..67c195c337 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -71,7 +71,9 @@ class _PartitionedArrayRepresentation(_ValueRepresentation): def from_partitioned( cls, partitioned: meta.Partitioned ) -> '_PartitionedArrayRepresentation': - return cls(_ArrayRepresentation.from_array(partitioned.value), partitioned.names) + return cls( + _ArrayRepresentation.from_array(partitioned.value), partitioned.names + ) def render(self): return self.array_representation.render() + f' [dim]P[/dim]{self.names}' @@ -115,7 +117,9 @@ def __post_init__(self): self.module_variables = self.module_variables self.counted_variables = self.counted_variables - def size_and_bytes(self, collections: Iterable[str]) -> Dict[str, Tuple[int, int]]: + def size_and_bytes( + self, collections: Iterable[str] + ) -> Dict[str, Tuple[int, int]]: return { col: ( _size_and_bytes(self.counted_variables[col]) @@ -249,7 +253,9 @@ def __call__(self, x): """ def _tabulate_fn(*fn_args, **fn_kwargs): - table_fn = _get_module_table(module, depth=depth, show_repeated=show_repeated) + table_fn = _get_module_table( + module, depth=depth, show_repeated=show_repeated + ) table = table_fn(rngs, *fn_args, mutable=mutable, **fn_kwargs, **kwargs) return _render_table(table, console_kwargs, table_kwargs, column_kwargs) @@ -324,18 +330,26 @@ def _get_module_variables( all_paths: Set[Tuple[str, ...]], ) -> Tuple[MutableVariableDict, Any]: """A function that takes a path and variables structure and returns a - (module_variables, submodule_variables) tuple for that path. _get_module_variables - uses the `all_paths` set to determine if a variable belongs to a submodule or not.""" + + (module_variables, submodule_variables) tuple for that path. + _get_module_variables + uses the `all_paths` set to determine if a variable belongs to a submodule or + not. + """ module_variables = _get_path_variables(path, variables) submodule_variables: Any = {collection: {} for collection in module_variables} - all_keys = set(key for collection in module_variables.values() for key in collection) + all_keys = set( + key for collection in module_variables.values() for key in collection + ) for key in all_keys: submodule_path = path + (key,) if submodule_path in all_paths: for collection in module_variables: if key in module_variables[collection]: - submodule_variables[collection][key] = module_variables[collection].pop(key) + submodule_variables[collection][key] = module_variables[ + collection + ].pop(key) return module_variables, submodule_variables @@ -470,7 +484,9 @@ def _render_table( ) rich_table.caption_style = 'bold' - rich_table.caption = f'\nTotal Parameters: {_size_and_bytes_repr(*caption_totals)}' + rich_table.caption = ( + f'\nTotal Parameters: {_size_and_bytes_repr(*caption_totals)}' + ) return '\n' + _get_rich_repr(rich_table, console_kwargs) + '\n' @@ -489,7 +505,9 @@ def _size_and_bytes_repr(size: int, num_bytes: int) -> str: def _size_and_bytes(pytree: Any) -> Tuple[int, int]: leaves = jax.tree_util.tree_leaves(pytree) size = sum(x.size for x in leaves if hasattr(x, 'size')) - num_bytes = sum(x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size')) + num_bytes = sum( + x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size') + ) return size, num_bytes diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index e8b927365f..51fbf177ee 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -1293,7 +1293,9 @@ def body_fn(mdl, c): def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs): - return lift.cond(pred, t_fn, f_fn, scope, *ops, variables=variables, rngs=rngs) + return lift.cond( + pred, t_fn, f_fn, scope, *ops, variables=variables, rngs=rngs + ) def cond( @@ -1362,7 +1364,9 @@ def _switch_wrapper(*args, variables, rngs, n_branches): # then scope, index, and the rest are *operands branches = args[:n_branches] scope, index, *operands = args[n_branches:] - return lift.switch(index, branches, scope, *operands, variables=variables, rngs=rngs) + return lift.switch( + index, branches, scope, *operands, variables=variables, rngs=rngs + ) def switch( diff --git a/flax/struct.py b/flax/struct.py index 33968021c7..3385f09206 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -156,7 +156,9 @@ def from_state_dict(x, state): ) value = getattr(x, name) value_state = state.pop(name) - updates[name] = serialization.from_state_dict(value, value_state, name=name) + updates[name] = serialization.from_state_dict( + value, value_state, name=name + ) if state: names = ','.join(state.keys()) raise ValueError( diff --git a/flax/training/early_stopping.py b/flax/training/early_stopping.py index f6675e4184..fd95fd761f 100644 --- a/flax/training/early_stopping.py +++ b/flax/training/early_stopping.py @@ -65,7 +65,10 @@ def update(self, metric): `best_metric` and `early_stop` is the updated `EarlyStop` object. """ - if math.isinf(self.best_metric) or self.best_metric - metric > self.min_delta: + if ( + math.isinf(self.best_metric) + or self.best_metric - metric > self.min_delta + ): return True, self.replace(best_metric=metric, patience_count=0) else: should_stop = self.patience_count >= self.patience or self.should_stop diff --git a/pyproject.toml b/pyproject.toml index 6cad172004..e8fc9051b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,4 +149,6 @@ exclude_lines = [ [tool.pyink] pyink-indentation = 2 -pyink-use-majority-quotes = true \ No newline at end of file +pyink-use-majority-quotes = true +line-length = 80 +preview = true \ No newline at end of file diff --git a/tests/core/core_frozen_dict_test.py b/tests/core/core_frozen_dict_test.py index 4f2df7283f..a44405ffb9 100644 --- a/tests/core/core_frozen_dict_test.py +++ b/tests/core/core_frozen_dict_test.py @@ -128,7 +128,9 @@ def test_utility_copy(self, x, add_or_replace, actual_new_x): }, { 'x': FrozenDict({'a': 1, 'b': {'c': 2}}), - 'pretty_str': 'FrozenDict({\n a: 1,\n b: {\n c: 2,\n },\n})', + 'pretty_str': ( + 'FrozenDict({\n a: 1,\n b: {\n c: 2,\n },\n})' + ), }, { 'x': 345, diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index 525616918d..c433c2276e 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -116,7 +116,9 @@ def f(scope, x): vars_t = jax.tree_util.tree_map( jnp.ones_like, scope.variables().get('params', {}) ) - _, out_t = lift.jvp(g, scope, (x,), (jnp.zeros_like(x),), {'params': vars_t}) + _, out_t = lift.jvp( + g, scope, (x,), (jnp.zeros_like(x),), {'params': vars_t} + ) return out_t x = jnp.ones((3,)) diff --git a/tests/core/core_scope_test.py b/tests/core/core_scope_test.py index 2ada76dfc0..63744e9560 100644 --- a/tests/core/core_scope_test.py +++ b/tests/core/core_scope_test.py @@ -36,7 +36,9 @@ def f(scope): self.assertTrue(scope.has_rng('params')) self.assertFalse(scope.has_rng('dropout')) rng = scope.make_rng('params') - self.assertTrue(np.all(rng == LazyRng.create(random.PRNGKey(0), 1).as_jax_rng())) + self.assertTrue( + np.all(rng == LazyRng.create(random.PRNGKey(0), 1).as_jax_rng()) + ) init(f)(random.PRNGKey(0)) @@ -67,7 +69,9 @@ def union_check(a, b, ans): scope.DenyList(['b', 'c']), scope.DenyList(set(['b'])), ) - union_check(scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList(set(['a']))) + union_check( + scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList(set(['a'])) + ) def test_intersect_filter(self): def intersect_check(a, b, ans): @@ -94,7 +98,9 @@ def subtract_check(a, b, ans): subtract_check(False, False, set()) subtract_check(True, True, False) subtract_check(True, 'a', scope.DenyList('a')) - subtract_check(scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), set(['c'])) + subtract_check( + scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), set(['c']) + ) subtract_check( scope.DenyList(['a', 'b']), ['b', 'c'], @@ -122,7 +128,10 @@ def test_inconsistent_param_shapes(self): def f(scope): scope.param('test', nn.initializers.ones_init(), (4,)) - msg = r'Initializer expected to generate shape \(2,\) but got shape \(4,\) instead for parameter "test" in "/"' + msg = ( + r'Initializer expected to generate shape \(2,\) but got shape \(4,\)' + r' instead for parameter "test" in "/"' + ) with self.assertRaisesRegex(errors.ScopeParamShapeError, msg): apply(f)(freeze({'params': {'test': np.ones((2,))}})) diff --git a/tests/core/design/core_attention_test.py b/tests/core/design/core_attention_test.py index 686ad5938f..2ddb468890 100644 --- a/tests/core/design/core_attention_test.py +++ b/tests/core/design/core_attention_test.py @@ -38,7 +38,9 @@ def softmax_attn(scope: Scope, weights: Array): def with_dropout(fn, rate: float, deterministic: bool = False): def attn_fn(scope: Scope, weights: Array): attn_weights = fn(scope, weights) - return nn.dropout(scope, attn_weights, deterministic=deterministic, rate=rate) + return nn.dropout( + scope, attn_weights, deterministic=deterministic, rate=rate + ) return attn_fn diff --git a/tests/core/design/core_auto_encoder_test.py b/tests/core/design/core_auto_encoder_test.py index ac6d0b1f8c..785caf626b 100644 --- a/tests/core/design/core_auto_encoder_test.py +++ b/tests/core/design/core_auto_encoder_test.py @@ -106,7 +106,9 @@ def test_auto_encoder_hp_struct(self): x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) - variable_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) + variable_shapes = unfreeze( + jax.tree_util.tree_map(jnp.shape, variables['params']) + ) self.assertEqual( variable_shapes, { @@ -122,12 +124,16 @@ def test_auto_encoder_hp_struct(self): ) def test_auto_encoder_with_scope(self): - ae = lambda scope, x: AutoEncoder2(scope, latents=2, features=4, hidden=3)(x) + ae = lambda scope, x: AutoEncoder2(scope, latents=2, features=4, hidden=3)( + x + ) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) - variable_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) + variable_shapes = unfreeze( + jax.tree_util.tree_map(jnp.shape, variables['params']) + ) self.assertEqual( variable_shapes, { @@ -143,12 +149,16 @@ def test_auto_encoder_with_scope(self): ) def test_auto_encoder_bind_method(self): - ae = lambda scope, x: AutoEncoder3.create(scope, latents=2, features=4, hidden=3)(x) + ae = lambda scope, x: AutoEncoder3.create( + scope, latents=2, features=4, hidden=3 + )(x) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) - variable_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) + variable_shapes = unfreeze( + jax.tree_util.tree_map(jnp.shape, variables['params']) + ) self.assertEqual( variable_shapes, { diff --git a/tests/core/design/core_dense_test.py b/tests/core/design/core_dense_test.py index 55b0ddf0e3..a90d93cf3b 100644 --- a/tests/core/design/core_dense_test.py +++ b/tests/core/design/core_dense_test.py @@ -33,7 +33,9 @@ class Dense: bias_init: Any = nn.initializers.zeros_init() def __call__(self, scope, x): - kernel = scope.param('kernel', self.kernel_init, (x.shape[-1], self.features)) + kernel = scope.param( + 'kernel', self.kernel_init, (x.shape[-1], self.features) + ) y = x @ kernel if self.bias: bias = scope.param('bias', self.bias_init, (self.features,)) diff --git a/tests/core/design/core_resnet_test.py b/tests/core/design/core_resnet_test.py index f15ec159a9..ac7628bd7e 100644 --- a/tests/core/design/core_resnet_test.py +++ b/tests/core/design/core_resnet_test.py @@ -72,7 +72,9 @@ def resnet( strides = (2, 2) block_features = features * 2**i block_scope = scope.push(f'block_{i}_{j}') - x = residual_block(block_scope, x, conv, norm, act, block_features, strides) + x = residual_block( + block_scope, x, conv, norm, act, block_features, strides + ) # we can access parameters of the sub module by operating on the scope # Example: # block_scope.get_kind('params')['conv_1']['kernel'] diff --git a/tests/io_test.py b/tests/io_test.py index 70f9f5e63e..ca9d771831 100644 --- a/tests/io_test.py +++ b/tests/io_test.py @@ -176,7 +176,9 @@ def test_makedirs(self, backend_mode): with io.override_mode(backend_mode): io.makedirs(test_dir_path) - self.assertTrue(os.path.exists(test_dir_path) and (os.path.isdir(test_dir_path))) + self.assertTrue( + os.path.exists(test_dir_path) and (os.path.isdir(test_dir_path)) + ) def test_glob(self): with tempfile.TemporaryDirectory() as temp_dir_path: diff --git a/tests/linen/linen_combinators_test.py b/tests/linen/linen_combinators_test.py index 985e688ea7..685c01b434 100644 --- a/tests/linen/linen_combinators_test.py +++ b/tests/linen/linen_combinators_test.py @@ -37,7 +37,9 @@ class MLP(nn.Module): def __call__(self, inputs): x = inputs for layer_size in self.layer_sizes[:-1]: - x = nn.Dense(features=layer_size, kernel_init=nn.initializers.ones_init())(x) + x = nn.Dense( + features=layer_size, kernel_init=nn.initializers.ones_init() + )(x) if self.activation is not None: x = self.activation(x) x = nn.Dense( diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 263b9196b6..562bd03224 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -102,7 +102,6 @@ def test_init_module(self): self.assertEqual(params, {'bias': jnp.array([1.0])}) def test_lazy_init(self): - class Foo(nn.Module): @compact @@ -128,7 +127,9 @@ def __call__(self, x): return x * k with self.assertRaises(errors.LazyInitError): - Foo().lazy_init(random.PRNGKey(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) + Foo().lazy_init( + random.PRNGKey(0), jax.ShapeDtypeStruct((8, 4), jnp.float32) + ) def test_arg_module(self): rngkey = jax.random.PRNGKey(0) @@ -867,7 +868,9 @@ def test(self): Foo().init({}, method='allowed_apply_fn') # attributes which are not callables yield TypeError. - with self.assertRaisesRegex(TypeError, "'Foo.not_callable' must be a callable"): + with self.assertRaisesRegex( + TypeError, "'Foo.not_callable' must be a callable" + ): Foo().apply({}, method='not_callable') # test same for init. Foo().init({}, method='not_callable') @@ -1476,7 +1479,6 @@ def loss(params, perturbations, inputs, targets): ) def test_perturb_noop(self): - class Foo(nn.Module): @nn.compact @@ -1497,7 +1499,9 @@ def __call__(self, x): module.apply({'params': params}, x) # check errors if perturbations is passed but empty - with self.assertRaisesRegex(errors.ScopeCollectionNotFound, 'Tried to access'): + with self.assertRaisesRegex( + errors.ScopeCollectionNotFound, 'Tried to access' + ): module.apply({'params': params, 'perturbations': {}}, x) # check no error if perturbations is passed and not empty @@ -1570,7 +1574,6 @@ def f(foo, x): np.testing.assert_allclose(x, y) def test_unbind(self): - class Foo(nn.Module): def setup(self): @@ -1593,8 +1596,12 @@ def __call__(self, x): self.assertIsInstance(decoder, nn.Dense) self.assertEqual(decoder.features, 2) - np.testing.assert_equal(variables['params']['encoder'], encoder_vars['params']) - np.testing.assert_equal(variables['params']['decoder'], decoder_vars['params']) + np.testing.assert_equal( + variables['params']['encoder'], encoder_vars['params'] + ) + np.testing.assert_equal( + variables['params']['decoder'], decoder_vars['params'] + ) def test_passing_mutable_variables(self): class Foo(nn.Module): @@ -2296,7 +2303,6 @@ def __call__(self, x): y = bar.apply(vs, x) def test_nested_class_optional_adoption_name_preservation(self): - class Foo(nn.Module): @nn.compact @@ -2400,7 +2406,6 @@ def __call__(self, x): vs = foo.init(k, x) def test_relaxed_intercollection_conflict(self): - class Foo(nn.Module): @nn.compact @@ -2415,7 +2420,6 @@ def __call__(self, x): vs = foo.init(k, x) def test_relaxed_intercollection_conflict_set(self): - class Foo(nn.Module): @nn.compact diff --git a/tests/linen/linen_recurrent_test.py b/tests/linen/linen_recurrent_test.py index 5d07003c24..2f09a487bf 100644 --- a/tests/linen/linen_recurrent_test.py +++ b/tests/linen/linen_recurrent_test.py @@ -49,7 +49,9 @@ def test_rnn_basic_forward(self): for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out,)) - self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertIn( + layer_params['kernel'].shape[0], [channels_in, channels_out] + ) self.assertEqual(layer_params['kernel'].shape[1], channels_out) def test_rnn_multiple_batch_dims(self): @@ -70,7 +72,9 @@ def test_rnn_multiple_batch_dims(self): for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out,)) - self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertIn( + layer_params['kernel'].shape[0], [channels_in, channels_out] + ) self.assertEqual(layer_params['kernel'].shape[1], channels_out) def test_rnn_unroll(self): @@ -91,7 +95,9 @@ def test_rnn_unroll(self): for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out,)) - self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertIn( + layer_params['kernel'].shape[0], [channels_in, channels_out] + ) self.assertEqual(layer_params['kernel'].shape[1], channels_out) def test_rnn_time_major(self): @@ -118,7 +124,9 @@ def test_rnn_time_major(self): for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out,)) - self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertIn( + layer_params['kernel'].shape[0], [channels_in, channels_out] + ) self.assertEqual(layer_params['kernel'].shape[1], channels_out) def test_rnn_with_spatial_dimensions(self): @@ -168,11 +176,15 @@ def test_numerical_equivalence(self): ys: jnp.ndarray (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs) - cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) + cell_carry = rnn.cell.initialize_carry( + jax.random.PRNGKey(0), xs[:, 0].shape + ) cell_params = variables['params']['cell'] for i in range(seq_len): - cell_carry, y = rnn.cell.apply({'params': cell_params}, cell_carry, xs[:, i, :]) + cell_carry, y = rnn.cell.apply( + {'params': cell_params}, cell_carry, xs[:, i, :] + ) np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-5) np.testing.assert_allclose(cell_carry, carry, rtol=1e-5) @@ -184,7 +196,9 @@ def test_numerical_equivalence_with_mask(self): channels_out = 6 key = jax.random.PRNGKey(0) - seq_lengths = jax.random.randint(key, (batch_size,), minval=1, maxval=seq_len + 1) + seq_lengths = jax.random.randint( + key, (batch_size,), minval=1, maxval=seq_len + 1 + ) rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True) @@ -194,19 +208,25 @@ def test_numerical_equivalence_with_mask(self): jax.random.PRNGKey(0), xs, seq_lengths=seq_lengths ) - cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) + cell_carry = rnn.cell.initialize_carry( + jax.random.PRNGKey(0), xs[:, 0].shape + ) cell_params = variables['params']['cell'] carries = [] for i in range(seq_len): - cell_carry, y = rnn.cell.apply({'params': cell_params}, cell_carry, xs[:, i, :]) + cell_carry, y = rnn.cell.apply( + {'params': cell_params}, cell_carry, xs[:, i, :] + ) np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-5) carries.append(cell_carry) for batch_idx, length in enumerate(seq_lengths): t = int(length) - 1 for carries_t_, carry_ in zip(carries[t], carry): - np.testing.assert_allclose(carries_t_[batch_idx], carry_[batch_idx], rtol=1e-5) + np.testing.assert_allclose( + carries_t_[batch_idx], carry_[batch_idx], rtol=1e-5 + ) def test_numerical_equivalence_single_batch(self): batch_size = 3 @@ -223,7 +243,9 @@ def test_numerical_equivalence_single_batch(self): cell_params = variables['params']['cell'] for batch_idx in range(batch_size): - cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:1, 0].shape) + cell_carry = rnn.cell.initialize_carry( + jax.random.PRNGKey(0), xs[:1, 0].shape + ) for i in range(seq_len): cell_carry, y = rnn.cell.apply( @@ -252,7 +274,9 @@ def test_numerical_equivalence_single_batch_nn_scan(self): xs = jnp.ones((batch_size, seq_len, channels_in)) carry = rnn.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) ys: jnp.ndarray - (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), carry, xs) + (carry, ys), variables = rnn.init_with_output( + jax.random.PRNGKey(0), carry, xs + ) cell_params = variables['params'] @@ -276,7 +300,9 @@ def test_numerical_equivalence_single_batch_jax_scan(self): channels_in = 5 channels_out = 6 - xs = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, seq_len, channels_in)) + xs = jax.random.uniform( + jax.random.PRNGKey(0), (batch_size, seq_len, channels_in) + ) cell: nn.LSTMCell = nn.LSTMCell(channels_out) carry = cell.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) variables = cell.init(jax.random.PRNGKey(0), carry, xs[:, 0]) @@ -292,7 +318,9 @@ def scan_fn(carry, x): cell_carry = cell.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) for i in range(seq_len): - cell_carry, y = cell.apply({'params': cell_params}, cell_carry, xs[:, i, :]) + cell_carry, y = cell.apply( + {'params': cell_params}, cell_carry, xs[:, i, :] + ) np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-4) np.testing.assert_allclose(cell_carry, carry, rtol=1e-4) @@ -312,7 +340,9 @@ def test_reverse(self): cell_params = variables['params']['cell'] for batch_idx in range(batch_size): - cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:1, 0].shape) + cell_carry = rnn.cell.initialize_carry( + jax.random.PRNGKey(0), xs[:1, 0].shape + ) for i in range(seq_len): cell_carry, y = rnn.cell.apply( @@ -348,7 +378,9 @@ def test_reverse_but_keep_order(self): cell_params = variables['params']['cell'] for batch_idx in range(batch_size): - cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:1, 0].shape) + cell_carry = rnn.cell.initialize_carry( + jax.random.PRNGKey(0), xs[:1, 0].shape + ) for i in range(seq_len): cell_carry, y = rnn.cell.apply( @@ -356,7 +388,9 @@ def test_reverse_but_keep_order(self): cell_carry, xs[batch_idx, seq_len - i - 1, :][None], ) - np.testing.assert_allclose(y[0], ys[batch_idx, seq_len - i - 1, :], rtol=1e-5) + np.testing.assert_allclose( + y[0], ys[batch_idx, seq_len - i - 1, :], rtol=1e-5 + ) np.testing.assert_allclose( cell_carry, @@ -476,7 +510,9 @@ def test_return_carry(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - (carry, ys), variables = bdirectional.init_with_output(jax.random.PRNGKey(0), xs) + (carry, ys), variables = bdirectional.init_with_output( + jax.random.PRNGKey(0), xs + ) carry_forward, carry_backward = carry self.assertEqual(ys.shape, (batch_size, seq_len, channels_out * 2)) diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 1976999f08..54a46064a9 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -247,7 +247,9 @@ def test_group_norm(self): key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) - model_cls = nn.GroupNorm(num_groups=2, use_bias=False, use_scale=False, epsilon=e) + model_cls = nn.GroupNorm( + num_groups=2, use_bias=False, use_scale=False, epsilon=e + ) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) @@ -266,7 +268,9 @@ def test_group_norm_raises(self): key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) - model_cls = nn.GroupNorm(num_groups=3, use_bias=False, use_scale=False, epsilon=e) + model_cls = nn.GroupNorm( + num_groups=3, use_bias=False, use_scale=False, epsilon=e + ) with self.assertRaises(ValueError): model_cls.init_with_output(key2, x) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 8f8b8bcf13..6c9b5c03ff 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -327,7 +327,9 @@ def __call__(self, c, xs): # simulate scan in python for comparison: c = init_carry ys = [] - lstmcell_variables = freeze({'params': init_variables['params']['lstm_cell']}) + lstmcell_variables = freeze( + {'params': init_variables['params']['lstm_cell']} + ) for i in range(xs.shape[0]): c, y = nn.LSTMCell(2).apply(lstmcell_variables, c, xs[i]) ys.append(y[None, ...]) @@ -363,7 +365,9 @@ def __call__(self, c, b, xs): # simulate scan in python for comparison: c = init_carry ys = [] - lstmcell_variables = freeze({'params': init_variables['params']['lstm_cell']}) + lstmcell_variables = freeze( + {'params': init_variables['params']['lstm_cell']} + ) for i in range(xs.shape[0]): c, y = nn.LSTMCell(2).apply(lstmcell_variables, c, xs[i]) ys.append(y[None, ...]) @@ -1267,7 +1271,9 @@ class TiedAutencoder(nn.Module): @nn.compact def _call(self, x, decode): def f(self): - return nn.Dense(self.features if decode else self.latents, use_bias=False)(x) + return nn.Dense( + self.features if decode else self.latents, use_bias=False + )(x) if decode: map_fn = trans @@ -1367,7 +1373,9 @@ class Foo(nn.Module): @nn.compact def __call__(self, x): bar = Bar() - vars_t = jax.tree_util.tree_map(jnp.ones_like, bar.variables.get('params', {})) + vars_t = jax.tree_util.tree_map( + jnp.ones_like, bar.variables.get('params', {}) + ) _, out_t = nn.jvp( Bar.__call__, bar, (x,), (jnp.zeros_like(x),), {'params': vars_t} ) @@ -1427,7 +1435,6 @@ def __call__(self, x): ) def test_custom_vjp(self): - class Foo(nn.Module): @nn.compact @@ -1884,16 +1891,26 @@ def fn(mdl, x): vars = copy(vars, updates) self.assertEqual(vars['state'], {'0_count': 1, '1_count': 1, '2_count': 1}) - self.assertEqual(vars['params']['heads_0']['layers_0']['kernel'].shape, (3, 10)) + self.assertEqual( + vars['params']['heads_0']['layers_0']['kernel'].shape, (3, 10) + ) self.assertEqual(vars['params']['heads_0']['layers_0']['bias'].shape, (10,)) - self.assertEqual(vars['params']['heads_0']['layers_1']['kernel'].shape, (10, 7)) + self.assertEqual( + vars['params']['heads_0']['layers_1']['kernel'].shape, (10, 7) + ) self.assertEqual(vars['params']['heads_0']['layers_1']['bias'].shape, (7,)) - self.assertEqual(vars['params']['heads_0']['layers_2']['kernel'].shape, (7, 5)) + self.assertEqual( + vars['params']['heads_0']['layers_2']['kernel'].shape, (7, 5) + ) self.assertEqual(vars['params']['heads_0']['layers_2']['bias'].shape, (5,)) - self.assertEqual(vars['params']['heads_1']['layers_0']['kernel'].shape, (3, 11)) + self.assertEqual( + vars['params']['heads_1']['layers_0']['kernel'].shape, (3, 11) + ) self.assertEqual(vars['params']['heads_1']['layers_0']['bias'].shape, (11,)) - self.assertEqual(vars['params']['heads_1']['layers_1']['kernel'].shape, (11, 5)) + self.assertEqual( + vars['params']['heads_1']['layers_1']['kernel'].shape, (11, 5) + ) self.assertEqual(vars['params']['heads_1']['layers_1']['bias'].shape, (5,)) self.assertEqual(vars['params']['heads_2']['kernel'].shape, (3, 5)) diff --git a/tests/linen/partitioning_test.py b/tests/linen/partitioning_test.py index f5b2faddc6..7b79238b82 100644 --- a/tests/linen/partitioning_test.py +++ b/tests/linen/partitioning_test.py @@ -489,7 +489,9 @@ def __call__(self, x): with partitioning.axis_rules(p_rules): variables = Foo().init(jax.random.PRNGKey(0), jnp.array([1, 2, 3])) variables = unfreeze(variables) - variables['params'] = jax.tree_util.tree_map(lambda x: x.shape, variables['params']) + variables['params'] = jax.tree_util.tree_map( + lambda x: x.shape, variables['params'] + ) self.assertDictEqual( variables, { @@ -506,7 +508,9 @@ def __call__(self, x): jax.random.PRNGKey(0), jnp.array([[1, 2, 3], [4, 5, 6]]) ) variables = unfreeze(variables) - variables['params'] = jax.tree_util.tree_map(lambda x: x.shape, variables['params']) + variables['params'] = jax.tree_util.tree_map( + lambda x: x.shape, variables['params'] + ) self.assertDictEqual( variables, { diff --git a/tests/linen/summary_test.py b/tests/linen/summary_test.py index 0caecb02e6..eb2f5624b7 100644 --- a/tests/linen/summary_test.py +++ b/tests/linen/summary_test.py @@ -555,7 +555,6 @@ def __call__(self): self.assertIn("4.141592", lines[7]) def test_partitioned_params(self): - class Classifier(nn.Module): @nn.compact @@ -608,7 +607,9 @@ def __call__(self, x): return nn.Dense(2)(h) x = jnp.ones((16, 9)) - rep = Foo().tabulate(jax.random.PRNGKey(0), x, console_kwargs=CONSOLE_TEST_KWARGS) + rep = Foo().tabulate( + jax.random.PRNGKey(0), x, console_kwargs=CONSOLE_TEST_KWARGS + ) lines = rep.splitlines() self.assertIn("Total Parameters: 50", lines[-2]) diff --git a/tests/traverse_util_test.py b/tests/traverse_util_test.py index cdddeb5154..172a5585da 100644 --- a/tests/traverse_util_test.py +++ b/tests/traverse_util_test.py @@ -297,10 +297,18 @@ def test_path_aware_map_with_multi_transform(self): updates, new_state = tx.update(gradients, state, params) new_params = optax.apply_updates(params, updates) - self.assertTrue(np.allclose(new_params['linear_1']['b'], params['linear_1']['b'])) - self.assertTrue(np.allclose(new_params['linear_2']['b'], params['linear_2']['b'])) - self.assertFalse(np.allclose(new_params['linear_1']['w'], params['linear_1']['w'])) - self.assertFalse(np.allclose(new_params['linear_2']['w'], params['linear_2']['w'])) + self.assertTrue( + np.allclose(new_params['linear_1']['b'], params['linear_1']['b']) + ) + self.assertTrue( + np.allclose(new_params['linear_2']['b'], params['linear_2']['b']) + ) + self.assertFalse( + np.allclose(new_params['linear_1']['w'], params['linear_1']['w']) + ) + self.assertFalse( + np.allclose(new_params['linear_2']['w'], params['linear_2']['w']) + ) def test_path_aware_map_with_masked(self): params = {