From 7ce981bce295eb5fda30e4f9cd77929b6f003cc9 Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Wed, 13 Sep 2023 10:46:53 -0700 Subject: [PATCH] Update minimum python version info. Flax follows JAX (and numpy) minimum python version schedule. We weren't specifying the minimum version metadata in pyproject.toml Also includes some random autoformatting fixes. --- examples/lm1b/input_pipeline.py | 2 +- examples/ppo/ppo_lib.py | 2 +- examples/wmt/input_pipeline.py | 4 ++-- pyproject.toml | 5 ++++- tests/linen/linen_module_test.py | 14 ++++++++------ 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/lm1b/input_pipeline.py b/examples/lm1b/input_pipeline.py index 5c8f388d06..4090f724d5 100644 --- a/examples/lm1b/input_pipeline.py +++ b/examples/lm1b/input_pipeline.py @@ -322,7 +322,7 @@ def get_datasets( config: ml_collections.ConfigDict, *, n_devices: int, - vocab_path: Optional[str] = None + vocab_path: Optional[str] = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index c108cccf1c..2b7258acf9 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -139,7 +139,7 @@ def train_step( *, clip_param: float, vf_coeff: float, - entropy_coeff: float + entropy_coeff: float, ): """Compilable train step. diff --git a/examples/wmt/input_pipeline.py b/examples/wmt/input_pipeline.py index 9bf75209c4..215222c1d3 100644 --- a/examples/wmt/input_pipeline.py +++ b/examples/wmt/input_pipeline.py @@ -47,7 +47,7 @@ def get_raw_dataset( dataset_builder: tfds.core.DatasetBuilder, split: str, *, - reverse_translation: bool = False + reverse_translation: bool = False, ) -> tf.data.Dataset: """Loads a raw WMT dataset and normalizes feature keys. @@ -333,7 +333,7 @@ def get_wmt_datasets( *, n_devices: int, reverse_translation: bool = True, - vocab_path: Optional[str] = None + vocab_path: Optional[str] = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: diff --git a/pyproject.toml b/pyproject.toml index 3fb500754b..b299351035 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "flax" +requires-python = ">=3.9" description = "Flax: A neural network library for JAX designed for flexibility" keywords = [] authors = [ @@ -25,7 +26,9 @@ classifiers = [ "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dynamic = ["version", "readme"] diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index c0e1cf7f25..f2b5da866a 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2442,16 +2442,18 @@ def _interceptor(f, args, kwargs, context): mod = IdentityModule(parent=None) x = 7 - with nn.intercept_methods( - op_interceptor(lambda a: a + 1) - ), nn.intercept_methods(op_interceptor(lambda a: a**2)): + with ( + nn.intercept_methods(op_interceptor(lambda a: a + 1)), + nn.intercept_methods(op_interceptor(lambda a: a**2)), + ): y = mod(x) self.assertEqual(y, (x**2) + 1) - with nn.intercept_methods( - op_interceptor(lambda a: a**2) - ), nn.intercept_methods(op_interceptor(lambda a: a + 1)): + with ( + nn.intercept_methods(op_interceptor(lambda a: a**2)), + nn.intercept_methods(op_interceptor(lambda a: a + 1)), + ): y = mod(x) self.assertEqual(y, (x + 1) ** 2)