Skip to content

Commit

Permalink
add preview to pyink
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 21, 2023
1 parent 0f4e662 commit f86fd1c
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 52 deletions.
6 changes: 2 additions & 4 deletions examples/cloud/launch_gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ def launch_gce(*, vm_name: str, startup_script: str):


def print_howto(login_args: Sequence[str]):
print(
f"""
print(f"""
###############################################################################
###############################################################################
Expand All @@ -227,8 +226,7 @@ def print_howto(login_args: Sequence[str]):
###############################################################################
###############################################################################
"""
)
""")


def main(_):
Expand Down
10 changes: 5 additions & 5 deletions examples/nlp_seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@

FLAGS = flags.FLAGS

flags.DEFINE_string('model_dir', default='', help=('Directory for model data.'))
flags.DEFINE_string('model_dir', default='', help='Directory for model data.')

flags.DEFINE_string('experiment', default='xpos', help=('Experiment name.'))
flags.DEFINE_string('experiment', default='xpos', help='Experiment name.')

flags.DEFINE_integer('batch_size', default=64, help='Batch size for training.')

Expand All @@ -58,7 +58,7 @@
'num_train_steps', default=75000, help='Number of train steps.'
)

flags.DEFINE_float('learning_rate', default=0.05, help=('Learning rate.'))
flags.DEFINE_float('learning_rate', default=0.05, help='Learning rate.')

flags.DEFINE_float(
'weight_decay',
Expand All @@ -74,9 +74,9 @@
'random_seed', default=0, help='Integer for PRNG random seed.'
)

flags.DEFINE_string('train', default='', help=('Path to training data.'))
flags.DEFINE_string('train', default='', help='Path to training data.')

flags.DEFINE_string('dev', default='', help=('Path to development data.'))
flags.DEFINE_string('dev', default='', help='Path to development data.')


def create_learning_rate_scheduler(
Expand Down
3 changes: 2 additions & 1 deletion flax/core/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def dense_general(
if set(batch_dims) != set(range(max_dim + 1)):
raise ValueError(
'batch_dims %s must be consecutive leading '
'dimensions starting from 0.' % str(batch_dims)
'dimensions starting from 0.'
% str(batch_dims)
)

ndim = inputs.ndim
Expand Down
9 changes: 6 additions & 3 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,8 @@ def __call__(self, x):

def __init__(self):
super().__init__(
'Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself'
'Can only call init, init_with_output or apply methods on an instance'
' of the Module class, not the Module class itself'
)


Expand Down Expand Up @@ -757,7 +758,8 @@ def __call__(self, x):

def __init__(self):
super().__init__(
'Trying to access a property that is accessing a non-existent attribute.'
'Trying to access a property that is accessing a non-existent'
' attribute.'
)


Expand All @@ -772,7 +774,8 @@ class InvalidCheckpointError(FlaxError):

def __init__(self, path, step):
super().__init__(
f'Trying to save an outdated checkpoint at step: "{step}" and path: "{path}".'
f'Trying to save an outdated checkpoint at step: "{step}" and path:'
f' "{path}".'
)


Expand Down
6 changes: 3 additions & 3 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ class MultiHeadDotProductAttention(Module):
deterministic: Optional[bool] = None
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[
[PRNGKey, Shape, Dtype], Array
] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
Expand Down
18 changes: 9 additions & 9 deletions flax/linen/experimental/layers_with_named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class Dense(nn.Module):
param_dtype: DType = jnp.float32
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, DType], Array] = default_kernel_init
bias_init: Callable[
[PRNGKey, Shape, DType], Array
] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, DType], Array] = (
initializers.zeros_init()
)
kernel_axes: Tuple[str, ...] = ()
dot_general: DotGeneralT = lax.dot_general

Expand Down Expand Up @@ -305,12 +305,12 @@ class LayerNorm(nn.Module):
param_dtype: DType = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[
[PRNGKey, Shape, DType], Array
] = initializers.zeros_init()
scale_init: Callable[
[PRNGKey, Shape, DType], Array
] = initializers.ones_init()
bias_init: Callable[[PRNGKey, Shape, DType], Array] = (
initializers.zeros_init()
)
scale_init: Callable[[PRNGKey, Shape, DType], Array] = (
initializers.ones_init()
)

