Skip to content

Commit

Permalink
Update minimum python version info.
Browse files Browse the repository at this point in the history
Flax follows JAX (and numpy) minimum python version schedule.
We weren't specifying the minimum version metadata in pyproject.toml

Also includes some random autoformatting fixes.
  • Loading branch information
levskaya committed Sep 13, 2023
1 parent ca3ea06 commit 7ce981b
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion examples/lm1b/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def get_datasets(
config: ml_collections.ConfigDict,
*,
n_devices: int,
vocab_path: Optional[str] = None
vocab_path: Optional[str] = None,
):
"""Load and return dataset of batched examples for use during training."""
if vocab_path is None:
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo/ppo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def train_step(
*,
clip_param: float,
vf_coeff: float,
entropy_coeff: float
entropy_coeff: float,
):
"""Compilable train step.
Expand Down
4 changes: 2 additions & 2 deletions examples/wmt/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_raw_dataset(
dataset_builder: tfds.core.DatasetBuilder,
split: str,
*,
reverse_translation: bool = False
reverse_translation: bool = False,
) -> tf.data.Dataset:
"""Loads a raw WMT dataset and normalizes feature keys.
Expand Down Expand Up @@ -333,7 +333,7 @@ def get_wmt_datasets(
*,
n_devices: int,
reverse_translation: bool = True,
vocab_path: Optional[str] = None
vocab_path: Optional[str] = None,
):
"""Load and return dataset of batched examples for use during training."""
if vocab_path is None:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "flax"
requires-python = ">=3.9"
description = "Flax: A neural network library for JAX designed for flexibility"
keywords = []
authors = [
Expand All @@ -25,7 +26,9 @@ classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dynamic = ["version", "readme"]
Expand Down
14 changes: 8 additions & 6 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2442,16 +2442,18 @@ def _interceptor(f, args, kwargs, context):

mod = IdentityModule(parent=None)
x = 7
with nn.intercept_methods(
op_interceptor(lambda a: a + 1)
), nn.intercept_methods(op_interceptor(lambda a: a**2)):
with (
nn.intercept_methods(op_interceptor(lambda a: a + 1)),
nn.intercept_methods(op_interceptor(lambda a: a**2)),
):
y = mod(x)

self.assertEqual(y, (x**2) + 1)

with nn.intercept_methods(
op_interceptor(lambda a: a**2)
), nn.intercept_methods(op_interceptor(lambda a: a + 1)):
with (
nn.intercept_methods(op_interceptor(lambda a: a**2)),
nn.intercept_methods(op_interceptor(lambda a: a + 1)),
):
y = mod(x)

self.assertEqual(y, (x + 1) ** 2)
Expand Down

0 comments on commit 7ce981b

Please sign in to comment.