diff --git a/examples/cloud/launch_gce.py b/examples/cloud/launch_gce.py index a4a6c3765a..0520b30b5a 100644 --- a/examples/cloud/launch_gce.py +++ b/examples/cloud/launch_gce.py @@ -202,8 +202,7 @@ def launch_gce(*, vm_name: str, startup_script: str): def print_howto(login_args: Sequence[str]): - print( - f""" + print(f""" ############################################################################### ############################################################################### @@ -227,8 +226,7 @@ def print_howto(login_args: Sequence[str]): ############################################################################### ############################################################################### -""" - ) +""") def main(_): diff --git a/examples/nlp_seq/train.py b/examples/nlp_seq/train.py index bcb268cf9f..72ce469507 100644 --- a/examples/nlp_seq/train.py +++ b/examples/nlp_seq/train.py @@ -42,9 +42,9 @@ FLAGS = flags.FLAGS -flags.DEFINE_string('model_dir', default='', help=('Directory for model data.')) +flags.DEFINE_string('model_dir', default='', help='Directory for model data.') -flags.DEFINE_string('experiment', default='xpos', help=('Experiment name.')) +flags.DEFINE_string('experiment', default='xpos', help='Experiment name.') flags.DEFINE_integer('batch_size', default=64, help='Batch size for training.') @@ -58,7 +58,7 @@ 'num_train_steps', default=75000, help='Number of train steps.' ) -flags.DEFINE_float('learning_rate', default=0.05, help=('Learning rate.')) +flags.DEFINE_float('learning_rate', default=0.05, help='Learning rate.') flags.DEFINE_float( 'weight_decay', @@ -74,9 +74,9 @@ 'random_seed', default=0, help='Integer for PRNG random seed.' ) -flags.DEFINE_string('train', default='', help=('Path to training data.')) +flags.DEFINE_string('train', default='', help='Path to training data.') -flags.DEFINE_string('dev', default='', help=('Path to development data.')) +flags.DEFINE_string('dev', default='', help='Path to development data.') def create_learning_rate_scheduler( diff --git a/flax/core/nn/linear.py b/flax/core/nn/linear.py index b5816e4a2d..cfd62154e9 100644 --- a/flax/core/nn/linear.py +++ b/flax/core/nn/linear.py @@ -75,7 +75,8 @@ def dense_general( if set(batch_dims) != set(range(max_dim + 1)): raise ValueError( 'batch_dims %s must be consecutive leading ' - 'dimensions starting from 0.' % str(batch_dims) + 'dimensions starting from 0.' + % str(batch_dims) ) ndim = inputs.ndim diff --git a/flax/errors.py b/flax/errors.py index c8ef7d79df..df28287728 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -708,7 +708,8 @@ def __call__(self, x): def __init__(self): super().__init__( - 'Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself' + 'Can only call init, init_with_output or apply methods on an instance' + ' of the Module class, not the Module class itself' ) @@ -757,7 +758,8 @@ def __call__(self, x): def __init__(self): super().__init__( - 'Trying to access a property that is accessing a non-existent attribute.' + 'Trying to access a property that is accessing a non-existent' + ' attribute.' ) @@ -772,7 +774,8 @@ class InvalidCheckpointError(FlaxError): def __init__(self, path, step): super().__init__( - f'Trying to save an outdated checkpoint at step: "{step}" and path: "{path}".' + f'Trying to save an outdated checkpoint at step: "{step}" and path:' + f' "{path}".' ) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 7b74685cb8..3ecdc72ff1 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -237,9 +237,9 @@ class MultiHeadDotProductAttention(Module): deterministic: Optional[bool] = None precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[ - [PRNGKey, Shape, Dtype], Array - ] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) use_bias: bool = True attention_fn: Callable[..., Array] = dot_product_attention decode: bool = False diff --git a/flax/linen/experimental/layers_with_named_axes.py b/flax/linen/experimental/layers_with_named_axes.py index 4c5ca87a2e..cb8c4e09d8 100644 --- a/flax/linen/experimental/layers_with_named_axes.py +++ b/flax/linen/experimental/layers_with_named_axes.py @@ -68,9 +68,9 @@ class Dense(nn.Module): param_dtype: DType = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, DType], Array] = default_kernel_init - bias_init: Callable[ - [PRNGKey, Shape, DType], Array - ] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, DType], Array] = ( + initializers.zeros_init() + ) kernel_axes: Tuple[str, ...] = () dot_general: DotGeneralT = lax.dot_general @@ -305,12 +305,12 @@ class LayerNorm(nn.Module): param_dtype: DType = jnp.float32 use_bias: bool = True use_scale: bool = True - bias_init: Callable[ - [PRNGKey, Shape, DType], Array - ] = initializers.zeros_init() - scale_init: Callable[ - [PRNGKey, Shape, DType], Array - ] = initializers.ones_init() + bias_init: Callable[[PRNGKey, Shape, DType], Array] = ( + initializers.zeros_init() + ) + scale_init: Callable[[PRNGKey, Shape, DType], Array] = ( + initializers.ones_init() + ) @nn.compact def __call__(self, x): diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 522b71696f..463efeb10d 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -93,9 +93,9 @@ class DenseGeneral(Module): dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[ - [PRNGKey, Shape, Dtype], Array - ] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) precision: PrecisionLike = None dot_general: DotGeneralT = lax.dot_general @@ -117,7 +117,8 @@ def __call__(self, inputs: Array) -> Array: if set(batch_dims) != set(range(max_dim + 1)): raise ValueError( 'batch_dims %s must be consecutive leading ' - 'dimensions starting from 0.' % str(batch_dims) + 'dimensions starting from 0.' + % str(batch_dims) ) ndim = inputs.ndim @@ -207,9 +208,9 @@ class Dense(Module): param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[ - [PRNGKey, Shape, Dtype], Array - ] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) dot_general: DotGeneralT = lax.dot_general @compact @@ -334,9 +335,9 @@ class _Conv(Module): param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[ - [PRNGKey, Shape, Dtype], Array - ] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated @property @@ -444,7 +445,7 @@ def maybe_broadcast( else: if self.feature_group_count != 1: raise NotImplementedError( - f'`lax.conv_general_dilated_local` does not support ' + '`lax.conv_general_dilated_local` does not support ' f'`feature_group_count != 1`, got `{self.feature_group_count}`.' ) @@ -660,9 +661,9 @@ class ConvTranspose(Module): param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[ - [PRNGKey, Shape, Dtype], Array - ] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) transpose_kernel: bool = False @compact diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index e805aef4af..259127380b 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -225,9 +225,9 @@ class DenseParams(Module): param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[ - [PRNGKey, Shape, Dtype], Array - ] = initializers.zeros_init() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( + initializers.zeros_init() + ) @compact def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: @@ -728,9 +728,9 @@ class RNN(Module): reverse: bool = False keep_order: bool = False unroll: int = 1 - variable_axes: Mapping[ - lift.CollectionFilter, lift.InOutScanAxis - ] = FrozenDict() + variable_axes: Mapping[lift.CollectionFilter, lift.InOutScanAxis] = ( + FrozenDict() + ) variable_broadcast: lift.CollectionFilter = 'params' variable_carry: lift.CollectionFilter = False split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict( diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index e14eacb07c..84d8bfe4cf 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -1074,10 +1074,10 @@ def restore_checkpoint( target ) if isinstance(orbax_checkpointer._handler, ocp.PyTreeCheckpointHandler): # pylint: disable=protected-access - restore_kwargs[ - 'transforms' - ] = orbax_utils.maybe_construct_transformations( - target, orbax_transforms + restore_kwargs['transforms'] = ( + orbax_utils.maybe_construct_transformations( + target, orbax_transforms + ) ) restored = orbax_checkpointer.restore( ckpt_path, item=target, **restore_kwargs diff --git a/pyproject.toml b/pyproject.toml index 344fb82713..e8fc9051b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,4 +150,5 @@ exclude_lines = [ [tool.pyink] pyink-indentation = 2 pyink-use-majority-quotes = true -line-length = 80 \ No newline at end of file +line-length = 80 +preview = true \ No newline at end of file diff --git a/tests/core/core_frozen_dict_test.py b/tests/core/core_frozen_dict_test.py index 4f2df7283f..a44405ffb9 100644 --- a/tests/core/core_frozen_dict_test.py +++ b/tests/core/core_frozen_dict_test.py @@ -128,7 +128,9 @@ def test_utility_copy(self, x, add_or_replace, actual_new_x): }, { 'x': FrozenDict({'a': 1, 'b': {'c': 2}}), - 'pretty_str': 'FrozenDict({\n a: 1,\n b: {\n c: 2,\n },\n})', + 'pretty_str': ( + 'FrozenDict({\n a: 1,\n b: {\n c: 2,\n },\n})' + ), }, { 'x': 345, diff --git a/tests/core/core_scope_test.py b/tests/core/core_scope_test.py index ba7fa5e56a..63744e9560 100644 --- a/tests/core/core_scope_test.py +++ b/tests/core/core_scope_test.py @@ -128,7 +128,10 @@ def test_inconsistent_param_shapes(self): def f(scope): scope.param('test', nn.initializers.ones_init(), (4,)) - msg = r'Initializer expected to generate shape \(2,\) but got shape \(4,\) instead for parameter "test" in "/"' + msg = ( + r'Initializer expected to generate shape \(2,\) but got shape \(4,\)' + r' instead for parameter "test" in "/"' + ) with self.assertRaisesRegex(errors.ScopeParamShapeError, msg): apply(f)(freeze({'params': {'test': np.ones((2,))}}))