@nn.compact
def __call__(self, x):
Expand Down
29 changes: 15 additions & 14 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class DenseGeneral(Module):
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[
[PRNGKey, Shape, Dtype], Array
] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)
precision: PrecisionLike = None
dot_general: DotGeneralT = lax.dot_general

Expand All @@ -117,7 +117,8 @@ def __call__(self, inputs: Array) -> Array:
if set(batch_dims) != set(range(max_dim + 1)):
raise ValueError(
'batch_dims %s must be consecutive leading '
'dimensions starting from 0.' % str(batch_dims)
'dimensions starting from 0.'
% str(batch_dims)
)

ndim = inputs.ndim
Expand Down Expand Up @@ -207,9 +208,9 @@ class Dense(Module):
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[
[PRNGKey, Shape, Dtype], Array
] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)
dot_general: DotGeneralT = lax.dot_general

@compact
Expand Down Expand Up @@ -334,9 +335,9 @@ class _Conv(Module):
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[
[PRNGKey, Shape, Dtype], Array
] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated

@property
Expand Down Expand Up @@ -444,7 +445,7 @@ def maybe_broadcast(
else:
if self.feature_group_count != 1:
raise NotImplementedError(
f'`lax.conv_general_dilated_local` does not support '
'`lax.conv_general_dilated_local` does not support '
f'`feature_group_count != 1`, got `{self.feature_group_count}`.'
)

Expand Down Expand Up @@ -660,9 +661,9 @@ class ConvTranspose(Module):
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[
[PRNGKey, Shape, Dtype], Array
] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)
transpose_kernel: bool = False

@compact
Expand Down
12 changes: 6 additions & 6 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ class DenseParams(Module):
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[
[PRNGKey, Shape, Dtype], Array
] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)

@compact
def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]:
Expand Down Expand Up @@ -728,9 +728,9 @@ class RNN(Module):
reverse: bool = False
keep_order: bool = False
unroll: int = 1
variable_axes: Mapping[
lift.CollectionFilter, lift.InOutScanAxis
] = FrozenDict()
variable_axes: Mapping[lift.CollectionFilter, lift.InOutScanAxis] = (
FrozenDict()
)
variable_broadcast: lift.CollectionFilter = 'params'
variable_carry: lift.CollectionFilter = False
split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict(
Expand Down
8 changes: 4 additions & 4 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,10 +1074,10 @@ def restore_checkpoint(
target
)
if isinstance(orbax_checkpointer._handler, ocp.PyTreeCheckpointHandler): # pylint: disable=protected-access
restore_kwargs[
'transforms'
] = orbax_utils.maybe_construct_transformations(
target, orbax_transforms
restore_kwargs['transforms'] = (
orbax_utils.maybe_construct_transformations(
target, orbax_transforms
)
)
restored = orbax_checkpointer.restore(
ckpt_path, item=target, **restore_kwargs
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@ exclude_lines = [
[tool.pyink]
pyink-indentation = 2
pyink-use-majority-quotes = true
line-length = 80
line-length = 80
preview = true
4 changes: 3 additions & 1 deletion tests/core/core_frozen_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def test_utility_copy(self, x, add_or_replace, actual_new_x):
},
{
'x': FrozenDict({'a': 1, 'b': {'c': 2}}),
'pretty_str': 'FrozenDict({\n a: 1,\n b: {\n c: 2,\n },\n})',
'pretty_str': (
'FrozenDict({\n a: 1,\n b: {\n c: 2,\n },\n})'
),
},
{
'x': 345,
Expand Down
5 changes: 4 additions & 1 deletion tests/core/core_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def test_inconsistent_param_shapes(self):
def f(scope):
scope.param('test', nn.initializers.ones_init(), (4,))

msg = r'Initializer expected to generate shape \(2,\) but got shape \(4,\) instead for parameter "test" in "/"'
msg = (
r'Initializer expected to generate shape \(2,\) but got shape \(4,\)'
r' instead for parameter "test" in "/"'
)
with self.assertRaisesRegex(errors.ScopeParamShapeError, msg):
apply(f)(freeze({'params': {'test': np.ones((2,))}}))

Expand Down

0 comments on commit f86fd1c

Please sign in to comment.