Skip to content

Commit

Permalink
Merge pull request #3331 from levskaya:version_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565109475
  • Loading branch information
Flax Authors committed Sep 13, 2023
2 parents ca3ea06 + 7ce981b commit 9c40e03
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion examples/lm1b/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def get_datasets(
config: ml_collections.ConfigDict,
*,
n_devices: int,
vocab_path: Optional[str] = None
vocab_path: Optional[str] = None,
):
"""Load and return dataset of batched examples for use during training."""
if vocab_path is None:
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo/ppo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def train_step(
*,
clip_param: float,
vf_coeff: float,
entropy_coeff: float
entropy_coeff: float,
):
"""Compilable train step.
Expand Down
4 changes: 2 additions & 2 deletions examples/wmt/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_raw_dataset(
dataset_builder: tfds.core.DatasetBuilder,
split: str,
*,
reverse_translation: bool = False
reverse_translation: bool = False,
) -> tf.data.Dataset:
"""Loads a raw WMT dataset and normalizes feature keys.
Expand Down Expand Up @@ -333,7 +333,7 @@ def get_wmt_datasets(
*,
n_devices: int,
reverse_translation: bool = True,
vocab_path: Optional[str] = None
vocab_path: Optional[str] = None,
):
"""Load and return dataset of batched examples for use during training."""
if vocab_path is None:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "flax"
requires-python = ">=3.9"
description = "Flax: A neural network library for JAX designed for flexibility"
keywords = []
authors = [
Expand All @@ -25,7 +26,9 @@ classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dynamic = ["version", "readme"]
Expand Down
14 changes: 8 additions & 6 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2442,16 +2442,18 @@ def _interceptor(f, args, kwargs, context):

mod = IdentityModule(parent=None)
x = 7
with nn.intercept_methods(
op_interceptor(lambda a: a + 1)
), nn.intercept_methods(op_interceptor(lambda a: a**2)):
with (
nn.intercept_methods(op_interceptor(lambda a: a + 1)),
nn.intercept_methods(op_interceptor(lambda a: a**2)),
):
y = mod(x)

self.assertEqual(y, (x**2) + 1)

with nn.intercept_methods(
op_interceptor(lambda a: a**2)
), nn.intercept_methods(op_interceptor(lambda a: a + 1)):
with (
nn.intercept_methods(op_interceptor(lambda a: a**2)),
nn.intercept_methods(op_interceptor(lambda a: a + 1)),
):
y = mod(x)

self.assertEqual(y, (x + 1) ** 2)
Expand Down

0 comments on commit 9c40e03

Please sign in to comment.