From 338caabfd6f8f587c33ce69ef72e59fe4d6da418 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 21 Jul 2023 18:56:27 +0000 Subject: [PATCH 1/3] add pyink --- .pre-commit-config.yaml | 6 +++++- pyproject.toml | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd93c78bd3..2fd3f6b440 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: hooks: - id: check-toml - id: trailing-whitespace - exclude: ^docs/.*\.md$ + exclude: ^docs/.*\.md$|^dev/.*\.py$ - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: @@ -29,3 +29,7 @@ repos: --extra-keys, "metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab", ] +- repo: https://github.com/google/pyink + rev: 23.5.0 + hooks: + - id: pyink diff --git a/pyproject.toml b/pyproject.toml index 4fb4c451b5..6cad172004 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ testing = [ "tensorflow", "torch", "nbstripout", + "black[jupyter]==23.7.0", + "pyink==23.5.0", ] [project.urls] @@ -143,4 +145,8 @@ filterwarnings = [ exclude_lines = [ "@abc.abstractmethod", "raise NotImplementedError", -] \ No newline at end of file +] + +[tool.pyink] +pyink-indentation = 2 +pyink-use-majority-quotes = true \ No newline at end of file From 40a6e074e5224d733f964be00e21e0a1cb98bd2e Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 21 Jul 2023 18:57:09 +0000 Subject: [PATCH 2/3] apply pyink --- .github/analytics/get_repo_metrics.py | 128 +-- dev/update_requirements.py | 30 +- docs/_ext/codediff.py | 21 +- docs/_ext/codediff_test.py | 18 +- docs/_ext/flax_module.py | 24 +- docs/conf.py | 11 +- docs/conf_sphinx_patch.py | 293 ++++--- examples/cloud/launch_gce.py | 74 +- examples/imagenet/imagenet_benchmark.py | 12 +- examples/imagenet/input_pipeline.py | 69 +- examples/imagenet/main.py | 13 +- examples/imagenet/models.py | 81 +- examples/imagenet/models_test.py | 10 +- examples/imagenet/train.py | 144 +-- examples/imagenet/train_test.py | 4 +- .../linen_design_test/attention_simple.py | 89 +- examples/linen_design_test/autoencoder.py | 18 +- examples/linen_design_test/dense.py | 10 +- .../linen_design_test/linear_regression.py | 10 +- examples/linen_design_test/mlp_explicit.py | 11 +- examples/linen_design_test/mlp_inline.py | 5 +- examples/linen_design_test/mlp_lazy.py | 3 +- examples/linen_design_test/weight_std.py | 1 + examples/lm1b/input_pipeline.py | 136 +-- examples/lm1b/input_pipeline_test.py | 44 +- examples/lm1b/main.py | 13 +- examples/lm1b/models.py | 142 ++- examples/lm1b/temperature_sampler.py | 49 +- examples/lm1b/temperature_sampler_test.py | 15 +- examples/lm1b/tokenizer.py | 65 +- examples/lm1b/train.py | 218 ++--- examples/mnist/main.py | 13 +- examples/mnist/train.py | 28 +- examples/mnist/train_test.py | 5 +- examples/nlp_seq/input_pipeline.py | 63 +- examples/nlp_seq/input_pipeline_test.py | 54 +- examples/nlp_seq/models.py | 46 +- examples/nlp_seq/train.py | 110 +-- examples/ogbg_molpcba/input_pipeline.py | 73 +- examples/ogbg_molpcba/input_pipeline_test.py | 9 +- examples/ogbg_molpcba/main.py | 13 +- examples/ogbg_molpcba/models.py | 75 +- examples/ogbg_molpcba/models_test.py | 43 +- .../ogbg_molpcba/ogbg_molpcba_benchmark.py | 30 +- examples/ogbg_molpcba/train.py | 106 +-- examples/ogbg_molpcba/train_test.py | 101 ++- examples/ppo/agent.py | 13 +- examples/ppo/configs/default.py | 1 + examples/ppo/env_utils.py | 11 +- examples/ppo/models.py | 17 +- examples/ppo/ppo_lib.py | 106 ++- examples/ppo/ppo_lib_test.py | 64 +- examples/ppo/ppo_main.py | 9 +- examples/ppo/seed_rl_atari_preprocessing.py | 44 +- examples/ppo/test_episodes.py | 3 +- examples/seq2seq/input_pipeline.py | 25 +- examples/seq2seq/models.py | 22 +- examples/seq2seq/train.py | 87 +- examples/seq2seq/train_test.py | 33 +- examples/sst2/build_vocabulary.py | 13 +- examples/sst2/input_pipeline.py | 100 ++- examples/sst2/input_pipeline_test.py | 17 +- examples/sst2/main.py | 13 +- examples/sst2/models.py | 103 ++- examples/sst2/models_test.py | 15 +- examples/sst2/train.py | 105 ++- examples/sst2/vocabulary.py | 35 +- examples/vae/main.py | 26 +- examples/vae/models.py | 2 + examples/vae/train.py | 36 +- examples/vae/utils.py | 22 +- examples/wmt/bleu.py | 51 +- examples/wmt/decode.py | 150 ++-- examples/wmt/input_pipeline.py | 155 ++-- examples/wmt/input_pipeline_test.py | 44 +- examples/wmt/main.py | 13 +- examples/wmt/models.py | 269 +++--- examples/wmt/tokenizer.py | 65 +- examples/wmt/train.py | 231 ++--- flax/__init__.py | 2 +- flax/configurations.py | 23 +- flax/core/__init__.py | 57 +- flax/core/axes_scan.py | 42 +- flax/core/frozen_dict.py | 40 +- flax/core/lift.py | 522 ++++++----- flax/core/meta.py | 33 +- flax/core/nn/__init__.py | 56 +- flax/core/nn/attention.py | 142 +-- flax/core/nn/linear.py | 153 ++-- flax/core/nn/normalization.py | 82 +- flax/core/nn/stochastic.py | 8 +- flax/core/partial_eval.py | 1 + flax/core/scope.py | 241 ++--- flax/errors.py | 108 ++- flax/ids.py | 8 + flax/io.py | 19 +- flax/jax_utils.py | 36 +- flax/linen/__init__.py | 174 ++-- flax/linen/activation.py | 12 +- flax/linen/attention.py | 244 +++--- flax/linen/combinators.py | 8 +- flax/linen/dtypes.py | 11 +- .../experimental/layers_with_named_axes.py | 88 +- flax/linen/initializers.py | 4 +- flax/linen/kw_only_dataclasses.py | 28 +- flax/linen/linear.py | 152 ++-- flax/linen/module.py | 536 ++++++++---- flax/linen/normalization.py | 238 +++-- flax/linen/partitioning.py | 183 ++-- flax/linen/pooling.py | 23 +- flax/linen/recurrent.py | 334 ++++--- flax/linen/spmd.py | 59 +- flax/linen/stochastic.py | 34 +- flax/linen/summary.py | 147 +++- flax/linen/transforms.py | 423 +++++---- flax/metrics/__init__.py | 1 - flax/metrics/tensorboard.py | 29 +- flax/serialization.py | 99 ++- flax/struct.py | 44 +- flax/testing/benchmark.py | 49 +- flax/training/checkpoints.py | 343 +++++--- flax/training/common_utils.py | 5 +- flax/training/dynamic_scale.py | 26 +- flax/training/early_stopping.py | 13 +- flax/training/lr_schedule.py | 54 +- flax/training/orbax_utils.py | 9 +- flax/training/prefetch_iterator.py | 11 +- flax/training/train_state.py | 4 +- flax/traverse_util.py | 33 +- flax/version.py | 1 - tests/checkpoints_test.py | 200 +++-- tests/core/core_frozen_dict_test.py | 23 +- tests/core/core_lift_test.py | 100 ++- tests/core/core_meta_test.py | 137 +-- tests/core/core_scope_test.py | 83 +- tests/core/design/core_attention_test.py | 74 +- tests/core/design/core_auto_encoder_test.py | 65 +- tests/core/design/core_big_resnets_test.py | 44 +- tests/core/design/core_custom_vjp_test.py | 19 +- tests/core/design/core_dense_test.py | 84 +- tests/core/design/core_flow_test.py | 22 +- tests/core/design/core_resnet_test.py | 123 +-- tests/core/design/core_scan_test.py | 35 +- .../core/design/core_tied_autoencoder_test.py | 32 +- tests/core/design/core_vmap_test.py | 59 +- tests/core/design/core_weight_std_test.py | 35 +- tests/early_stopping_test.py | 42 +- tests/io_test.py | 33 +- tests/jax_utils_test.py | 12 +- tests/linen/initializers_test.py | 38 +- tests/linen/kw_only_dataclasses_test.py | 6 +- tests/linen/linen_attention_test.py | 44 +- tests/linen/linen_combinators_test.py | 32 +- tests/linen/linen_dtypes_test.py | 11 +- tests/linen/linen_linear_test.py | 360 ++++---- tests/linen/linen_meta_test.py | 125 +-- tests/linen/linen_module_test.py | 398 ++++----- tests/linen/linen_recurrent_test.py | 113 +-- tests/linen/linen_test.py | 228 ++--- tests/linen/linen_transforms_test.py | 821 ++++++++++++------ tests/linen/partitioning_test.py | 302 ++++--- tests/linen/summary_test.py | 253 +++--- tests/linen/toplevel_test.py | 3 + tests/serialization_test.py | 320 ++++--- tests/struct_test.py | 6 +- tests/tensorboard_test.py | 236 ++--- tests/traceback_util_test.py | 29 +- tests/traverse_util_test.py | 141 +-- 168 files changed, 8114 insertions(+), 5996 deletions(-) diff --git a/.github/analytics/get_repo_metrics.py b/.github/analytics/get_repo_metrics.py index a43b49200e..c822e23962 100644 --- a/.github/analytics/get_repo_metrics.py +++ b/.github/analytics/get_repo_metrics.py @@ -25,17 +25,18 @@ import matplotlib.dates as mdates -token = os.environ["GITHUB_TOKEN"] -endpoint = r"https://api.github.com/graphql" -headers = {"Authorization": f"bearer {token}"} +token = os.environ['GITHUB_TOKEN'] +endpoint = r'https://api.github.com/graphql' +headers = {'Authorization': f'bearer {token}'} -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ # GraphQL -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ # NOTE: This GraphQL logic was ported and adapted from this script: # https://github.com/scientific-python/devstats-data/blob/4c022961abc4ca6061f8719d9c3387e98734b90c/query.py # It contains style differences from Google's style guide. + def load_query_from_file(fname, repo_owner, repo_name) -> str: with open(fname) as fh: query = fh.read() @@ -75,9 +76,7 @@ def send_query(query, query_type, cursor=None): # TODO: Expand this, either by parsing the query type from the query # directly or manually adding more query_types to the set if query_type not in {'issues', 'pullRequests'}: - raise ValueError( - 'Only \'issues\' and \'pullRequests\' queries are currently supported' - ) + raise ValueError("Only 'issues' and 'pullRequests' queries are currently supported") # TODO: Generalize this # WARNING: The cursor injection depends on the specific structure of the # query, this is the main reason why query types are limited to issues/PRs @@ -86,12 +85,13 @@ def send_query(query, query_type, cursor=None): cursor_ind = query.find(cursor_insertion_key) + len(cursor_insertion_key) query = query[:cursor_ind] + f'after:"{cursor}", ' + query[cursor_ind:] # Build request payload - payload = {'query' : query} + payload = {'query': query} response = requests.post(endpoint, json=payload, headers=headers) return json.loads(response.content) + def get_all_responses(query, query_type): - "Helper function to bypass GitHub GraphQL API node limit." + 'Helper function to bypass GitHub GraphQL API node limit.' # Get data from a single response initial_data = send_query(query, query_type) data, last_cursor, total_count = parse_single_query(initial_data, query_type) @@ -105,6 +105,7 @@ def get_all_responses(query, query_type): print('Done.') return data + def parse_single_query(data, query_type): """ Parses the data returned by `send_query` @@ -159,20 +160,21 @@ def __init__(self, query_fname, query_type, repo_owner, repo_name): self.load_query() def load_query(self): - self.query = load_query_from_file( - self.query_fname, self.repo_owner, self.repo_name - ) + self.query = load_query_from_file(self.query_fname, self.repo_owner, self.repo_name) def get(self): self.raw_data = get_all_responses(self.query, self.query_type) -#------------------------------------------------------------------------------ + +# ------------------------------------------------------------------------------ # metrics helpers -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ + def _to_datetime(date_str: str) -> datetime: return datetime.fromisoformat(date_str.replace('Z', '')) + def _get_issues_features(issues): for issue in issues: issue = issue['node'] @@ -191,12 +193,13 @@ def _get_issues_features(issues): time_issue_closed = _to_datetime(event['createdAt']) yield { - 'created_at': created_at, - 'time_labeled_or_converted': time_labeled_or_converted, - 'time_issue_closed': time_issue_closed, - 'issue_closed': issue['state'] == 'CLOSED', + 'created_at': created_at, + 'time_labeled_or_converted': time_labeled_or_converted, + 'time_issue_closed': time_issue_closed, + 'issue_closed': issue['state'] == 'CLOSED', } + def _get_pr_features(prs): for pr in prs: pr = pr['node'] @@ -207,24 +210,21 @@ def _get_pr_features(prs): time_merged_or_closed = None time_review = None - if pr["reviews"]["nodes"]: - review = pr["reviews"]["nodes"][0] - time_review = _to_datetime(review["createdAt"]) + if pr['reviews']['nodes']: + review = pr['reviews']['nodes'][0] + time_review = _to_datetime(review['createdAt']) for event in pr['timelineItems']['edges']: event = event['node'] if ( - time_labeled_or_assigned is None - and event['__typename'] == 'LabeledEvent' - and 'cla:' not in event['label']['name'] + time_labeled_or_assigned is None + and event['__typename'] == 'LabeledEvent' + and 'cla:' not in event['label']['name'] ): time_labeled_or_assigned = _to_datetime(event['createdAt']) - if ( - time_labeled_or_assigned is None - and event['__typename'] == 'AssignedEvent' - ): + if time_labeled_or_assigned is None and event['__typename'] == 'AssignedEvent': time_labeled_or_assigned = _to_datetime(event['createdAt']) if event['__typename'] in {'ClosedEvent', 'MergedEvent'}: @@ -234,17 +234,19 @@ def _get_pr_features(prs): ready_for_review_at = _to_datetime(event['createdAt']) yield { - 'created_at': created_at, - 'ready_for_review_at': ready_for_review_at, - 'time_labeled_or_assigned': time_labeled_or_assigned, - 'time_merged_or_closed': time_merged_or_closed, - 'time_review': time_review, - 'pr_closed': pr['state'] != 'OPEN', + 'created_at': created_at, + 'ready_for_review_at': ready_for_review_at, + 'time_labeled_or_assigned': time_labeled_or_assigned, + 'time_merged_or_closed': time_merged_or_closed, + 'time_review': time_review, + 'pr_closed': pr['state'] != 'OPEN', } + def _start_of_month(date: datetime) -> datetime: return date.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + def _shift_n_months(date: datetime, n: int) -> datetime: month = ((date.month + n - 1) % 12) + 1 @@ -258,14 +260,14 @@ def _shift_n_months(date: datetime, n: int) -> datetime: def _rolling_window( - df: pd.DataFrame, - f: Callable[[pd.DataFrame], pd.Series], - window_size: int = 6, - step: int = 1, + df: pd.DataFrame, + f: Callable[[pd.DataFrame], pd.Series], + window_size: int = 6, + step: int = 1, ) -> pd.DataFrame: # start of month of the first issue start: datetime = df.iloc[0]['created_at'].replace( - day=1, hour=0, minute=0, second=0, microsecond=0 + day=1, hour=0, minute=0, second=0, microsecond=0 ) end = _shift_n_months(start, window_size) @@ -286,56 +288,66 @@ def _rolling_window( return df + def _process_prs(df: pd.DataFrame) -> pd.Series: return pd.Series({ - 'pr_response_time': df['pr_response_time'].dt.days.mean(), - 'pr_resolution_time': df['pr_resolution_time'].dt.days.mean(), + 'pr_response_time': df['pr_response_time'].dt.days.mean(), + 'pr_resolution_time': df['pr_resolution_time'].dt.days.mean(), }) + def _process_issues(df: pd.DataFrame) -> pd.Series: return pd.Series({ - 'issue_response_time': df['issue_response_time'].dt.days.mean(), - 'issue_resolution_time': df['issue_resolution_time'].dt.days.mean(), + 'issue_response_time': df['issue_response_time'].dt.days.mean(), + 'issue_resolution_time': df['issue_resolution_time'].dt.days.mean(), }) -#----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- # main -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- FLAGS = flags.FLAGS flags.DEFINE_string('repo_owner', 'google', 'User name or organization') flags.DEFINE_string('repo_name', 'flax', 'Name of the repository') + def main(_): repo_owner: str = FLAGS.repo_owner repo_name: str = FLAGS.repo_name # Download issue data issues = GithubGrabber( - '.github/analytics/issue_activity_since_date.gql', - 'issues', - repo_owner=repo_owner, - repo_name=repo_name, + '.github/analytics/issue_activity_since_date.gql', + 'issues', + repo_owner=repo_owner, + repo_name=repo_name, ) issues.get() df_issues = df_issues0 = pd.DataFrame(list(_get_issues_features(issues.raw_data))) - df_issues['issue_response_time'] = df_issues['time_labeled_or_converted'] - df_issues['created_at'] - df_issues['issue_resolution_time'] = df_issues['time_issue_closed'] - df_issues['created_at'] + df_issues['issue_response_time'] = ( + df_issues['time_labeled_or_converted'] - df_issues['created_at'] + ) + df_issues['issue_resolution_time'] = ( + df_issues['time_issue_closed'] - df_issues['created_at'] + ) df_issues = _rolling_window(df_issues, _process_issues) prs = GithubGrabber( - '.github/analytics/pr_data_query.gql', - 'pullRequests', - repo_owner=repo_owner, - repo_name=repo_name, + '.github/analytics/pr_data_query.gql', + 'pullRequests', + repo_owner=repo_owner, + repo_name=repo_name, ) prs.get() df_prs = df_prs0 = pd.DataFrame(list(_get_pr_features(prs.raw_data))) time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(axis=1) df_prs['pr_response_time'] = time_response - df_prs['ready_for_review_at'] - df_prs['pr_resolution_time'] = df_prs['time_merged_or_closed'] - df_prs['ready_for_review_at'] + df_prs['pr_resolution_time'] = ( + df_prs['time_merged_or_closed'] - df_prs['ready_for_review_at'] + ) df_prs = _rolling_window(df_prs, _process_prs) @@ -367,7 +379,6 @@ def main(_): plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5)) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) - # plot for isssue_response_time plt.figure() plt.plot(df_issues['period_end'], df_issues['issue_response_time']) @@ -411,5 +422,6 @@ def main(_): # show plots plt.show() + if __name__ == '__main__': app.run(main) diff --git a/dev/update_requirements.py b/dev/update_requirements.py index 9691382b18..063e872ae7 100644 --- a/dev/update_requirements.py +++ b/dev/update_requirements.py @@ -38,7 +38,7 @@ Alternatively, the list can also be provided from the local environment with: -python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6" +python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6" """ import pathlib @@ -58,7 +58,8 @@ '`--version="$(pip freeze | sed s/==/-/g) flax-0.3.6"` ' '(note the flax version "override") ' 'or from the "install dependencies" step in the github build action ' - 'https://github.com/google/flax/actions/workflows/build.yml') + 'https://github.com/google/flax/actions/workflows/build.yml', +) flags.mark_flag_as_required('versions') flags.DEFINE_bool('verbose', False, 'enables verbose output.') flags.DEFINE_list('ignore', ['jax'], 'packages not to add to requirements.') @@ -67,22 +68,26 @@ import_re = re.compile(r'(?:from|import)\s+(\w+)') # maps `import cv2` to `pip install opencv-python` pkg_map = { - 'absl': 'absl-py', - 'atari_py': 'atari-py', - 'cv2': 'opencv-python', - 'ml_collections': 'ml-collections', - 'PIL': 'Pillow', - 'tensorflow_datasets': 'tensorflow-datasets', - 'tensorflow_text': 'tensorflow-text', + 'absl': 'absl-py', + 'atari_py': 'atari-py', + 'cv2': 'opencv-python', + 'ml_collections': 'ml-collections', + 'PIL': 'Pillow', + 'tensorflow_datasets': 'tensorflow-datasets', + 'tensorflow_text': 'tensorflow-text', } -standard_libs = set('codecs collections dataclasses datetime enum functools math multiprocessing itertools os pathlib random re sys tempfile time typing unicodedata warnings'.split(' ')) +standard_libs = set( + 'codecs collections dataclasses datetime enum functools math multiprocessing itertools os pathlib random re sys tempfile time typing unicodedata warnings'.split( + ' ' + ) +) def main(argv): del argv versions = { - pkg_version[:pkg_version.rindex('-')]: pkg_version[pkg_version.rindex('-') + 1:] + pkg_version[: pkg_version.rindex('-')]: pkg_version[pkg_version.rindex('-') + 1 :] for pkg_version in FLAGS.versions.replace('\n', ' ').split(' ') if '-' in pkg_version } @@ -117,7 +122,8 @@ def main(argv): print(f'{requirements} -', end=' ') with requirements.open('w') as f: for pkg in sorted(pkgs, key=str.casefold): - if pkg in ignore: continue + if pkg in ignore: + continue pkg = pkg_map.get(pkg, pkg) print(f'{pkg}-{versions[pkg]}', end=' ') f.write(f'{pkg}=={versions[pkg]}\n') diff --git a/docs/_ext/codediff.py b/docs/_ext/codediff.py index b7cc7b191b..ece5314821 100644 --- a/docs/_ext/codediff.py +++ b/docs/_ext/codediff.py @@ -38,22 +38,25 @@ MISSING = object() + class CodeDiffParser: def parse( - self, lines, title_left='Base', title_right='Diff', code_sep='---', sync=MISSING): + self, lines, title_left='Base', title_right='Diff', code_sep='---', sync=MISSING + ): sync = sync is not MISSING if code_sep not in lines: - raise ValueError('Code separator not found! Code snippets should be ' - f'separated by {code_sep}.') + raise ValueError( + 'Code separator not found! Code snippets should be ' + f'separated by {code_sep}.' + ) idx = lines.index(code_sep) - code_left = self._code_block(lines[0: idx]) - test_code = lines[idx+1:] + code_left = self._code_block(lines[0:idx]) + test_code = lines[idx + 1 :] code_right = self._code_block(test_code) - output = self._tabs( - (title_left, code_left), (title_right, code_right), sync=sync) + output = self._tabs((title_left, code_left), (title_right, code_right), sync=sync) return output, test_code @@ -88,6 +91,7 @@ def _tabs(self, *contents: Tuple[str, List[str]], sync): return output + class CodeDiffDirective(SphinxDirective): has_content = True option_spec = { @@ -98,8 +102,7 @@ class CodeDiffDirective(SphinxDirective): } def run(self): - table_code, test_code = CodeDiffParser().parse( - list(self.content), **self.options) + table_code, test_code = CodeDiffParser().parse(list(self.content), **self.options) # Create a test node as a comment node so it won't show up in the docs. # We add attribute "testnodetype" so it is be picked up by the doctest diff --git a/docs/_ext/codediff_test.py b/docs/_ext/codediff_test.py index 7d94c008e0..3f815fe6d8 100644 --- a/docs/_ext/codediff_test.py +++ b/docs/_ext/codediff_test.py @@ -22,8 +22,7 @@ class CodeDiffTest(absltest.TestCase): def test_parse(self): - - input_text = r'''@jax.jit #! + input_text = r"""@jax.jit #! def get_initial_params(key): #! init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] @@ -34,9 +33,9 @@ def get_initial_params(key): #! def get_initial_params(key): init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] - return initial_params''' + return initial_params""" - expected_table = r'''+----------------------------------------------------------+----------------------------------------------------------+ + expected_table = r"""+----------------------------------------------------------+----------------------------------------------------------+ | Single device | Ensembling on multiple devices | +----------------------------------------------------------+----------------------------------------------------------+ | .. code-block:: python | .. code-block:: python | @@ -48,21 +47,20 @@ def get_initial_params(key): | initial_params = CNN().init(key, init_val)['params'] | initial_params = CNN().init(key, init_val)['params'] | | extra_line | return initial_params | | return initial_params | | -+----------------------------------------------------------+----------------------------------------------------------+''' ++----------------------------------------------------------+----------------------------------------------------------+""" - expected_testcode = r'''@jax.pmap #! + expected_testcode = r"""@jax.pmap #! def get_initial_params(key): init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] - return initial_params''' + return initial_params""" title_left = 'Single device' title_right = 'Ensembling on multiple devices' actual_table, actual_testcode = CodeDiffParser().parse( - lines=input_text.split('\n'), - title_left=title_left, - title_right=title_right) + lines=input_text.split('\n'), title_left=title_left, title_right=title_right + ) actual_table = '\n'.join(actual_table) actual_testcode = '\n'.join(actual_testcode) diff --git a/docs/_ext/flax_module.py b/docs/_ext/flax_module.py index ece47caf3f..fd823ed526 100644 --- a/docs/_ext/flax_module.py +++ b/docs/_ext/flax_module.py @@ -19,7 +19,6 @@ .. flax_module:: :module: flax.linen :class: Dense - """ from docutils import nodes @@ -37,24 +36,35 @@ def render_module(modname: str, qualname: str, app): parent = importlib.import_module(modname) obj = getattr(parent, qualname) template = ag.AutosummaryRenderer(app) - template_name = "flax_module" + template_name = 'flax_module' imported_members = False recursive = False context = {} return generate_autosummary_content( - qualname, obj, parent, template, template_name, imported_members, - app, recursive, context, modname, qualname) + qualname, + obj, + parent, + template, + template_name, + imported_members, + app, + recursive, + context, + modname, + qualname, + ) + class FlaxModuleDirective(SphinxDirective): has_content = True option_spec = { - 'module': directives.unchanged, - 'class': directives.unchanged, + 'module': directives.unchanged, + 'class': directives.unchanged, } def run(self): module_template = render_module( - self.options['module'], self.options['class'], self.env.app + self.options['module'], self.options['class'], self.env.app ) module_template = module_template.splitlines() diff --git a/docs/conf.py b/docs/conf.py index 1114cbabdb..981106321f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,6 +31,7 @@ import os import sys + sys.path.insert(0, os.path.abspath('..')) # Include local extension. sys.path.append(os.path.abspath('./_ext')) @@ -110,9 +111,9 @@ html_theme_options = { 'repository_url': 'https://github.com/google/flax', - 'use_repository_button': True, # add a 'link to repository' button - 'use_issues_button': False, # add an 'Open an Issue' button - 'path_to_docs': 'docs', # used to compute the path to launch notebooks in colab + 'use_repository_button': True, # add a 'link to repository' button + 'use_issues_button': False, # add an 'Open an Issue' button + 'path_to_docs': 'docs', # used to compute the path to launch notebooks in colab 'launch_buttons': { 'colab_url': 'https://colab.research.google.com/', }, @@ -129,8 +130,8 @@ # files that will not be executed. myst_enable_extensions = ['dollarmath'] nb_execution_excludepatterns = [ - 'getting_started.ipynb', # <-- times out - 'optax_update_guide.ipynb', # <-- requires flax<=0.5.3 + 'getting_started.ipynb', # <-- times out + 'optax_update_guide.ipynb', # <-- requires flax<=0.5.3 ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False diff --git a/docs/conf_sphinx_patch.py b/docs/conf_sphinx_patch.py index 44bd876172..64d6ae291e 100644 --- a/docs/conf_sphinx_patch.py +++ b/docs/conf_sphinx_patch.py @@ -27,147 +27,162 @@ import sphinx.ext.autosummary.generate as ag import sphinx.ext.autodoc -def generate_autosummary_content(name: str, obj: Any, parent: Any, - template: ag.AutosummaryRenderer, template_name: str, - imported_members: bool, app: Any, - recursive: bool, context: Dict, - modname: str = None, qualname: str = None) -> str: - doc = ag.get_documenter(app, obj, parent) - - def skip_member(obj: Any, name: str, objtype: str) -> bool: - try: - return app.emit_firstresult('autodoc-skip-member', objtype, name, - obj, False, {}) - except Exception as exc: - ag.logger.warning(__('autosummary: failed to determine %r to be documented, ' - 'the following exception was raised:\n%s'), - name, exc, type='autosummary') - return False - - def get_class_members(obj: Any) -> Dict[str, Any]: - members = sphinx.ext.autodoc.get_class_members(obj, [qualname], ag.safe_getattr) - return {name: member.object for name, member in members.items()} - - def get_module_members(obj: Any) -> Dict[str, Any]: - members = {} - for name in ag.members_of(obj, app.config): - try: - members[name] = ag.safe_getattr(obj, name) - except AttributeError: - continue - return members - - def get_all_members(obj: Any) -> Dict[str, Any]: - if doc.objtype == "module": - return get_module_members(obj) - elif doc.objtype == "class": - return get_class_members(obj) - return {} - - def get_members(obj: Any, types: Set[str], include_public: List[str] = [], - imported: bool = True) -> Tuple[List[str], List[str]]: - items: List[str] = [] - public: List[str] = [] - - all_members = get_all_members(obj) - for name, value in all_members.items(): - documenter = ag.get_documenter(app, value, obj) - if documenter.objtype in types: - # skip imported members if expected - if imported or getattr(value, '__module__', None) == obj.__name__: - skipped = skip_member(value, name, documenter.objtype) - if skipped is True: - pass - elif skipped is False: - # show the member forcedly - items.append(name) - public.append(name) - else: - items.append(name) - if name in include_public or not name.startswith('_'): - # considers member as public - public.append(name) - return public, items - - def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: - """Find module attributes with docstrings.""" - attrs, public = [], [] - try: - analyzer = ag.ModuleAnalyzer.for_module(name) - attr_docs = analyzer.find_attr_docs() - for namespace, attr_name in attr_docs: - if namespace == '' and attr_name in members: - attrs.append(attr_name) - if not attr_name.startswith('_'): - public.append(attr_name) - except ag.PycodeError: - pass # give up if ModuleAnalyzer fails to parse code - return public, attrs - - def get_modules(obj: Any) -> Tuple[List[str], List[str]]: - items: List[str] = [] - for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__): - fullname = name + '.' + modname - try: - module = ag.import_module(fullname) - if module and hasattr(module, '__sphinx_mock__'): - continue - except ImportError: - pass - - items.append(fullname) - public = [x for x in items if not x.split('.')[-1].startswith('_')] - return public, items - - ns: Dict[str, Any] = {} - ns.update(context) +def generate_autosummary_content( + name: str, + obj: Any, + parent: Any, + template: ag.AutosummaryRenderer, + template_name: str, + imported_members: bool, + app: Any, + recursive: bool, + context: Dict, + modname: str = None, + qualname: str = None, +) -> str: + doc = ag.get_documenter(app, obj, parent) + + def skip_member(obj: Any, name: str, objtype: str) -> bool: + try: + return app.emit_firstresult('autodoc-skip-member', objtype, name, obj, False, {}) + except Exception as exc: + ag.logger.warning( + __( + 'autosummary: failed to determine %r to be documented, ' + 'the following exception was raised:\n%s' + ), + name, + exc, + type='autosummary', + ) + return False + + def get_class_members(obj: Any) -> Dict[str, Any]: + members = sphinx.ext.autodoc.get_class_members(obj, [qualname], ag.safe_getattr) + return {name: member.object for name, member in members.items()} + + def get_module_members(obj: Any) -> Dict[str, Any]: + members = {} + for name in ag.members_of(obj, app.config): + try: + members[name] = ag.safe_getattr(obj, name) + except AttributeError: + continue + return members + + def get_all_members(obj: Any) -> Dict[str, Any]: if doc.objtype == 'module': - scanner = ag.ModuleScanner(app, obj) - ns['members'] = scanner.scan(imported_members) - ns['functions'], ns['all_functions'] = \ - get_members(obj, {'function'}, imported=imported_members) - ns['classes'], ns['all_classes'] = \ - get_members(obj, {'class'}, imported=imported_members) - ns['exceptions'], ns['all_exceptions'] = \ - get_members(obj, {'exception'}, imported=imported_members) - ns['attributes'], ns['all_attributes'] = \ - get_module_attrs(ns['members']) - ispackage = hasattr(obj, '__path__') - if ispackage and recursive: - ns['modules'], ns['all_modules'] = get_modules(obj) + return get_module_members(obj) elif doc.objtype == 'class': - ns['members'] = dir(obj) - ns['inherited_members'] = \ - set(dir(obj)) - set(obj.__dict__.keys()) - ns['methods'], ns['all_methods'] = \ - get_members(obj, {'method'}, ['__init__']) - ns['attributes'], ns['all_attributes'] = \ - get_members(obj, {'attribute', 'property'}) - ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys()) - - if modname is None or qualname is None: - modname, qualname = ag.split_full_qualified_name(name) - - if doc.objtype in ('method', 'attribute', 'property'): - ns['class'] = qualname.rsplit(".", 1)[0] - - if doc.objtype in ('class',): - shortname = qualname - else: - shortname = qualname.rsplit(".", 1)[-1] - - ns['fullname'] = name - ns['module'] = modname - ns['objname'] = qualname - ns['name'] = shortname - - ns['objtype'] = doc.objtype - ns['underline'] = len(name) * '=' - - if template_name: - return template.render(template_name, ns) - else: - return template.render(doc.objtype, ns) + return get_class_members(obj) + return {} + + def get_members( + obj: Any, types: Set[str], include_public: List[str] = [], imported: bool = True + ) -> Tuple[List[str], List[str]]: + items: List[str] = [] + public: List[str] = [] + + all_members = get_all_members(obj) + for name, value in all_members.items(): + documenter = ag.get_documenter(app, value, obj) + if documenter.objtype in types: + # skip imported members if expected + if imported or getattr(value, '__module__', None) == obj.__name__: + skipped = skip_member(value, name, documenter.objtype) + if skipped is True: + pass + elif skipped is False: + # show the member forcedly + items.append(name) + public.append(name) + else: + items.append(name) + if name in include_public or not name.startswith('_'): + # considers member as public + public.append(name) + return public, items + + def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: + """Find module attributes with docstrings.""" + attrs, public = [], [] + try: + analyzer = ag.ModuleAnalyzer.for_module(name) + attr_docs = analyzer.find_attr_docs() + for namespace, attr_name in attr_docs: + if namespace == '' and attr_name in members: + attrs.append(attr_name) + if not attr_name.startswith('_'): + public.append(attr_name) + except ag.PycodeError: + pass # give up if ModuleAnalyzer fails to parse code + return public, attrs + + def get_modules(obj: Any) -> Tuple[List[str], List[str]]: + items: List[str] = [] + for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__): + fullname = name + '.' + modname + try: + module = ag.import_module(fullname) + if module and hasattr(module, '__sphinx_mock__'): + continue + except ImportError: + pass + + items.append(fullname) + public = [x for x in items if not x.split('.')[-1].startswith('_')] + return public, items + + ns: Dict[str, Any] = {} + ns.update(context) + + if doc.objtype == 'module': + scanner = ag.ModuleScanner(app, obj) + ns['members'] = scanner.scan(imported_members) + ns['functions'], ns['all_functions'] = get_members( + obj, {'function'}, imported=imported_members + ) + ns['classes'], ns['all_classes'] = get_members( + obj, {'class'}, imported=imported_members + ) + ns['exceptions'], ns['all_exceptions'] = get_members( + obj, {'exception'}, imported=imported_members + ) + ns['attributes'], ns['all_attributes'] = get_module_attrs(ns['members']) + ispackage = hasattr(obj, '__path__') + if ispackage and recursive: + ns['modules'], ns['all_modules'] = get_modules(obj) + elif doc.objtype == 'class': + ns['members'] = dir(obj) + ns['inherited_members'] = set(dir(obj)) - set(obj.__dict__.keys()) + ns['methods'], ns['all_methods'] = get_members(obj, {'method'}, ['__init__']) + ns['attributes'], ns['all_attributes'] = get_members(obj, {'attribute', 'property'}) + ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys()) + + if modname is None or qualname is None: + modname, qualname = ag.split_full_qualified_name(name) + + if doc.objtype in ('method', 'attribute', 'property'): + ns['class'] = qualname.rsplit('.', 1)[0] + + if doc.objtype in ('class',): + shortname = qualname + else: + shortname = qualname.rsplit('.', 1)[-1] + + ns['fullname'] = name + ns['module'] = modname + ns['objname'] = qualname + ns['name'] = shortname + + ns['objtype'] = doc.objtype + ns['underline'] = len(name) * '=' + + if template_name: + return template.render(template_name, ns) + else: + return template.render(doc.objtype, ns) + ag.generate_autosummary_content = generate_autosummary_content diff --git a/examples/cloud/launch_gce.py b/examples/cloud/launch_gce.py index 2cfc01ed41..5e265b71f2 100644 --- a/examples/cloud/launch_gce.py +++ b/examples/cloud/launch_gce.py @@ -30,16 +30,20 @@ 'dry_run', False, help='If set, then the command to launch the GCE instance will only be ' - 'printed to stdout.') + 'printed to stdout.', +) flags.DEFINE_bool( 'connect', False, - help='Same as --wait, but directly connect to VM once it is ready.') + help='Same as --wait, but directly connect to VM once it is ready.', +) flags.DEFINE_bool( - 'wait', False, + 'wait', + False, help='If set, then the script will wait until VM is ready. If VM_READY_CMD ' 'is set in environment, then that command will be executed once the VM ' - 'is ready. Useful for sending a notification, e.g. "osascript" (mac).') + 'is ready. Useful for sending a notification, e.g. "osascript" (mac).', +) # Machine configuration. flags.DEFINE_string('project', None, help='Name of the Google Cloud project.') @@ -47,60 +51,63 @@ flags.DEFINE_string( 'machine_type', None, - help='Machine type to use for VM. See "gcloud compute machine-types list".') + help='Machine type to use for VM. See "gcloud compute machine-types list".', +) flags.DEFINE_string( 'accelerator_type', '', help='Type of accelerator to use, or empty. ' - 'See "gcloud compute accelerator-types list".' + 'See "gcloud compute accelerator-types list".', ) flags.DEFINE_integer( 'shutdown_secs', 300, help='How long to wait (after successful/failed training) before shutting ' - 'down the VM. Set to 0 to disable.' + 'down the VM. Set to 0 to disable.', ) -flags.DEFINE_integer( - 'accelerator_count', 8, help='Number of accelerators to use.') +flags.DEFINE_integer('accelerator_count', 8, help='Number of accelerators to use.') # GCS configuration. flags.DEFINE_string( 'gcs_workdir_base', None, help='GCS base directory for model output. The --workdir argument will be ' - 'constructed from {gcs_workdir_base}/{example}/{name}/{timestamp} .') + 'constructed from {gcs_workdir_base}/{example}/{name}/{timestamp} .', +) flags.DEFINE_string( 'tfds_data_dir', '', help='Optional tfds data directory. This can be useful to prepare datasets ' 'on GCS and then point the jobs to this preloaded directory. Dataset will ' - 'be downloaded from the web if not specified.') + 'be downloaded from the web if not specified.', +) # Repo configuration. -flags.DEFINE_string( - 'repo', 'https://github.com/google/flax', help='Git repository') +flags.DEFINE_string('repo', 'https://github.com/google/flax', help='Git repository') flags.DEFINE_string('branch', 'main', help='Git repository') # Example configuration. -flags.DEFINE_string( - 'example', None, help='Name of Flax example (e.g. "imagenet").') +flags.DEFINE_string('example', None, help='Name of Flax example (e.g. "imagenet").') flags.DEFINE_string( 'args', '', help='Any additional command line arguments for {example}_main.py, like ' 'for example --config. Note that --workdir will be provided by the ' - 'script.') + 'script.', +) # Run configuration. flags.DEFINE_string( 'name', None, help='Name of the experiment. Note that the provided name will be ' - 'extended to {example}/{name}/{timestamp}') + 'extended to {example}/{name}/{timestamp}', +) FLAGS = flags.FLAGS flags.mark_flags_as_required( - ['project', 'zone', 'machine_type', 'gcs_workdir_base', 'example', 'name']) + ['project', 'zone', 'machine_type', 'gcs_workdir_base', 'example', 'name'] +) timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') @@ -122,7 +129,7 @@ def generate_startup_file(vm_name: str) -> str: ('__GCS_WORKDIR_BASE__', FLAGS.gcs_workdir_base), ('__TFDS_DATA_DIR__', FLAGS.tfds_data_dir), ('__ACCELERATOR_TYPE__', FLAGS.accelerator_type), - ('__SHUTDOWN_SECS__', str(FLAGS.shutdown_secs)) + ('__SHUTDOWN_SECS__', str(FLAGS.shutdown_secs)), ): startup_script_content = startup_script_content.replace(from_str, to_str) with open(startup_script_dst, 'w', encoding='utf8') as f: @@ -134,13 +141,21 @@ def launch_gce(*, vm_name: str, startup_script: str): # Note : Use `gcloud compute images list --project ml-images` to get a list # of available VM images. args = [ - 'gcloud', 'compute', 'instances', 'create', vm_name, - f'--project={FLAGS.project}', f'--zone={FLAGS.zone}', + 'gcloud', + 'compute', + 'instances', + 'create', + vm_name, + f'--project={FLAGS.project}', + f'--zone={FLAGS.zone}', '--image=c1-deeplearning-tf-2-10-cu113-v20221107-debian-10', - '--image-project=ml-images', f'--machine-type={FLAGS.machine_type}', - '--scopes=cloud-platform,storage-full', '--boot-disk-size=256GB', - '--boot-disk-type=pd-ssd', '--metadata=install-nvidia-driver=True', - f'--metadata-from-file=startup-script={startup_script}' + '--image-project=ml-images', + f'--machine-type={FLAGS.machine_type}', + '--scopes=cloud-platform,storage-full', + '--boot-disk-size=256GB', + '--boot-disk-type=pd-ssd', + '--metadata=install-nvidia-driver=True', + f'--metadata-from-file=startup-script={startup_script}', ] if FLAGS.accelerator_type and FLAGS.accelerator_count: args.extend([ @@ -165,7 +180,8 @@ def launch_gce(*, vm_name: str, startup_script: str): def print_howto(login_args: Sequence[str]): - print(f''' + print( + f""" ############################################################################### ############################################################################### @@ -189,7 +205,8 @@ def print_howto(login_args: Sequence[str]): ############################################################################### ############################################################################### -''') +""" + ) def main(_): @@ -237,8 +254,7 @@ def main(_): login_true_args = login_args[:-1] + ['true'] while True: try: - result = subprocess.run( - login_true_args, timeout=10, stderr=subprocess.PIPE) + result = subprocess.run(login_true_args, timeout=10, stderr=subprocess.PIPE) if result.returncode == 0: break stderr = result.stderr.decode('utf8') diff --git a/examples/imagenet/imagenet_benchmark.py b/examples/imagenet/imagenet_benchmark.py index 863e3dc3b4..9f010d7912 100644 --- a/examples/imagenet/imagenet_benchmark.py +++ b/examples/imagenet/imagenet_benchmark.py @@ -38,8 +38,7 @@ class ImagenetBenchmark(Benchmark): """Benchmarks for the ImageNet Flax example.""" @flagsaver - def _test_8x_v100_half_precision(self, num_epochs: int, min_accuracy, - max_accuracy): + def _test_8x_v100_half_precision(self, num_epochs: int, min_accuracy, max_accuracy): """Utility to benchmark ImageNet on 8xV100 GPUs. Use in your test func.""" # Prepare and set flags defined in main.py. config = config_lib.get_config() @@ -68,13 +67,13 @@ def _test_8x_v100_half_precision(self, num_epochs: int, min_accuracy, # Use the reporting API to report single or multiple metrics/extras. self.report_wall_time(benchmark_time) - self.report_metrics({'sec_per_epoch': sec_per_epoch, - 'accuracy': end_accuracy}) + self.report_metrics({'sec_per_epoch': sec_per_epoch, 'accuracy': end_accuracy}) def test_8x_v100_half_precision_short(self): """Run ImageNet on 8x V100 GPUs in half precision for 2 epochs.""" self._test_8x_v100_half_precision( - num_epochs=2, min_accuracy=0.06, max_accuracy=0.09) + num_epochs=2, min_accuracy=0.06, max_accuracy=0.09 + ) self.report_extras({ 'description': 'Short (2 epochs) 8 x V100 test for ImageNet ResNet50.', 'model_name': 'resnet50', @@ -85,7 +84,8 @@ def test_8x_v100_half_precision_short(self): def test_8x_v100_half_precision_full(self): """Run ImageNet on 8x V100 GPUs in half precision for full 90 epochs.""" self._test_8x_v100_half_precision( - num_epochs=90, min_accuracy=0.76, max_accuracy=0.77) + num_epochs=90, min_accuracy=0.76, max_accuracy=0.77 + ) self.report_extras({ 'description': 'Full (90 epochs) 8 x V100 test for ImageNet ResNet50.', 'model_name': 'resnet50', diff --git a/examples/imagenet/input_pipeline.py b/examples/imagenet/input_pipeline.py index b9145b560b..564eaa653e 100644 --- a/examples/imagenet/input_pipeline.py +++ b/examples/imagenet/input_pipeline.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ImageNet input pipeline. -""" +"""ImageNet input pipeline.""" import jax import tensorflow as tf @@ -26,12 +25,14 @@ STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] -def distorted_bounding_box_crop(image_bytes, - bbox, - min_object_covered=0.1, - aspect_ratio_range=(0.75, 1.33), - area_range=(0.05, 1.0), - max_attempts=100): +def distorted_bounding_box_crop( + image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100, +): """Generates cropped_image using one of the bboxes randomly distorted. See `tf.image.sample_distorted_bounding_box` for more documentation. @@ -63,7 +64,8 @@ def distorted_bounding_box_crop(image_bytes, aspect_ratio_range=aspect_ratio_range, area_range=area_range, max_attempts=max_attempts, - use_image_if_no_bounding_boxes=True) + use_image_if_no_bounding_boxes=True, + ) bbox_begin, bbox_size, _ = sample_distorted_bounding_box # Crop the image to the specified bounding box. @@ -76,8 +78,9 @@ def distorted_bounding_box_crop(image_bytes, def _resize(image, image_size): - return tf.image.resize([image], [image_size, image_size], - method=tf.image.ResizeMethod.BICUBIC)[0] + return tf.image.resize( + [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC + )[0] def _at_least_x_are_equal(a, b, x): @@ -94,16 +97,18 @@ def _decode_and_random_crop(image_bytes, image_size): image_bytes, bbox, min_object_covered=0.1, - aspect_ratio_range=(3. / 4, 4. / 3.), + aspect_ratio_range=(3.0 / 4, 4.0 / 3.0), area_range=(0.08, 1.0), - max_attempts=10) + max_attempts=10, + ) original_shape = tf.io.extract_jpeg_shape(image_bytes) bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) image = tf.cond( bad, lambda: _decode_and_center_crop(image_bytes, image_size), - lambda: _resize(image, image_size)) + lambda: _resize(image, image_size), + ) return image @@ -115,14 +120,18 @@ def _decode_and_center_crop(image_bytes, image_size): image_width = shape[1] padded_center_crop_size = tf.cast( - ((image_size / (image_size + CROP_PADDING)) * - tf.cast(tf.minimum(image_height, image_width), tf.float32)), - tf.int32) + ( + (image_size / (image_size + CROP_PADDING)) + * tf.cast(tf.minimum(image_height, image_width), tf.float32) + ), + tf.int32, + ) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 - crop_window = tf.stack([offset_height, offset_width, - padded_center_crop_size, padded_center_crop_size]) + crop_window = tf.stack( + [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size] + ) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = _resize(image, image_size) @@ -172,9 +181,16 @@ def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE): return image -def create_split(dataset_builder, batch_size, train, dtype=tf.float32, - image_size=IMAGE_SIZE, cache=False, shuffle_buffer_size=2_000, - prefetch=10): +def create_split( + dataset_builder, + batch_size, + train, + dtype=tf.float32, + image_size=IMAGE_SIZE, + cache=False, + shuffle_buffer_size=2_000, + prefetch=10, +): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: @@ -207,9 +223,12 @@ def decode_example(example): image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} - ds = dataset_builder.as_dataset(split=split, decoders={ - 'image': tfds.decode.SkipDecoding(), - }) + ds = dataset_builder.as_dataset( + split=split, + decoders={ + 'image': tfds.decode.SkipDecoding(), + }, + ) options = tf.data.Options() options.experimental_threading.private_threadpool_size = 48 ds = ds.with_options(options) diff --git a/examples/imagenet/main.py b/examples/imagenet/main.py index 4f2f8f1db1..2ad853f964 100644 --- a/examples/imagenet/main.py +++ b/examples/imagenet/main.py @@ -36,7 +36,8 @@ 'config', None, 'File path to the training hyperparameter configuration.', - lock_config=True) + lock_config=True, +) def main(argv): @@ -52,10 +53,12 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}') - platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, - FLAGS.workdir, 'workdir') + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' + ) + platform.work_unit().create_artifact( + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/imagenet/models.py b/examples/imagenet/models.py index 70d195c32b..86d4ee6046 100644 --- a/examples/imagenet/models.py +++ b/examples/imagenet/models.py @@ -28,6 +28,7 @@ class ResNetBlock(nn.Module): """ResNet block.""" + filters: int conv: ModuleDef norm: ModuleDef @@ -35,7 +36,10 @@ class ResNetBlock(nn.Module): strides: Tuple[int, int] = (1, 1) @nn.compact - def __call__(self, x,): + def __call__( + self, + x, + ): residual = x y = self.conv(self.filters, (3, 3), self.strides)(x) y = self.norm()(y) @@ -44,8 +48,9 @@ def __call__(self, x,): y = self.norm(scale_init=nn.initializers.zeros_init())(y) if residual.shape != y.shape: - residual = self.conv(self.filters, (1, 1), - self.strides, name='conv_proj')(residual) + residual = self.conv(self.filters, (1, 1), self.strides, name='conv_proj')( + residual + ) residual = self.norm(name='norm_proj')(residual) return self.act(residual + y) @@ -53,6 +58,7 @@ def __call__(self, x,): class BottleneckResNetBlock(nn.Module): """Bottleneck ResNet block.""" + filters: int conv: ModuleDef norm: ModuleDef @@ -72,8 +78,9 @@ def __call__(self, x): y = self.norm(scale_init=nn.initializers.zeros_init())(y) if residual.shape != y.shape: - residual = self.conv(self.filters * 4, (1, 1), - self.strides, name='conv_proj')(residual) + residual = self.conv(self.filters * 4, (1, 1), self.strides, name='conv_proj')( + residual + ) residual = self.norm(name='norm_proj')(residual) return self.act(residual + y) @@ -81,6 +88,7 @@ def __call__(self, x): class ResNet(nn.Module): """ResNetV1.""" + stage_sizes: Sequence[int] block_cls: ModuleDef num_classes: int @@ -92,53 +100,52 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x, train: bool = True): conv = partial(self.conv, use_bias=False, dtype=self.dtype) - norm = partial(nn.BatchNorm, - use_running_average=not train, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype, - axis_name='batch') - - - x = conv(self.num_filters, (7, 7), (2, 2), - padding=[(3, 3), (3, 3)], - name='conv_init')(x) + norm = partial( + nn.BatchNorm, + use_running_average=not train, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + axis_name='batch', + ) + + x = conv( + self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init' + )(x) x = norm(name='bn_init')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) - x = self.block_cls(self.num_filters * 2 ** i, - strides=strides, - conv=conv, - norm=norm, - act=self.act)(x) + x = self.block_cls( + self.num_filters * 2**i, + strides=strides, + conv=conv, + norm=norm, + act=self.act, + )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, dtype=self.dtype)(x) x = jnp.asarray(x, self.dtype) return x -ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], - block_cls=ResNetBlock) -ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], - block_cls=ResNetBlock) -ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], - block_cls=BottleneckResNetBlock) -ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], - block_cls=BottleneckResNetBlock) -ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], - block_cls=BottleneckResNetBlock) -ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], - block_cls=BottleneckResNetBlock) +ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock) +ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) +ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock) +ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock) +ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock) +ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock) -ResNet18Local = partial(ResNet, stage_sizes=[2, 2, 2, 2], - block_cls=ResNetBlock, conv=nn.ConvLocal) +ResNet18Local = partial( + ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock, conv=nn.ConvLocal +) # Used for testing only. _ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock) -_ResNet1Local = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock, - conv=nn.ConvLocal) +_ResNet1Local = partial( + ResNet, stage_sizes=[1], block_cls=ResNetBlock, conv=nn.ConvLocal +) diff --git a/examples/imagenet/models_test.py b/examples/imagenet/models_test.py index 0476e428f0..efe5118769 100644 --- a/examples/imagenet/models_test.py +++ b/examples/imagenet/models_test.py @@ -33,8 +33,7 @@ def test_resnet_v1_model(self): """Tests ResNet V1 model definition and output (variables).""" rng = jax.random.PRNGKey(0) model_def = models.ResNet50(num_classes=10, dtype=jnp.float32) - variables = model_def.init( - rng, jnp.ones((8, 224, 224, 3), jnp.float32)) + variables = model_def.init(rng, jnp.ones((8, 224, 224, 3), jnp.float32)) self.assertLen(variables, 2) # Resnet50 model will create parameters for the following layers: @@ -43,15 +42,12 @@ def test_resnet_v1_model(self): # Followed by a Dense layer = 1 self.assertLen(variables['params'], 19) - @parameterized.product( - model=(models.ResNet18, models.ResNet18Local) - ) + @parameterized.product(model=(models.ResNet18, models.ResNet18Local)) def test_resnet_18_v1_model(self, model): """Tests ResNet18 V1 model definition and output (variables).""" rng = jax.random.PRNGKey(0) model_def = model(num_classes=2, dtype=jnp.float32) - variables = model_def.init( - rng, jnp.ones((1, 64, 64, 3), jnp.float32)) + variables = model_def.init(rng, jnp.ones((1, 64, 64, 3), jnp.float32)) self.assertLen(variables, 2) self.assertLen(variables['params'], 11) diff --git a/examples/imagenet/train.py b/examples/imagenet/train.py index 85ef10f056..75762ef3b1 100644 --- a/examples/imagenet/train.py +++ b/examples/imagenet/train.py @@ -61,9 +61,11 @@ def create_model(*, model_cls, half_precision, **kwargs): def initialized(key, image_size, model): input_shape = (1, image_size, image_size, 3) + @jax.jit def init(*args): return model.init(*args) + variables = init({'params': key}, jnp.ones(input_shape, model.dtype)) return variables['params'], variables['batch_stats'] @@ -86,37 +88,39 @@ def compute_metrics(logits, labels): def create_learning_rate_fn( - config: ml_collections.ConfigDict, - base_learning_rate: float, - steps_per_epoch: int): + config: ml_collections.ConfigDict, base_learning_rate: float, steps_per_epoch: int +): """Create learning rate schedule.""" warmup_fn = optax.linear_schedule( - init_value=0., end_value=base_learning_rate, - transition_steps=config.warmup_epochs * steps_per_epoch) + init_value=0.0, + end_value=base_learning_rate, + transition_steps=config.warmup_epochs * steps_per_epoch, + ) cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) + init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch + ) schedule_fn = optax.join_schedules( schedules=[warmup_fn, cosine_fn], - boundaries=[config.warmup_epochs * steps_per_epoch]) + boundaries=[config.warmup_epochs * steps_per_epoch], + ) return schedule_fn def train_step(state, batch, learning_rate_fn): """Perform a single training step.""" + def loss_fn(params): """loss function used for training.""" logits, new_model_state = state.apply_fn( {'params': params, 'batch_stats': state.batch_stats}, batch['image'], - mutable=['batch_stats']) + mutable=['batch_stats'], + ) loss = cross_entropy_loss(logits, batch['label']) weight_penalty_params = jax.tree_util.tree_leaves(params) weight_decay = 0.0001 - weight_l2 = sum(jnp.sum(x ** 2) - for x in weight_penalty_params - if x.ndim > 1) + weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) weight_penalty = weight_decay * 0.5 * weight_l2 loss = loss + weight_penalty return loss, (new_model_state, logits) @@ -126,8 +130,7 @@ def loss_fn(params): lr = learning_rate_fn(step) if dynamic_scale: - grad_fn = dynamic_scale.value_and_grad( - loss_fn, has_aux=True, axis_name='batch') + grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True, axis_name='batch') dynamic_scale, is_fin, aux, grads = grad_fn(state.params) # dynamic loss takes care of averaging gradients across replicas else: @@ -140,20 +143,20 @@ def loss_fn(params): metrics['learning_rate'] = lr new_state = state.apply_gradients( - grads=grads, batch_stats=new_model_state['batch_stats']) + grads=grads, batch_stats=new_model_state['batch_stats'] + ) if dynamic_scale: # if is_fin == False the gradients contain Inf/NaNs and optimizer state and # params should be restored (= skip this step). new_state = new_state.replace( opt_state=jax.tree_util.tree_map( - functools.partial(jnp.where, is_fin), - new_state.opt_state, - state.opt_state), + functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state + ), params=jax.tree_util.tree_map( - functools.partial(jnp.where, is_fin), - new_state.params, - state.params), - dynamic_scale=dynamic_scale) + functools.partial(jnp.where, is_fin), new_state.params, state.params + ), + dynamic_scale=dynamic_scale, + ) metrics['scale'] = dynamic_scale.scale return new_state, metrics @@ -161,14 +164,14 @@ def loss_fn(params): def eval_step(state, batch): variables = {'params': state.params, 'batch_stats': state.batch_stats} - logits = state.apply_fn( - variables, batch['image'], train=False, mutable=False) + logits = state.apply_fn(variables, batch['image'], train=False, mutable=False) return compute_metrics(logits, batch['label']) def prepare_tf_data(xs): """Convert a input batch from tf Tensors to numpy arrays.""" local_device_count = jax.local_device_count() + def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. x = x._numpy() # pylint: disable=protected-access @@ -180,12 +183,26 @@ def _prepare(x): return jax.tree_util.tree_map(_prepare, xs) -def create_input_iter(dataset_builder, batch_size, image_size, dtype, train, - cache, shuffle_buffer_size, prefetch): +def create_input_iter( + dataset_builder, + batch_size, + image_size, + dtype, + train, + cache, + shuffle_buffer_size, + prefetch, +): ds = input_pipeline.create_split( - dataset_builder, batch_size, image_size=image_size, dtype=dtype, - train=train, cache=cache, shuffle_buffer_size=shuffle_buffer_size, - prefetch=prefetch) + dataset_builder, + batch_size, + image_size=image_size, + dtype=dtype, + train=train, + cache=cache, + shuffle_buffer_size=shuffle_buffer_size, + prefetch=prefetch, + ) it = map(prepare_tf_data, ds) it = jax_utils.prefetch_to_device(it, 2) return it @@ -219,8 +236,9 @@ def sync_batch_stats(state): return state.replace(batch_stats=cross_replica_mean(state.batch_stats)) -def create_train_state(rng, config: ml_collections.ConfigDict, - model, image_size, learning_rate_fn): +def create_train_state( + rng, config: ml_collections.ConfigDict, model, image_size, learning_rate_fn +): """Create initial training state.""" dynamic_scale = None platform = jax.local_devices()[0].platform @@ -240,12 +258,12 @@ def create_train_state(rng, config: ml_collections.ConfigDict, params=params, tx=tx, batch_stats=batch_stats, - dynamic_scale=dynamic_scale) + dynamic_scale=dynamic_scale, + ) return state -def train_and_evaluate(config: ml_collections.ConfigDict, - workdir: str) -> TrainState: +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState: """Execute model training and evaluation loop. Args: @@ -257,7 +275,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, """ writer = metric_writers.create_default_writer( - logdir=workdir, just_logging=jax.process_index() != 0) + logdir=workdir, just_logging=jax.process_index() != 0 + ) rng = random.PRNGKey(0) @@ -279,12 +298,25 @@ def train_and_evaluate(config: ml_collections.ConfigDict, dataset_builder = tfds.builder(config.dataset) train_iter = create_input_iter( - dataset_builder, local_batch_size, image_size, input_dtype, train=True, - cache=config.cache, shuffle_buffer_size=config.shuffle_buffer_size, - prefetch=config.prefetch) + dataset_builder, + local_batch_size, + image_size, + input_dtype, + train=True, + cache=config.cache, + shuffle_buffer_size=config.shuffle_buffer_size, + prefetch=config.prefetch, + ) eval_iter = create_input_iter( - dataset_builder, local_batch_size, image_size, input_dtype, train=False, - cache=config.cache, shuffle_buffer_size=None, prefetch=config.prefetch) + dataset_builder, + local_batch_size, + image_size, + input_dtype, + train=False, + cache=config.cache, + shuffle_buffer_size=None, + prefetch=config.prefetch, + ) steps_per_epoch = ( dataset_builder.info.splits['train'].num_examples // config.batch_size @@ -296,22 +328,21 @@ def train_and_evaluate(config: ml_collections.ConfigDict, num_steps = config.num_train_steps if config.steps_per_eval == -1: - num_validation_examples = dataset_builder.info.splits[ - 'validation'].num_examples + num_validation_examples = dataset_builder.info.splits['validation'].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 - base_learning_rate = config.learning_rate * config.batch_size / 256. + base_learning_rate = config.learning_rate * config.batch_size / 256.0 model_cls = getattr(models, config.model) - model = create_model( - model_cls=model_cls, half_precision=config.half_precision) + model = create_model(model_cls=model_cls, half_precision=config.half_precision) learning_rate_fn = create_learning_rate_fn( - config, base_learning_rate, steps_per_epoch) + config, base_learning_rate, steps_per_epoch + ) state = create_train_state(rng, config, model, image_size, learning_rate_fn) state = restore_checkpoint(state, workdir) @@ -321,7 +352,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn), - axis_name='batch') + axis_name='batch', + ) p_eval_step = jax.pmap(eval_step, axis_name='batch') train_metrics = [] @@ -343,10 +375,13 @@ def train_and_evaluate(config: ml_collections.ConfigDict, train_metrics = common_utils.get_metrics(train_metrics) summary = { f'train_{k}': v - for k, v in jax.tree_util.tree_map(lambda x: x.mean(), train_metrics).items() + for k, v in jax.tree_util.tree_map( + lambda x: x.mean(), train_metrics + ).items() } summary['steps_per_second'] = config.log_every_steps / ( - time.time() - train_metrics_last_t) + time.time() - train_metrics_last_t + ) writer.write_scalars(step + 1, summary) train_metrics = [] train_metrics_last_t = time.time() @@ -363,10 +398,15 @@ def train_and_evaluate(config: ml_collections.ConfigDict, eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics) - logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', - epoch, summary['loss'], summary['accuracy'] * 100) + logging.info( + 'eval epoch: %d, loss: %.4f, accuracy: %.2f', + epoch, + summary['loss'], + summary['accuracy'] * 100, + ) writer.write_scalars( - step + 1, {f'eval_{key}': val for key, val in summary.items()}) + step + 1, {f'eval_{key}': val for key, val in summary.items()} + ) writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) diff --git a/examples/imagenet/train_test.py b/examples/imagenet/train_test.py index 8f7def7a9d..8dd86b723f 100644 --- a/examples/imagenet/train_test.py +++ b/examples/imagenet/train_test.py @@ -62,9 +62,7 @@ def test_create_model_local(self): y = model.apply(variables, x, train=False) self.assertEqual(y.shape, (1, 1000)) - @parameterized.product( - model=('_ResNet1', '_ResNet1Local') - ) + @parameterized.product(model=('_ResNet1', '_ResNet1Local')) def test_train_and_evaluate(self, model): """Tests training and evaluation loop using mocked data.""" # Create a temporary directory where tensorboard metrics are written. diff --git a/examples/linen_design_test/attention_simple.py b/examples/linen_design_test/attention_simple.py index 209d33bc77..cf339223f6 100644 --- a/examples/linen_design_test/attention_simple.py +++ b/examples/linen_design_test/attention_simple.py @@ -25,7 +25,6 @@ import numpy as np - class Dense(Module): features: int use_bias: bool = True @@ -37,12 +36,11 @@ class Dense(Module): @compact def __call__(self, inputs): inputs = jnp.asarray(inputs, self.dtype) - kernel = self.param('kernel', self.kernel_init, - (inputs.shape[-1], self.features)) + kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) kernel = jnp.asarray(kernel, self.dtype) - y = lax.dot_general(inputs, kernel, - (((inputs.ndim - 1,), (0,)), ((), ())), - precision=self.precision) + y = lax.dot_general( + inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision + ) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) bias = jnp.asarray(bias, self.dtype) @@ -51,6 +49,7 @@ def __call__(self, inputs): class SoftmaxAttn(Module): + @compact def __call__(self, weights): norm_dims = tuple(range(weights.ndim // 2, weights.ndim)) @@ -62,9 +61,9 @@ class Dropout(Module): @compact def __call__(self, x, deterministic=False, rng=None): - if self.rate == 0.: + if self.rate == 0.0: return x - keep_prob = 1. - self.rate + keep_prob = 1.0 - self.rate if deterministic: return x @@ -95,20 +94,14 @@ def __call__(self, query, key, value, bias=None, dtype=jnp.float32): assert key.ndim == value.ndim n = query.ndim - attn_weights = lax.dot_general( - query, key, - (((n-1,), (n - 1,)), ((), ()))) + attn_weights = lax.dot_general(query, key, (((n - 1,), (n - 1,)), ((), ()))) if bias is not None: attn_weights += bias attn_weights = self.attn_module()(attn_weights) attn_weights = attn_weights.astype(dtype) - contract_dims = ( - tuple(range(n - 1, attn_weights.ndim)), - tuple(range(0, n - 1))) - y = lax.dot_general( - attn_weights, value, - (contract_dims, ((), ()))) + contract_dims = (tuple(range(n - 1, attn_weights.ndim)), tuple(range(0, n - 1))) + y = lax.dot_general(attn_weights, value, (contract_dims, ((), ()))) return y @@ -123,29 +116,33 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): out_features = self.out_features or inputs_q.shape[-1] QKVDense = functools.partial( - Dense, features=qkv_features, use_bias=False, dtype=dtype) + Dense, features=qkv_features, use_bias=False, dtype=dtype + ) query = QKVDense(name='query')(inputs_q) key = QKVDense(name='key')(inputs_kv) value = QKVDense(name='value')(inputs_kv) y = RawDotProductAttention(attn_module=self.attn_module)( - query, key, value, bias=bias, dtype=dtype) + query, key, value, bias=bias, dtype=dtype + ) y = Dense(features=out_features, dtype=dtype, name='out')(y) return y # Trying out a slightly more compact vmap notation: + def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs): - variable_axes = {k: v[0] for k, v in - var_specs.items() if isinstance(v, Sequence)} + variable_axes = {k: v[0] for k, v in var_specs.items() if isinstance(v, Sequence)} splits = {k: v[1] for k, v in var_specs.items() if isinstance(v, Sequence)} - return vmap(module, - in_axes=in_axes, - out_axes=out_axes, - variable_axes=variable_axes, - split_rngs=splits, - axis_size=axis_size) + return vmap( + module, + in_axes=in_axes, + out_axes=out_axes, + variable_axes=variable_axes, + split_rngs=splits, + axis_size=axis_size, + ) class MultiHeadDotProductAttention(Module): @@ -162,20 +159,28 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): out_features = self.out_features or inputs_q.shape[-1] # Now, vmap attn.__call__ along heads and spatial dims. - Attn = concise_vmap(DotProductAttention, - (None, None, None), -2, - param=(0, True), - dropout=(None, not self.broadcast_dropout), - axis_size=self.num_heads) + Attn = concise_vmap( + DotProductAttention, + (None, None, None), + -2, + param=(0, True), + dropout=(None, not self.broadcast_dropout), + axis_size=self.num_heads, + ) for axis in reversed(sorted(self.batch_axes)): - Attn = concise_vmap(Attn, - (axis, axis, axis), axis, - param=(None, False), - dropout=(None, not self.broadcast_dropout)) - - attn = Attn(attn_module=self.attn_module, - qkv_features=qkv_features // self.num_heads, - out_features=out_features) + Attn = concise_vmap( + Attn, + (axis, axis, axis), + axis, + param=(None, False), + dropout=(None, not self.broadcast_dropout), + ) + + attn = Attn( + attn_module=self.attn_module, + qkv_features=qkv_features // self.num_heads, + out_features=out_features, + ) # evaluate multi-headed-attention. y = attn(inputs_q, inputs_kv, bias) @@ -186,7 +191,6 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): if __name__ == '__main__': - inputs = jnp.ones((8, 97, 256)) rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} model = MultiHeadDotProductAttention( @@ -195,7 +199,8 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): out_features=256, attn_module=functools.partial(SoftmaxAttnWDropout, rate=0.1), num_heads=8, - batch_axes=(0,),) + batch_axes=(0,), + ) y, params = model.init_with_output(rngs, inputs, inputs) diff --git a/examples/linen_design_test/autoencoder.py b/examples/linen_design_test/autoencoder.py index cc16acfef8..25fd1398f5 100644 --- a/examples/linen_design_test/autoencoder.py +++ b/examples/linen_design_test/autoencoder.py @@ -22,7 +22,6 @@ from flax.linen import Module, Dense, compact - # A concise MLP defined via lazy submodule initialization class MLP(Module): widths: Iterable @@ -44,13 +43,13 @@ class AutoEncoder(Module): def setup(self): # Submodules attached in `setup` get names via attribute assignment self.encoder = MLP(self.encoder_widths) - self.decoder = MLP(self.decoder_widths + (jnp.prod(self.input_shape), )) + self.decoder = MLP(self.decoder_widths + (jnp.prod(self.input_shape),)) def __call__(self, x): return self.decode(self.encode(x)) def encode(self, x): - assert x.shape[-len(self.input_shape):] == self.input_shape + assert x.shape[-len(self.input_shape) :] == self.input_shape return self.encoder(jnp.reshape(x, (x.shape[0], -1))) def decode(self, z): @@ -62,21 +61,20 @@ def decode(self, z): # `ae` is a detached module, which has no variables. ae = AutoEncoder( - encoder_widths=(32, 32, 32), - decoder_widths=(32, 32, 32), - input_shape=(28, 28, 1)) + encoder_widths=(32, 32, 32), decoder_widths=(32, 32, 32), input_shape=(28, 28, 1) +) # `ae.initialized` returns a materialized copy of `ae` by # running through an input to create submodules defined lazily. -params = ae.init( - {'params': random.PRNGKey(42)}, - jnp.ones((1, 28, 28, 1))) +params = ae.init({"params": random.PRNGKey(42)}, jnp.ones((1, 28, 28, 1))) # Now you can use `ae` as a normal object, calling any methods defined on AutoEncoder print("reconstruct", jnp.shape(ae.apply(params, jnp.ones((1, 28, 28, 1))))) -print("encoder", jnp.shape(ae.apply(params, jnp.ones((1, 28, 28, 1)), method=ae.encode))) +print( + "encoder", jnp.shape(ae.apply(params, jnp.ones((1, 28, 28, 1)), method=ae.encode)) +) # `ae.variables` is a frozen dict that looks like diff --git a/examples/linen_design_test/dense.py b/examples/linen_design_test/dense.py index 6874d89636..b2422b7b11 100644 --- a/examples/linen_design_test/dense.py +++ b/examples/linen_design_test/dense.py @@ -27,10 +27,12 @@ class Dense(Module): @compact def __call__(self, inputs): - kernel = self.param('kernel', self.kernel_init, - (inputs.shape[-1], self.features)) - y = lax.dot_general(inputs, kernel, - (((inputs.ndim - 1,), (0,)), ((), ())),) + kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) + y = lax.dot_general( + inputs, + kernel, + (((inputs.ndim - 1,), (0,)), ((), ())), + ) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) y = y + bias diff --git a/examples/linen_design_test/linear_regression.py b/examples/linen_design_test/linear_regression.py index 56b88f60f6..b0ed964222 100644 --- a/examples/linen_design_test/linear_regression.py +++ b/examples/linen_design_test/linear_regression.py @@ -23,18 +23,22 @@ model = Dense(features=5) + @jit def predict(params): - return model.apply({'params': params}, X) + return model.apply({"params": params}, X) + @jit def loss_fn(params): return jnp.mean(jnp.abs(Y - predict(params))) + @jit def init_params(rng): - mlp_variables = model.init({'params': rng}, X) - return mlp_variables['params'] + mlp_variables = model.init({"params": rng}, X) + return mlp_variables["params"] + # Get initial parameters params = init_params(jax.random.PRNGKey(42)) diff --git a/examples/linen_design_test/mlp_explicit.py b/examples/linen_design_test/mlp_explicit.py index 4c815ef67d..594145b86f 100644 --- a/examples/linen_design_test/mlp_explicit.py +++ b/examples/linen_design_test/mlp_explicit.py @@ -31,9 +31,16 @@ class DenseExplicit(Dense): def setup(self): # We feed a fake batch through the module, which initialized parameters. # Assuming we're in a jit, should use no FLOPs -- "just shape inference". - self.__call__(jnp.zeros((1, self.in_features, ))) + self.__call__( + jnp.zeros(( + 1, + self.in_features, + )) + ) + class MLP(Module): + def setup(self): self.dense1 = DenseExplicit(in_features=3, features=2) self.dense2 = DenseExplicit(in_features=2, features=1) @@ -44,10 +51,10 @@ def setup(self): # 'kernel': DeviceArray([[ 0.6704609 ], # [-0.90477365]], dtype=float32)}} - def __call__(self, x): return self.dense2(nn.relu(self.dense1(x))) + # Return an initialized instance of MLP by only calling `setup`. rngkey = jax.random.PRNGKey(10) init_variables = MLP().init({'params': rngkey}, jnp.ones((1, 3))) diff --git a/examples/linen_design_test/mlp_inline.py b/examples/linen_design_test/mlp_inline.py index 4382325eb8..73d525acff 100644 --- a/examples/linen_design_test/mlp_inline.py +++ b/examples/linen_design_test/mlp_inline.py @@ -30,10 +30,11 @@ class MLP(Module): @compact def __call__(self, x): for size in self.sizes[:-1]: - x = Dense(size)(x) - x = nn.relu(x) + x = Dense(size)(x) + x = nn.relu(x) return Dense(self.sizes[-1])(x) + # Return an initialized instance of MLP by calling `__call__` with an input batch, # initializing all variables. # diff --git a/examples/linen_design_test/mlp_lazy.py b/examples/linen_design_test/mlp_lazy.py index 9fb337f7fc..7283d483e4 100644 --- a/examples/linen_design_test/mlp_lazy.py +++ b/examples/linen_design_test/mlp_lazy.py @@ -24,6 +24,7 @@ # Here submodules are explicitly defined during init, but still materialized # lazily only once a first input is passed through and shapes are known. class MLP(Module): + def setup(self): self.dense1 = Dense(features=2) self.dense2 = Dense(features=1) @@ -35,6 +36,7 @@ def setup(self): def __call__(self, x): return self.dense2(nn.relu(self.dense1(x))) + # Return an initialized instance of MLP by calling `__call__` with an input batch, # initializing all variables. # @@ -50,4 +52,3 @@ def __call__(self, x): # 'dense2': {'bias': DeviceArray([0.], dtype=float32), # 'kernel': DeviceArray([[ 0.6704609 ], # [-0.90477365]], dtype=float32)}}} - diff --git a/examples/linen_design_test/weight_std.py b/examples/linen_design_test/weight_std.py index 1caffcad8e..24b90c4fe1 100644 --- a/examples/linen_design_test/weight_std.py +++ b/examples/linen_design_test/weight_std.py @@ -29,6 +29,7 @@ def standardize(x, axis, eps=1e-8): x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps) return x + # TODO(avital, levskaya): resurrect this example once interactive api is restored. # A wrapper that calls through a simple module with standardized parameters. diff --git a/examples/lm1b/input_pipeline.py b/examples/lm1b/input_pipeline.py index 8f8d69ccdc..4ac1121e7f 100644 --- a/examples/lm1b/input_pipeline.py +++ b/examples/lm1b/input_pipeline.py @@ -41,8 +41,9 @@ def __call__(self, features: Features) -> Features: return features -def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, - split: str) -> tf.data.Dataset: +def get_raw_dataset( + dataset_builder: tfds.core.DatasetBuilder, split: str +) -> tf.data.Dataset: """Loads a raw text dataset and normalizes feature keys. Args: @@ -56,17 +57,20 @@ def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, """ num_examples = dataset_builder.info.splits[split].num_examples per_host_split = deterministic_data.get_read_instruction_for_host( - split, num_examples, drop_remainder=False) + split, num_examples, drop_remainder=False + ) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map( - NormalizeFeatureNamesOp(dataset_builder.info), - num_parallel_calls=AUTOTUNE) + NormalizeFeatureNamesOp(dataset_builder.info), num_parallel_calls=AUTOTUNE + ) return ds -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str]] = None, +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. @@ -111,9 +115,10 @@ def pack_dataset(dataset: tf.data.Dataset, keys = list(shapes.keys()) for k in keys: if k not in shapes: - raise ValueError('Key %s not found in dataset. Available keys are %s' % - (k, shapes.keys())) - if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] + raise ValueError( + 'Key %s not found in dataset. Available keys are %s' % (k, shapes.keys()) + ) + if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_segmentation" and "_position" @@ -125,13 +130,12 @@ def pack_dataset(dataset: tf.data.Dataset, # trim to length dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE + ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) - dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}) + dataset = dataset.padded_batch(batch_size, padded_shapes={k: [-1] for k in keys}) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -141,8 +145,9 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops( + dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] +) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. @@ -167,7 +172,8 @@ def write_packed_example(partial, outputs): for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + ) return new_partial, new_outputs def map_fn(x): @@ -187,9 +193,11 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -206,13 +214,13 @@ def body_fn(i, partial, outputs): one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, - tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + tf.less_equal(tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]), + ) def false_fn(): return write_packed_example(partial, outputs) @@ -223,12 +231,12 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], - tf.range(new_seq_len)], 0) + [partial[k + '_position'], tf.range(new_seq_len)], 0 + ) partial = new_partial return i + 1, partial, outputs @@ -239,17 +247,17 @@ def true_fn(): loop_vars=(i, partial, outputs), shape_invariants=( tf.TensorShape([]), - {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] - {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] + {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] + {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] ), - maximum_iterations=dynamic_batch_size) + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + '_segmentation'] = tf.cumsum( + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) @@ -259,19 +267,20 @@ def true_fn(): # ----------------------------------------------------------------------------- # Main dataset prep routines. # ----------------------------------------------------------------------------- -def preprocess_data(dataset, - shuffle: bool, - num_epochs: Optional[int] = 1, - pack_examples: bool = True, - shuffle_buffer_size: int = 1024, - max_length: int = 512, - batch_size: int = 256, - drop_remainder: bool = True, - prefetch_size: int = AUTOTUNE): +def preprocess_data( + dataset, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + batch_size: int = 256, + drop_remainder: bool = True, + prefetch_size: int = AUTOTUNE, +): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): - def filter_fn(x): source, target = x['inputs'], x['targets'] l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) @@ -292,15 +301,10 @@ def filter_fn(x): else: # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size, - padded_shapes={ - 'inputs': max_length, - 'targets': max_length - }, - padding_values={ - 'inputs': 0, - 'targets': 0 - }, - drop_remainder=drop_remainder) + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=drop_remainder, + ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) @@ -308,10 +312,12 @@ def filter_fn(x): return dataset -def get_datasets(config: ml_collections.ConfigDict, - *, - n_devices: int, - vocab_path: Optional[str] = None): +def get_datasets( + config: ml_collections.ConfigDict, + *, + n_devices: int, + vocab_path: Optional[str] = None +): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model') @@ -330,11 +336,14 @@ def get_datasets(config: ml_collections.ConfigDict, train_data, vocab_path=vocab_path, vocab_size=config.vocab_size, - max_corpus_chars=config.max_corpus_chars) + max_corpus_chars=config.max_corpus_chars, + ) train_data = train_data.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) eval_data = eval_data.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) batch_size = config.per_device_batch_size * n_devices if config.eval_per_device_batch_size > 0: @@ -348,14 +357,16 @@ def get_datasets(config: ml_collections.ConfigDict, num_epochs=None, pack_examples=True, batch_size=batch_size, - max_length=config.max_target_length) + max_length=config.max_target_length, + ) eval_ds = preprocess_data( eval_data, shuffle=False, pack_examples=False, batch_size=eval_batch_size, - max_length=config.max_eval_target_length) + max_length=config.max_eval_target_length, + ) predict_ds = preprocess_data( eval_data, @@ -363,6 +374,7 @@ def get_datasets(config: ml_collections.ConfigDict, pack_examples=False, batch_size=eval_batch_size, max_length=config.max_predict_length, - drop_remainder=False) + drop_remainder=False, + ) return train_ds, eval_ds, predict_ds, sp_tokenizer diff --git a/examples/lm1b/input_pipeline_test.py b/examples/lm1b/input_pipeline_test.py index 41ba00b9ca..4cd14f87bc 100644 --- a/examples/lm1b/input_pipeline_test.py +++ b/examples/lm1b/input_pipeline_test.py @@ -53,7 +53,8 @@ def _get_datasets(self): with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): train_ds, eval_ds, predict_ds, _ = input_pipeline.get_datasets( - n_devices=2, config=config, vocab_path=vocab_path) + n_devices=2, config=config, vocab_path=vocab_path + ) return train_ds, eval_ds, predict_ds def test_train_ds(self): @@ -61,30 +62,39 @@ def test_train_ds(self): # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. for batch in self.train_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'inputs_position': expected_shape, + 'inputs_segmentation': expected_shape, + 'targets': expected_shape, + 'targets_position': expected_shape, + 'targets_segmentation': expected_shape, + }, + ) def test_eval_ds(self): expected_shape = [4, _EVAL_TARGET_LENGTH] # 2 devices. for batch in self.eval_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'targets': expected_shape, + }, + ) def test_predict_ds(self): expected_shape = [4, _PREDICT_TARGET_LENGTH] # 2 devices. for batch in self.predict_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'targets': expected_shape, + }, + ) if __name__ == '__main__': diff --git a/examples/lm1b/main.py b/examples/lm1b/main.py index 610bc87764..74ebec3cd3 100644 --- a/examples/lm1b/main.py +++ b/examples/lm1b/main.py @@ -35,7 +35,8 @@ 'config', 'configs/default.py', 'File path to the training hyperparameter configuration.', - lock_config=True) + lock_config=True, +) flags.mark_flags_as_required(['config', 'workdir']) @@ -52,10 +53,12 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}') - platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, - FLAGS.workdir, 'workdir') + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' + ) + platform.work_unit().create_artifact( + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/lm1b/models.py b/examples/lm1b/models.py index 94448913f4..8f9eca953a 100644 --- a/examples/lm1b/models.py +++ b/examples/lm1b/models.py @@ -31,9 +31,11 @@ import jax.numpy as jnp import numpy as np + @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int output_vocab_size: int share_embeddings: bool = False @@ -58,8 +60,7 @@ def shift_right(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) - padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + padded = jnp.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis) @@ -69,13 +70,11 @@ def shift_inputs(x, segment_ids=None, axis=1): # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. if segment_ids is not None: - shifted *= (segment_ids == shift_right(segment_ids, axis=axis)) + shifted *= segment_ids == shift_right(segment_ids, axis=axis) return shifted -def sinusoidal_init(max_len=2048, - min_scale=1.0, - max_scale=10000.0): +def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): """1D Sinusoidal Position Embedding Initializer. Args: @@ -95,8 +94,8 @@ def init(key, shape, dtype=np.float32): position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) - pe[:, :d_feature // 2] = np.sin(position * div_term) - pe[:, d_feature // 2: 2 * (d_feature // 2)] = np.cos(position * div_term) + pe[:, : d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) @@ -110,13 +109,12 @@ class AddPositionEmbs(nn.Module): config: TransformerConfig dataclass containing hyperparameters. decode: whether to run in single-position autoregressive mode. """ + config: TransformerConfig decode: bool = False @nn.compact - def __call__(self, - inputs, - inputs_positions=None): + def __call__(self, inputs, inputs_positions=None): """Applies AddPositionEmbs module. By default this layer uses a fixed sinusoidal embedding table. If a @@ -132,32 +130,29 @@ def __call__(self, """ config = self.config # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3,' ' but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, config.max_len, inputs.shape[-1]) if config.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=config.max_len)(None, - pos_emb_shape, - None) + pos_embedding = sinusoidal_init(max_len=config.max_len)(None, pos_emb_shape, None) else: - pos_embedding = self.param('pos_embedding', config.posemb_init, - pos_emb_shape) + pos_embedding = self.param('pos_embedding', config.posemb_init, pos_emb_shape) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 _, _, df = pos_embedding.shape - pe = lax.dynamic_slice(pos_embedding, - jnp.array((0, i, 0)), - (1, 1, df)) + pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)), (1, 1, df)) if inputs_positions is None: # normal unpacked case: return inputs + pe @@ -173,6 +168,7 @@ class MlpBlock(nn.Module): config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ + config: TransformerConfig out_dim: Optional[int] = None @@ -180,25 +176,24 @@ class MlpBlock(nn.Module): def __call__(self, inputs): """Applies Transformer MlpBlock module.""" config = self.config - actual_out_dim = (inputs.shape[-1] if self.out_dim is None - else self.out_dim) + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( config.mlp_dim, dtype=config.dtype, kernel_init=config.kernel_init, - bias_init=config.bias_init)( - inputs) + bias_init=config.bias_init, + )(inputs) x = nn.relu(x) - x = nn.Dropout(rate=config.dropout_rate)( - x, deterministic=config.deterministic) + x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=config.deterministic) output = nn.Dense( actual_out_dim, dtype=config.dtype, kernel_init=config.kernel_init, - bias_init=config.bias_init)( - x) + bias_init=config.bias_init, + )(x) output = nn.Dropout(rate=config.dropout_rate)( - output, deterministic=config.deterministic) + output, deterministic=config.deterministic + ) return output @@ -208,13 +203,11 @@ class EncoderDecoder1DBlock(nn.Module): Args: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig @nn.compact - def __call__(self, - inputs, - decoder_mask=None, - encoder_decoder_mask=None): + def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None): """Applies EncoderDecoder1DBlock module. Args: @@ -240,9 +233,9 @@ def __call__(self, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, - decode=config.decode)(x, decoder_mask) - x = nn.Dropout(rate=config.dropout_rate)( - x, deterministic=config.deterministic) + decode=config.decode, + )(x, decoder_mask) + x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=config.deterministic) x = x + inputs # MLP block. @@ -259,16 +252,19 @@ class Decoder(nn.Module): config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ + config: TransformerConfig shared_embedding: Any = None @nn.compact - def __call__(self, - inputs, - inputs_positions=None, - inputs_segmentation=None, - decoder_mask=None, - encoder_decoder_mask=None): + def __call__( + self, + inputs, + inputs_positions=None, + inputs_segmentation=None, + decoder_mask=None, + encoder_decoder_mask=None, + ): """Applies Transformer model on the inputs. Args: @@ -290,7 +286,8 @@ def __call__(self, output_embed = nn.Embed( num_embeddings=config.output_vocab_size, features=config.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: output_embed = self.shared_embedding @@ -298,21 +295,18 @@ def __call__(self, if not config.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) y = output_embed(y) - y = AddPositionEmbs( - config=config, decode=config.decode, name='posembed_output')( - y, inputs_positions=inputs_positions) - y = nn.Dropout(rate=config.dropout_rate)( - y, deterministic=config.deterministic) + y = AddPositionEmbs(config=config, decode=config.decode, name='posembed_output')( + y, inputs_positions=inputs_positions + ) + y = nn.Dropout(rate=config.dropout_rate)(y, deterministic=config.deterministic) y = y.astype(config.dtype) # Target-Input Decoder for lyr in range(config.num_layers): - y = EncoderDecoder1DBlock( - config=config, name=f'encoderdecoderblock_{lyr}')( - y, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + y = EncoderDecoder1DBlock(config=config, name=f'encoderdecoderblock_{lyr}')( + y, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask + ) y = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')(y) # Decoded Logits @@ -327,8 +321,8 @@ def __call__(self, dtype=config.dtype, kernel_init=config.kernel_init, bias_init=config.bias_init, - name='logitdense')( - y) + name='logitdense', + )(y) return logits @@ -338,13 +332,11 @@ class TransformerLM(nn.Module): Args: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig @nn.compact - def __call__(self, - inputs, - inputs_positions=None, - inputs_segmentation=None): + def __call__(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies TransformerLM on the inputs. Args: @@ -364,23 +356,23 @@ def __call__(self, else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=config.dtype), - nn.make_causal_mask(inputs, dtype=config.dtype)) + nn.make_causal_mask(inputs, dtype=config.dtype), + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask( - inputs_segmentation, - inputs_segmentation, - jnp.equal, - dtype=config.dtype)) - - logits = Decoder( - config=config, shared_embedding=None, name='decoder')( - inputs, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation, - decoder_mask=decoder_mask, - encoder_decoder_mask=None) + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype + ), + ) + + logits = Decoder(config=config, shared_embedding=None, name='decoder')( + inputs, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation, + decoder_mask=decoder_mask, + encoder_decoder_mask=None, + ) return logits.astype(self.config.dtype) diff --git a/examples/lm1b/temperature_sampler.py b/examples/lm1b/temperature_sampler.py index 2b52841325..ec35046621 100644 --- a/examples/lm1b/temperature_sampler.py +++ b/examples/lm1b/temperature_sampler.py @@ -24,13 +24,15 @@ EOS_ID = 2 -def temperature_sample(prompt_inputs, - init_cache, - tokens_to_logits, - prng_key, - temperature=1.0, - topk=20, - eos_token=EOS_ID): +def temperature_sample( + prompt_inputs, + init_cache, + tokens_to_logits, + prng_key, + temperature=1.0, + topk=20, + eos_token=EOS_ID, +): """Temperature sampling for language model generation. Args: @@ -72,7 +74,7 @@ def sampling_loop_cond_fn(state): """Sampling loop termination condition.""" (i, _, _, _, ended, _) = state # Have we reached max decoding length? - not_at_end = (i < max_decode_len - 1) + not_at_end = i < max_decode_len - 1 # Have all sampled sequences reached an end marker? all_sequences_ended = jnp.all(ended) return not_at_end & (~all_sequences_ended) @@ -89,30 +91,31 @@ def sampling_loop_body_fn(state): if topk: # Get top-k logits and their indices, sample within these top-k tokens. topk_logits, topk_idxs = lax.top_k(logits, topk) - topk_token = jnp.expand_dims(random.categorical( - rng1, topk_logits / temperature).astype(jnp.int32), axis=-1) + topk_token = jnp.expand_dims( + random.categorical(rng1, topk_logits / temperature).astype(jnp.int32), axis=-1 + ) # Return the original indices corresponding to the sampled top-k tokens. next_token = jnp.squeeze( - jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1) + jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1 + ) else: - next_token = random.categorical( - rng1, logits / temperature).astype(jnp.int32) + next_token = random.categorical(rng1, logits / temperature).astype(jnp.int32) # Only use sampled tokens if we're past provided prefix tokens. - out_of_prompt = (sequences[:, i+1] == 0) - next_token = (next_token * out_of_prompt + - sequences[:, i+1] * ~out_of_prompt) + out_of_prompt = sequences[:, i + 1] == 0 + next_token = next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt # If end-marker reached for batch item, only emit padding tokens. - next_token_or_endpad = (next_token[None] * ~ended) - ended |= (next_token_or_endpad == end_marker) + next_token_or_endpad = next_token[None] * ~ended + ended |= next_token_or_endpad == end_marker # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice( - sequences, next_token_or_endpad, (0, i+1)) - return (i+1, new_sequences, new_cache, next_token_or_endpad, ended, rng2) + sequences, next_token_or_endpad, (0, i + 1) + ) + return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2) # Run sampling loop and collect final state. - final_state = lax.while_loop(sampling_loop_cond_fn, - sampling_loop_body_fn, - sampling_loop_init_state) + final_state = lax.while_loop( + sampling_loop_cond_fn, sampling_loop_body_fn, sampling_loop_init_state + ) # Pick part of the state corresponding to the sampled sequences. final_sequences = final_state[1] diff --git a/examples/lm1b/temperature_sampler_test.py b/examples/lm1b/temperature_sampler_test.py index 6cb1a8266e..0627d68530 100644 --- a/examples/lm1b/temperature_sampler_test.py +++ b/examples/lm1b/temperature_sampler_test.py @@ -24,22 +24,23 @@ class TestTemperatureSampler(absltest.TestCase): - def test_temperature_sampler(self): + def test_temperature_sampler(self): tokens = jnp.array([[5, 0, 0, 0]], dtype=jnp.int32) cache = None key = jax.random.PRNGKey(0) def tokens_to_logits(tokens, cache): - jax.debug.print("tokens: {}", tokens) - logits = jax.nn.one_hot(tokens[..., -1:] + 1, 10) - logits = jnp.where(logits < 0.5, float('-inf'), logits) - logits = logits.squeeze(axis=1) - return logits, cache + jax.debug.print('tokens: {}', tokens) + logits = jax.nn.one_hot(tokens[..., -1:] + 1, 10) + logits = jnp.where(logits < 0.5, float('-inf'), logits) + logits = logits.squeeze(axis=1) + return logits, cache new_tokens = temperature_sample(tokens, cache, tokens_to_logits, key, topk=5) np.testing.assert_array_equal(new_tokens, [[5, 6, 7, 8]]) + if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() diff --git a/examples/lm1b/tokenizer.py b/examples/lm1b/tokenizer.py index 0b93b2c50c..54655b7a21 100644 --- a/examples/lm1b/tokenizer.py +++ b/examples/lm1b/tokenizer.py @@ -30,9 +30,7 @@ def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('inputs', 'targets') + dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=('inputs', 'targets') ) -> Tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. @@ -46,8 +44,7 @@ def _dump_chars_to_textfile( """ char_count = 0 ds_iter = dataset.as_numpy_iterator() - with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + with tempfile.NamedTemporaryFile(delete=False, prefix='/tmp/ds_chars') as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: @@ -57,14 +54,16 @@ def _dump_chars_to_textfile( return outfp.name, char_count -def _train_sentencepiece(dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('inputs', 'targets')): +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + model_path: str, + model_type: str = 'unigram', + character_coverage: float = 1.0, + data_keys=('inputs', 'targets'), +): """Train SentencePiece tokenizer from subset of tf dataset. Args: @@ -85,15 +84,15 @@ def _train_sentencepiece(dataset: tf.data.Dataset, abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) - fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys) - with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) + with tempfile.NamedTemporaryFile(delete=False, prefix='/tmp/sp_tmp') as model_fp: pass # we just want a prefix'd tmp-filename argstr = ' '.join([ - f'--input={fname}', f'--vocab_size={vocab_size}', + f'--input={fname}', + f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', - f'--model_prefix={model_fp.name}', f'--model_type={model_type}' + f'--model_prefix={model_fp.name}', + f'--model_type={model_type}', ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: @@ -110,24 +109,26 @@ def _train_sentencepiece(dataset: tf.data.Dataset, return abs_model_path -def _load_sentencepiece_tokenizer(model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def _load_sentencepiece_tokenizer( + model_path: str, add_bos: bool = False, add_eos: bool = True, reverse: bool = False +): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) return sp_tokenizer -def load_or_train_tokenizer(dataset: tf.data.Dataset, - *, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: Tuple[str, str] = ('inputs', 'targets')): +def load_or_train_tokenizer( + dataset: tf.data.Dataset, + *, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: Tuple[str, str] = ('inputs', 'targets'), +): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: return _load_sentencepiece_tokenizer(vocab_path) @@ -138,13 +139,13 @@ def load_or_train_tokenizer(dataset: tf.data.Dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, model_path=vocab_path, - data_keys=data_keys) + data_keys=data_keys, + ) return _load_sentencepiece_tokenizer(vocab_path) @dataclasses.dataclass class TokenizeOp: - sp_tokenizer: Any data_keys: Iterable[str] = ('inputs', 'targets') diff --git a/examples/lm1b/train.py b/examples/lm1b/train.py index 9e8136653a..02a89b7b0a 100644 --- a/examples/lm1b/train.py +++ b/examples/lm1b/train.py @@ -63,25 +63,25 @@ def rsqrt_schedule( """ def schedule(count): - return init_value * (count + shift)**-.5 * shift**.5 + return init_value * (count + shift) ** -0.5 * shift**0.5 return schedule def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): """Creates a rsqrt schedule with linear warmup.""" - return optax.join_schedules([ - optax.linear_schedule( - init_value=0, end_value=learning_rate, transition_steps=warmup_steps), - rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), - ], - boundaries=[warmup_steps]) - - -def compute_weighted_cross_entropy(logits, - targets, - weights=None, - label_smoothing=0.0): + return optax.join_schedules( + [ + optax.linear_schedule( + init_value=0, end_value=learning_rate, transition_steps=warmup_steps + ), + rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), + ], + boundaries=[warmup_steps], + ) + + +def compute_weighted_cross_entropy(logits, targets, weights=None, label_smoothing=0.0): """Compute weighted cross entropy and entropy for log probs and targets. Args: @@ -95,16 +95,20 @@ def compute_weighted_cross_entropy(logits, Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: - raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % - (str(logits.shape), str(targets.shape))) + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" + % (str(logits.shape), str(targets.shape)) + ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( - confidence * jnp.log(confidence) + - (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) + confidence * jnp.log(confidence) + + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) soft_targets = common_utils.onehot( - targets, vocab_size, on_value=confidence, off_value=low_confidence) + targets, vocab_size, on_value=confidence, off_value=low_confidence + ) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant @@ -129,8 +133,10 @@ def compute_weighted_accuracy(logits, targets, weights=None): Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: - raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % - (str(logits.shape), str(targets.shape))) + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" + % (str(logits.shape), str(targets.shape)) + ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: @@ -142,8 +148,9 @@ def compute_weighted_accuracy(logits, targets, weights=None): def compute_metrics(logits, labels, weights, label_smoothing=0.0): """Compute summary metrics.""" - loss, weight_sum = compute_weighted_cross_entropy(logits, labels, weights, - label_smoothing) + loss, weight_sum = compute_weighted_cross_entropy( + logits, labels, weights, label_smoothing + ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { "loss": loss, @@ -158,12 +165,9 @@ def compute_metrics(logits, labels, weights, label_smoothing=0.0): # ----------------------------------------------------------------------------- -def train_step(state, - batch, - config, - learning_rate_fn, - label_smoothing=0.0, - dropout_rng=None): +def train_step( + state, batch, config, learning_rate_fn, label_smoothing=0.0, dropout_rng=None +): """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" # where multiple sequences are packed into the same example with this @@ -171,8 +175,9 @@ def train_step(state, # if such features are not present they are ignored and the example is treated # like a normal, unpacked sequence example. train_keys = ["inputs", "inputs_position", "inputs_segmentation"] - (inputs, inputs_positions, inputs_segmentation - ) = (batch.get(k, None) for k in train_keys) + (inputs, inputs_positions, inputs_segmentation) = ( + batch.get(k, None) for k in train_keys + ) weights = jnp.where(inputs > 0, 1, 0).astype(jnp.float32) @@ -185,10 +190,12 @@ def loss_fn(params): inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, - rngs={"dropout": dropout_rng}) + rngs={"dropout": dropout_rng}, + ) - loss, weight_sum = compute_weighted_cross_entropy(logits, inputs, weights, - label_smoothing) + loss, weight_sum = compute_weighted_cross_entropy( + logits, inputs, weights, label_smoothing + ) mean_loss = loss / weight_sum return mean_loss, logits @@ -213,31 +220,22 @@ def eval_step(params, batch, config, label_smoothing=0.0): return compute_metrics(logits, inputs, weights, label_smoothing) -def predict_step(inputs, - params, - rngkey, - eos_id, - max_decode_len, - config, - temperature, - top_k): +def predict_step( + inputs, params, rngkey, eos_id, max_decode_len, config, temperature, top_k +): """Predict language model on a batch.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.TransformerLM(config).init( - jax.random.PRNGKey(0), - jnp.ones(target_shape, config.dtype)) + jax.random.PRNGKey(0), jnp.ones(target_shape, config.dtype) + ) cache = initial_variables["cache"] def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.TransformerLM(config).apply( - { - "params": params, - "cache": flat_cache - }, - flat_ids, - mutable=["cache"]) + {"params": params, "cache": flat_cache}, flat_ids, mutable=["cache"] + ) new_flat_cache = new_vars["cache"] # Remove singleton sequence-length dimension: # [batch, 1, vocab] --> [batch, vocab] @@ -253,7 +251,8 @@ def tokens_ids_to_logits(flat_ids, flat_cache): rngkey, temperature=temperature, topk=top_k, - eos_token=eos_id) + eos_token=eos_id, + ) return seqs @@ -291,8 +290,7 @@ def tohost(x): return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims)) -def evaluate(*, p_eval_step, params, eval_ds: tf.data.Dataset, - num_eval_steps: int): +def evaluate(*, p_eval_step, params, eval_ds: tf.data.Dataset, num_eval_steps: int): """Evaluate the target an return a dictionary with the metrics.""" logging.info("Gathering evaluation metrics.") eval_metrics = [] @@ -307,16 +305,21 @@ def evaluate(*, p_eval_step, params, eval_ds: tf.data.Dataset, eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_util.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop - eval_metrics_sums) + eval_metrics_sums, + ) return eval_summary -def generate_prediction(*, p_pred_step, params, - tokenized_prompts, - eos_id, - inference_rng, - decode_tokens, - max_predict_length: int): +def generate_prediction( + *, + p_pred_step, + params, + tokenized_prompts, + eos_id, + inference_rng, + decode_tokens, + max_predict_length: int, +): """Generate text from the prompt.""" n_devices = jax.local_device_count() @@ -324,19 +327,21 @@ def generate_prediction(*, p_pred_step, params, predictions = [] # Use batch of prompts provided by user. for pred_batch in jnp.array_split( - tokenized_prompts, int(np.ceil(len(tokenized_prompts) / n_devices))): + tokenized_prompts, int(np.ceil(len(tokenized_prompts) / n_devices)) + ): cur_pred_batch_size = pred_batch.shape[0] if cur_pred_batch_size % n_devices: - padded_size = int( - np.ceil(cur_pred_batch_size / n_devices) * n_devices) + padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_util.tree_map( - lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop + lambda x: pad_examples(x, padded_size), pred_batch + ) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) inference_rng, sub_rng = random.split(inference_rng) inference_rngs = random.split(sub_rng, n_devices) - predicted = p_pred_step(pred_batch, params, inference_rngs, - eos_id, max_predict_length) + predicted = p_pred_step( + pred_batch, params, inference_rngs, eos_id, max_predict_length + ) predicted = tohost(predicted) # Iterate through non-padding examples of batch. for s in predicted[:cur_pred_batch_size]: @@ -371,16 +376,15 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, _, encoder = input_pipeline.get_datasets( - n_devices=jax.local_device_count(), - config=config, - vocab_path=vocab_path) + n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path + ) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = temperature_sampler.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): - valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) + valid_toks = toks[: np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") def encode_strings(strs, max_len): @@ -388,11 +392,10 @@ def encode_strings(strs, max_len): for i, s in enumerate(strs): toks = encoder.tokenize(s).numpy() # Remove EOS token in prompt. - tokenized_batch[i, :toks.shape[0]-1] = toks[:-1] + tokenized_batch[i, : toks.shape[0] - 1] = toks[:-1] return tokenized_batch - tokenized_prompts = encode_strings( - [config.prompts], config.max_predict_length) + tokenized_prompts = encode_strings([config.prompts], config.max_predict_length) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer @@ -413,7 +416,8 @@ def encode_strings(strs, max_len): deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.normal(stddev=1e-6)) + bias_init=nn.initializers.normal(stddev=1e-6), + ) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) @@ -424,21 +428,18 @@ def encode_strings(strs, max_len): input_shape = (config.per_device_batch_size, config.max_target_length) m = models.TransformerLM(eval_config) - initial_variables = jax.jit(m.init)(init_rng, - jnp.ones(input_shape, jnp.float32)) + initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32)) learning_rate_fn = create_learning_rate_schedule( - learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) + learning_rate=config.learning_rate, warmup_steps=config.warmup_steps + ) optimizer = optax.adamw( - learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, - weight_decay=config.weight_decay - ) + learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay + ) state = train_state.TrainState.create( - apply_fn=m.apply, - params=initial_variables["params"], - tx=optimizer - ) + apply_fn=m.apply, params=initial_variables["params"], tx=optimizer + ) # We access model params only from optimizer below. del initial_variables @@ -449,7 +450,8 @@ def encode_strings(strs, max_len): start_step = int(state.step) writer = metric_writers.create_default_writer( - workdir, just_logging=jax.process_index() > 0) + workdir, just_logging=jax.process_index() > 0 + ) if start_step == 0: writer.write_hparams(dict(config)) @@ -459,23 +461,25 @@ def encode_strings(strs, max_len): # compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap( functools.partial( - train_step, - config=train_config, - learning_rate_fn=learning_rate_fn), + train_step, config=train_config, learning_rate_fn=learning_rate_fn + ), axis_name="batch", - donate_argnums=(0,)) # pytype: disable=wrong-arg-types + donate_argnums=(0,), + ) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( - functools.partial( - eval_step, config=eval_config), - axis_name="batch") + functools.partial(eval_step, config=eval_config), axis_name="batch" + ) p_pred_step = jax.pmap( functools.partial( - predict_step, config=predict_config, + predict_step, + config=predict_config, temperature=config.sampling_temperature, - top_k=config.sampling_top_k), + top_k=config.sampling_top_k, + ), axis_name="batch", - static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant + static_broadcasted_argnums=(3, 4), + ) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- @@ -488,11 +492,12 @@ def encode_strings(strs, max_len): logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( - num_train_steps=config.num_train_steps, writer=writer) + num_train_steps=config.num_train_steps, writer=writer + ) if jax.process_index() == 0: hooks += [ report_progress, - periodic_actions.Profile(logdir=workdir, num_profile_steps=5) + periodic_actions.Profile(logdir=workdir, num_profile_steps=5), ] train_metrics = [] with metric_writers.ensure_flushes(writer): @@ -502,8 +507,7 @@ def encode_strings(strs, max_len): # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter))) - state, metrics = p_train_step( - state, batch, dropout_rng=dropout_rngs) + state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. @@ -521,8 +525,7 @@ def encode_strings(strs, max_len): denominator = metrics_sums.pop("denominator") summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr - summary["perplexity"] = jnp.clip( - jnp.exp(summary["loss"]), a_max=1.0e4) + summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]), a_max=1.0e4) summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] @@ -532,12 +535,13 @@ def encode_strings(strs, max_len): p_eval_step=p_eval_step, params=state.params, eval_ds=eval_ds, - num_eval_steps=config.num_eval_steps) + num_eval_steps=config.num_eval_steps, + ) # (clipped) perplexity after averaging log-perplexitie eval_results["perplexity"] = jnp.clip( - jnp.exp(eval_results["loss"]), a_max=1.0e4) - writer.write_scalars( - step, {"eval_" + k: v for k, v in eval_results.items()}) + jnp.exp(eval_results["loss"]), a_max=1.0e4 + ) + writer.write_scalars(step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("generate_text"): exemplars = generate_prediction( @@ -547,12 +551,12 @@ def encode_strings(strs, max_len): eos_id=eos_id, inference_rng=inference_rng, decode_tokens=decode_tokens, - max_predict_length=config.max_predict_length) + max_predict_length=config.max_predict_length, + ) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. - save_checkpoint = (step % config.checkpoint_every_steps == 0 or - is_last_step) + save_checkpoint = step % config.checkpoint_every_steps == 0 or is_last_step if config.save_checkpoints and save_checkpoint: logging.info("Saving checkpoint step %d.", step) with report_progress.timed("checkpoint"): diff --git a/examples/mnist/main.py b/examples/mnist/main.py index 04b6ea37f2..3a50d0f9de 100644 --- a/examples/mnist/main.py +++ b/examples/mnist/main.py @@ -36,7 +36,8 @@ 'config', None, 'File path to the training hyperparameter configuration.', - lock_config=True) + lock_config=True, +) def main(argv): @@ -52,10 +53,12 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}') - platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, - FLAGS.workdir, 'workdir') + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' + ) + platform.work_unit().create_artifact( + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/mnist/train.py b/examples/mnist/train.py index 79e937d139..664553e7fa 100644 --- a/examples/mnist/train.py +++ b/examples/mnist/train.py @@ -54,6 +54,7 @@ def __call__(self, x): @jax.jit def apply_model(state, images, labels): """Computes gradients, loss and accuracy for a single batch.""" + def loss_fn(params): logits = state.apply_fn({'params': params}, images) one_hot = jax.nn.one_hot(labels, 10) @@ -77,7 +78,7 @@ def train_epoch(state, train_ds, batch_size, rng): steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rng, len(train_ds['image'])) - perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch perms = perms.reshape((steps_per_epoch, batch_size)) epoch_loss = [] @@ -101,8 +102,8 @@ def get_datasets(): ds_builder.download_and_prepare() train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) - train_ds['image'] = jnp.float32(train_ds['image']) / 255. - test_ds['image'] = jnp.float32(test_ds['image']) / 255. + train_ds['image'] = jnp.float32(train_ds['image']) / 255.0 + test_ds['image'] = jnp.float32(test_ds['image']) / 255.0 return train_ds, test_ds @@ -111,12 +112,12 @@ def create_train_state(rng, config): cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(config.learning_rate, config.momentum) - return train_state.TrainState.create( - apply_fn=cnn.apply, params=params, tx=tx) + return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) -def train_and_evaluate(config: ml_collections.ConfigDict, - workdir: str) -> train_state.TrainState: +def train_and_evaluate( + config: ml_collections.ConfigDict, workdir: str +) -> train_state.TrainState: """Execute model training and evaluation loop. Args: @@ -137,16 +138,15 @@ def train_and_evaluate(config: ml_collections.ConfigDict, for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) - state, train_loss, train_accuracy = train_epoch(state, train_ds, - config.batch_size, - input_rng) - _, test_loss, test_accuracy = apply_model(state, test_ds['image'], - test_ds['label']) + state, train_loss, train_accuracy = train_epoch( + state, train_ds, config.batch_size, input_rng + ) + _, test_loss, test_accuracy = apply_model(state, test_ds['image'], test_ds['label']) logging.info( 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f' - % (epoch, train_loss, train_accuracy * 100, test_loss, - test_accuracy * 100)) + % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100) + ) summary_writer.scalar('train_loss', train_loss, epoch) summary_writer.scalar('train_accuracy', train_accuracy, epoch) diff --git a/examples/mnist/train_test.py b/examples/mnist/train_test.py index 9a07c3bb57..402042b566 100644 --- a/examples/mnist/train_test.py +++ b/examples/mnist/train_test.py @@ -48,7 +48,10 @@ def test_cnn(self): self.assertEqual((1, 10), output.shape) self.assertEqual( CNN_PARAMS, - sum(np.prod(arr.shape) for arr in jax.tree_util.tree_leaves(variables["params"]))) + sum( + np.prod(arr.shape) for arr in jax.tree_util.tree_leaves(variables["params"]) + ), + ) def test_train_and_evaluate(self): """Tests training and evaluation code by running a single step.""" diff --git a/examples/nlp_seq/input_pipeline.py b/examples/nlp_seq/input_pipeline.py index dd7011ec0a..992789f9ab 100644 --- a/examples/nlp_seq/input_pipeline.py +++ b/examples/nlp_seq/input_pipeline.py @@ -43,6 +43,7 @@ class CoNLLAttributes(enum.Enum): For details, please see: http://universaldependencies.org/format.html. """ + ID = 0 FORM = 1 LEMMA = 2 @@ -79,17 +80,23 @@ def create_vocabs(filename, max_num_forms=100000): # create word form vocab vocabs = {'forms': {}, 'xpos': {}} vocabs['forms'].update(special_tokens) - vocabs['forms'].update({ - form[0]: id for id, form in enumerate( - form_counter.most_common(max_num_forms), start=ROOT_ID + 1) - }) + vocabs['forms'].update( + { + form[0]: id + for id, form in enumerate( + form_counter.most_common(max_num_forms), start=ROOT_ID + 1 + ) + } + ) # create xpos vocab vocabs['xpos'].update(special_tokens) - vocabs['xpos'].update({ - tag[0]: id - for id, tag in enumerate(xpos_counter.most_common(), start=ROOT_ID + 1) - }) + vocabs['xpos'].update( + { + tag[0]: id + for id, tag in enumerate(xpos_counter.most_common(), start=ROOT_ID + 1) + } + ) return vocabs @@ -124,8 +131,7 @@ def create_token(token, attributes, vocabs): elif attribute == CoNLLAttributes.HEAD: selected_attributes.append(int(token[index])) else: - raise ValueError('CoNLL index %s not covered by mapping.' % - str(attribute.name)) + raise ValueError('CoNLL index %s not covered by mapping.' % str(attribute.name)) return selected_attributes @@ -150,10 +156,9 @@ def create_sentence_with_root(attributes, vocabs): return [token] -def sentences_from_conll_data(corpus_filename, - vocabs, - attributes, - max_sentence_length=1000): +def sentences_from_conll_data( + corpus_filename, vocabs, attributes, max_sentence_length=1000 +): """Load and returns conll data in list format. Args: @@ -187,14 +192,16 @@ def sentences_from_conll_data(corpus_filename, yield sentence -def sentence_dataset_dict(filename, - vocabs, - attributes_input, - attributes_target, - batch_size, - bucket_size, - repeat=None, - prefetch_size=tf.data.experimental.AUTOTUNE): +def sentence_dataset_dict( + filename, + vocabs, + attributes_input, + attributes_target, + batch_size, + bucket_size, + repeat=None, + prefetch_size=tf.data.experimental.AUTOTUNE, +): """Combines sentences into a dataset of padded batches. Args: @@ -217,11 +224,13 @@ def sentence_dataset_dict(filename, def generator(): """Generator to create the data.""" input_generator = sentences_from_conll_data( - filename, vocabs, attributes_input, max_sentence_length=bucket_size) + filename, vocabs, attributes_input, max_sentence_length=bucket_size + ) if attributes_target: target_generator = sentences_from_conll_data( - filename, vocabs, attributes_target, max_sentence_length=bucket_size) + filename, vocabs, attributes_target, max_sentence_length=bucket_size + ) for inputs in input_generator: data = {'inputs': inputs} @@ -232,7 +241,8 @@ def generator(): output_types = {k: tf.float32 for k in data_keys} output_shapes = {k: (None,) for k in data_keys} dataset = tf.data.Dataset.from_generator( - generator, output_types=output_types, output_shapes=output_shapes) + generator, output_types=output_types, output_shapes=output_shapes + ) # cache the dataset in memory and repeat. dataset = dataset.cache() @@ -240,8 +250,7 @@ def generator(): # static padding up to bucket size. padded_shapes = {k: [bucket_size] for k in data_keys} - dataset = dataset.padded_batch( - batch_size=batch_size, padded_shapes=(padded_shapes)) + dataset = dataset.padded_batch(batch_size=batch_size, padded_shapes=(padded_shapes)) dataset = dataset.prefetch(prefetch_size) return dataset diff --git a/examples/nlp_seq/input_pipeline_test.py b/examples/nlp_seq/input_pipeline_test.py index 8b368a4e04..28194ea1ed 100644 --- a/examples/nlp_seq/input_pipeline_test.py +++ b/examples/nlp_seq/input_pipeline_test.py @@ -58,7 +58,8 @@ def test_vocab_creation(self): """Tests the creation of the vocab.""" vocabs = input_pipeline.create_vocabs(self._filename) self.assertEqual( - vocabs['forms'], { + vocabs['forms'], + { '

': 0, '': 1, '': 2, @@ -67,7 +68,8 @@ def test_vocab_creation(self): 'books': 5, '.': 6, 'NY': 7, - }) + }, + ) def testInputBatch(self): """Test the batching of the dataset.""" @@ -76,15 +78,26 @@ def testInputBatch(self): attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [] # empty target for tagging of unlabeled data. sentence_dataset = input_pipeline.sentence_dataset_dict( - self._filename, vocabs, attributes_input, attributes_target, - batch_size=2, bucket_size=10, repeat=1) + self._filename, + vocabs, + attributes_input, + attributes_target, + batch_size=2, + bucket_size=10, + repeat=1, + ) sentence_dataset_iter = iter(sentence_dataset) batch = next(sentence_dataset_iter) inputs = batch['inputs'].numpy().tolist() - self.assertSameStructure(inputs, [[2., 3., 4., 5., 6., 0., 0., 0., 0., 0.], - [2., 3., 4., 5., 6., 0., 0., 0., 0., 0.]]) + self.assertSameStructure( + inputs, + [ + [2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ) self.assertLen(batch, 1) # make sure target is not included. def testInputTargetBatch(self): @@ -94,19 +107,34 @@ def testInputTargetBatch(self): attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] sentence_dataset = input_pipeline.sentence_dataset_dict( - self._filename, vocabs, attributes_input, attributes_target, - batch_size=2, bucket_size=10, repeat=1) + self._filename, + vocabs, + attributes_input, + attributes_target, + batch_size=2, + bucket_size=10, + repeat=1, + ) sentence_dataset_iter = iter(sentence_dataset) batch = next(sentence_dataset_iter) inputs = batch['inputs'].numpy().tolist() - self.assertSameStructure(inputs, [[2., 3., 4., 5., 6., 0., 0., 0., 0., 0.], - [2., 3., 4., 5., 6., 0., 0., 0., 0., 0.]]) + self.assertSameStructure( + inputs, + [ + [2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ) targets = batch['targets'].numpy().tolist() - self.assertSameStructure(targets, - [[2., 4., 5., 3., 6., 0., 0., 0., 0., 0.], - [2., 4., 5., 3., 6., 0., 0., 0., 0., 0.]]) + self.assertSameStructure( + targets, + [ + [2.0, 4.0, 5.0, 3.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.0, 4.0, 5.0, 3.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ) if __name__ == '__main__': diff --git a/examples/nlp_seq/models.py b/examples/nlp_seq/models.py index c44d6fd9de..bf01b49d79 100644 --- a/examples/nlp_seq/models.py +++ b/examples/nlp_seq/models.py @@ -25,6 +25,7 @@ @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int output_vocab_size: int dtype: Any = jnp.float32 @@ -57,8 +58,7 @@ def init(key, shape, dtype=np.float32): d_feature = shape[-1] pe = np.zeros((max_len, d_feature), dtype=np.float32) position = np.arange(0, max_len)[:, np.newaxis] - div_term = np.exp( - np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] @@ -73,6 +73,7 @@ class AddPositionEmbs(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig @nn.compact @@ -91,18 +92,16 @@ def __call__(self, inputs): """ config = self.config # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3,' ' but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, config.max_len, inputs.shape[-1]) if config.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=config.max_len)(None, - pos_emb_shape, - None) + pos_embedding = sinusoidal_init(max_len=config.max_len)(None, pos_emb_shape, None) else: - pos_embedding = self.param('pos_embedding', config.posemb_init, - pos_emb_shape) + pos_embedding = self.param('pos_embedding', config.posemb_init, pos_emb_shape) pe = pos_embedding[:, :length, :] return inputs + pe @@ -114,6 +113,7 @@ class MlpBlock(nn.Module): config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ + config: TransformerConfig out_dim: Optional[int] = None @@ -121,24 +121,22 @@ class MlpBlock(nn.Module): def __call__(self, inputs, deterministic=True): """Applies Transformer MlpBlock module.""" config = self.config - actual_out_dim = (inputs.shape[-1] if self.out_dim is None - else self.out_dim) + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( config.mlp_dim, dtype=config.dtype, kernel_init=config.kernel_init, - bias_init=config.bias_init)( - inputs) + bias_init=config.bias_init, + )(inputs) x = nn.elu(x) x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic) output = nn.Dense( actual_out_dim, dtype=config.dtype, kernel_init=config.kernel_init, - bias_init=config.bias_init)( - x) - output = nn.Dropout(rate=config.dropout_rate)( - output, deterministic=deterministic) + bias_init=config.bias_init, + )(x) + output = nn.Dropout(rate=config.dropout_rate)(output, deterministic=deterministic) return output @@ -148,6 +146,7 @@ class Encoder1DBlock(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig @nn.compact @@ -175,8 +174,8 @@ def __call__(self, inputs, deterministic): use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, - deterministic=deterministic)( - x) + deterministic=deterministic, + )(x) x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic) x = x + inputs @@ -210,9 +209,8 @@ def __call__(self, *, inputs, train): x = inputs.astype('int32') x = nn.Embed( - num_embeddings=config.vocab_size, features=config.emb_dim, - name='embed')( - x) + num_embeddings=config.vocab_size, features=config.emb_dim, name='embed' + )(x) x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=not train) x = AddPositionEmbs(config)(x) @@ -223,6 +221,6 @@ def __call__(self, *, inputs, train): logits = nn.Dense( config.output_vocab_size, kernel_init=config.kernel_init, - bias_init=config.bias_init)( - x) + bias_init=config.bias_init, + )(x) return logits diff --git a/examples/nlp_seq/train.py b/examples/nlp_seq/train.py index b15dac410f..78a76f24c2 100644 --- a/examples/nlp_seq/train.py +++ b/examples/nlp_seq/train.py @@ -46,29 +46,25 @@ flags.DEFINE_string('experiment', default='xpos', help=('Experiment name.')) -flags.DEFINE_integer( - 'batch_size', default=64, help=('Batch size for training.')) +flags.DEFINE_integer('batch_size', default=64, help=('Batch size for training.')) flags.DEFINE_integer( 'eval_frequency', default=100, - help=('Frequency of eval during training, e.g. every 1000 steps.')) + help=('Frequency of eval during training, e.g. every 1000 steps.'), +) -flags.DEFINE_integer( - 'num_train_steps', default=75000, help=('Number of train steps.')) +flags.DEFINE_integer('num_train_steps', default=75000, help=('Number of train steps.')) flags.DEFINE_float('learning_rate', default=0.05, help=('Learning rate.')) flags.DEFINE_float( - 'weight_decay', - default=1e-1, - help=('Decay factor for AdamW style weight decay.')) + 'weight_decay', default=1e-1, help=('Decay factor for AdamW style weight decay.') +) -flags.DEFINE_integer('max_length', default=256, - help=('Maximum length of examples.')) +flags.DEFINE_integer('max_length', default=256, help=('Maximum length of examples.')) -flags.DEFINE_integer( - 'random_seed', default=0, help=('Integer for PRNG random seed.')) +flags.DEFINE_integer('random_seed', default=0, help=('Integer for PRNG random seed.')) flags.DEFINE_string('train', default='', help=('Path to training data.')) @@ -81,7 +77,8 @@ def create_learning_rate_scheduler( warmup_steps=8000, decay_factor=0.5, steps_per_decay=20000, - steps_per_cycle=100000): + steps_per_cycle=100000, +): """creates learning rate schedule. Interprets factors in the factors string which can consist of: @@ -119,12 +116,10 @@ def step_fn(step): ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': - ret *= (decay_factor**(step // steps_per_decay)) + ret *= decay_factor ** (step // steps_per_decay) elif name == 'cosine_decay': - progress = jnp.maximum(0.0, - (step - warmup_steps) / float(steps_per_cycle)) - ret *= jnp.maximum(0.0, - 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) + progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) + ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) return jnp.asarray(ret, dtype=jnp.float32) @@ -144,8 +139,10 @@ def compute_weighted_cross_entropy(logits, targets, weights=None): Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: - raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % - (str(logits.shape), str(targets.shape))) + raise ValueError( + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) + ) onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = onehot_targets.sum() @@ -168,8 +165,10 @@ def compute_weighted_accuracy(logits, targets, weights=None): Tuple of scalar accuracy and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: - raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % - (str(logits.shape), str(targets.shape))) + raise ValueError( + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) + ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: @@ -192,11 +191,7 @@ def compute_metrics(logits, labels, weights): return metrics -def train_step(state, - batch, - model, - learning_rate_fn, - dropout_rng=None): +def train_step(state, batch, model, learning_rate_fn, dropout_rng=None): """Perform a single training step.""" train_keys = ['inputs', 'targets'] (inputs, targets) = (batch.get(k, None) for k in train_keys) @@ -207,8 +202,9 @@ def train_step(state, def loss_fn(params): """loss function used for training.""" - logits = model.apply({'params': params}, inputs=inputs, train=True, - rngs={'dropout': dropout_rng}) + logits = model.apply( + {'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng} + ) loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights) mean_loss = loss / weight_sum @@ -217,10 +213,10 @@ def loss_fn(params): lr = learning_rate_fn(state.step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) - grads = jax.lax.pmean(grads, "batch") + grads = jax.lax.pmean(grads, 'batch') new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, targets, weights) - metrics["learning_rate"] = lr + metrics['learning_rate'] = lr return new_state, metrics @@ -255,16 +251,19 @@ def main(argv): if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( - os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) + os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train') + ) eval_summary_writer = tensorboard.SummaryWriter( - os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) + os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval') + ) # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) config = models.TransformerConfig( vocab_size=len(vocabs['forms']), output_vocab_size=len(vocabs['xpos']), - max_len=FLAGS.max_length) + max_len=FLAGS.max_length, + ) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] @@ -274,7 +273,8 @@ def main(argv): attributes_input, attributes_target, batch_size=batch_size, - bucket_size=config.max_len) + bucket_size=config.max_len, + ) train_iter = iter(train_ds) eval_ds = input_pipeline.sentence_dataset_dict( @@ -284,7 +284,8 @@ def main(argv): attributes_target, batch_size=batch_size, bucket_size=config.max_len, - repeat=1) + repeat=1, + ) model = models.Transformer(config) @@ -297,29 +298,26 @@ def initialize_variables(init_rng): init_batch = jnp.ones((config.max_len, 1), jnp.float32) init_variables = model.init(init_rng, inputs=init_batch, train=False) return init_variables + init_variables = initialize_variables(init_rng) - learning_rate_fn = create_learning_rate_scheduler( - base_learning_rate=learning_rate) + learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=learning_rate) optimizer = optax.adamw( - learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, - weight_decay=1e-1) + learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=1e-1 + ) state = train_state.TrainState.create( - apply_fn=model.apply, - params=init_variables["params"], - tx=optimizer) + apply_fn=model.apply, params=init_variables['params'], tx=optimizer + ) # Replicate optimizer. state = jax_utils.replicate(state) p_train_step = jax.pmap( - functools.partial( - train_step, - model=model, - learning_rate_fn=learning_rate_fn), + functools.partial(train_step, model=model, learning_rate_fn=learning_rate_fn), axis_name='batch', - donate_argnums=(0,)) # pytype: disable=wrong-arg-types + donate_argnums=(0,), + ) # pytype: disable=wrong-arg-types def eval_step(params, batch): """Calculate evaluation metrics on a batch.""" @@ -370,7 +368,8 @@ def eval_step(params, batch): if cur_pred_batch_size != batch_size: # pad up to batch size eval_batch = jax.tree_util.tree_map( - lambda x: pad_examples(x, batch_size), eval_batch) + lambda x: pad_examples(x, batch_size), eval_batch + ) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(state.params, eval_batch) @@ -381,10 +380,15 @@ def eval_step(params, batch): eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_util.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop - eval_metrics_sums) - - logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, - eval_summary['loss'], eval_summary['accuracy']) + eval_metrics_sums, + ) + + logging.info( + 'eval in step: %d, loss: %.4f, accuracy: %.4f', + step, + eval_summary['loss'], + eval_summary['accuracy'], + ) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] diff --git a/examples/ogbg_molpcba/input_pipeline.py b/examples/ogbg_molpcba/input_pipeline.py index 8a308bcf68..2f3b9eb012 100644 --- a/examples/ogbg_molpcba/input_pipeline.py +++ b/examples/ogbg_molpcba/input_pipeline.py @@ -25,6 +25,7 @@ class GraphsTupleSize(NamedTuple): """Helper class to represent padding and graph sizes.""" + n_node: int n_edge: int n_graph: int @@ -35,16 +36,16 @@ def get_raw_datasets() -> Dict[str, tf.data.Dataset]: ds_builder = tfds.builder('ogbg_molpcba') ds_builder.download_and_prepare() ds_splits = ['train', 'validation', 'test'] - datasets = { - split: ds_builder.as_dataset(split=split) for split in ds_splits - } + datasets = {split: ds_builder.as_dataset(split=split) for split in ds_splits} return datasets -def get_datasets(batch_size: int, - add_virtual_node: bool = True, - add_undirected_edges: bool = True, - add_self_loops: bool = True) -> Dict[str, tf.data.Dataset]: +def get_datasets( + batch_size: int, + add_virtual_node: bool = True, + add_undirected_edges: bool = True, + add_self_loops: bool = True, +) -> Dict[str, tf.data.Dataset]: """Returns datasets of batched GraphsTuples, organized by split.""" if batch_size <= 1: raise ValueError('Batch size must be > 1 to account for padding graphs.') @@ -62,16 +63,17 @@ def get_datasets(batch_size: int, # Process each split separately. for split_name in datasets: - # Convert to GraphsTuple. datasets[split_name] = datasets[split_name].map( convert_to_graphs_tuple_fn, num_parallel_calls=tf.data.AUTOTUNE, - deterministic=True) + deterministic=True, + ) # Compute the padding budget for the requested batch size. - budget = estimate_padding_budget_for_batch_size(datasets['train'], batch_size, - num_estimation_graphs=100) + budget = estimate_padding_budget_for_batch_size( + datasets['train'], batch_size, num_estimation_graphs=100 + ) # Pad an example graph to see what the output shapes will be. # We will use this shape information when creating the tf.data.Dataset. @@ -81,7 +83,6 @@ def get_datasets(batch_size: int, # Process each split separately. for split_name, dataset_split in datasets.items(): - # Repeat and shuffle the training split. if split_name == 'train': dataset_split = dataset_split.shuffle(100, reshuffle_each_iteration=True) @@ -93,10 +94,11 @@ def get_datasets(batch_size: int, graphs_tuple_iterator=iter(dataset_split), n_node=budget.n_node, n_edge=budget.n_edge, - n_graph=budget.n_graph) + n_graph=budget.n_graph, + ) dataset_split = tf.data.Dataset.from_generator( - batching_fn, - output_signature=padded_graphs_spec) + batching_fn, output_signature=padded_graphs_spec + ) # We cache the validation and test sets, since these are small. if split_name in ['validation', 'test']: @@ -106,10 +108,12 @@ def get_datasets(batch_size: int, return datasets -def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], - add_virtual_node: bool, - add_undirected_edges: bool, - add_self_loops: bool) -> jraph.GraphsTuple: +def convert_to_graphs_tuple( + graph: Dict[str, tf.Tensor], + add_virtual_node: bool, + add_undirected_edges: bool, + add_self_loops: bool, +) -> jraph.GraphsTuple: """Converts a dictionary of tf.Tensors to a GraphsTuple.""" num_nodes = tf.squeeze(graph['num_nodes']) num_edges = tf.squeeze(graph['num_edges']) @@ -124,14 +128,10 @@ def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], # The feature vectors for the virtual node # and the new edges are set to all zeros. if add_virtual_node: - nodes = tf.concat( - [nodes, tf.zeros_like(nodes[0, None])], axis=0) - senders = tf.concat( - [senders, tf.range(num_nodes)], axis=0) - receivers = tf.concat( - [receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0) - edges = tf.concat( - [edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) + nodes = tf.concat([nodes, tf.zeros_like(nodes[0, None])], axis=0) + senders = tf.concat([senders, tf.range(num_nodes)], axis=0) + receivers = tf.concat([receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0) + edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) num_edges += num_nodes num_nodes += 1 @@ -164,9 +164,8 @@ def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], def estimate_padding_budget_for_batch_size( - dataset: tf.data.Dataset, - batch_size: int, - num_estimation_graphs: int) -> GraphsTupleSize: + dataset: tf.data.Dataset, batch_size: int, num_estimation_graphs: int +) -> GraphsTupleSize: """Estimates the padding budget for a dataset of unbatched GraphsTuples. Args: @@ -204,7 +203,8 @@ def next_multiple_of_64(val: float): padding_budget = GraphsTupleSize( n_node=next_multiple_of_64(num_nodes_per_graph_estimate * batch_size), n_edge=next_multiple_of_64(num_edges_per_graph_estimate * batch_size), - n_graph=batch_size) + n_graph=batch_size, + ) return padding_budget @@ -218,7 +218,13 @@ def get_tensor_spec(array: np.ndarray): specs = {} for field in [ - 'nodes', 'edges', 'senders', 'receivers', 'globals', 'n_node', 'n_edge' + 'nodes', + 'edges', + 'senders', + 'receivers', + 'globals', + 'n_node', + 'n_edge', ]: field_sample = getattr(graph, field) specs[field] = get_tensor_spec(field_sample) @@ -230,4 +236,5 @@ def get_graphs_tuple_size(graph: jraph.GraphsTuple): return GraphsTupleSize( n_node=np.sum(graph.n_node), n_edge=np.sum(graph.n_edge), - n_graph=np.shape(graph.n_node)[0]) + n_graph=np.shape(graph.n_node)[0], + ) diff --git a/examples/ogbg_molpcba/input_pipeline_test.py b/examples/ogbg_molpcba/input_pipeline_test.py index a43b92ec12..bdb7a25ccc 100644 --- a/examples/ogbg_molpcba/input_pipeline_test.py +++ b/examples/ogbg_molpcba/input_pipeline_test.py @@ -45,7 +45,8 @@ def get_dummy_graphs(): datasets = {} for split in ['train', 'validation', 'test']: datasets[split] = tf.data.Dataset.from_generator( - get_dummy_graphs, output_signature=graphs_spec) + get_dummy_graphs, output_signature=graphs_spec + ) return datasets @@ -61,7 +62,8 @@ def setUp(self): ) def test_estimate_padding_budget_valid(self, valid_batch_size): budget = input_pipeline.estimate_padding_budget_for_batch_size( - self.datasets['train'], valid_batch_size, num_estimation_graphs=1) + self.datasets['train'], valid_batch_size, num_estimation_graphs=1 + ) self.assertEqual(budget.n_graph, valid_batch_size) @parameterized.product( @@ -70,7 +72,8 @@ def test_estimate_padding_budget_valid(self, valid_batch_size): def test_estimate_padding_budget_invalid(self, invalid_batch_size): with self.assertRaises(ValueError): input_pipeline.estimate_padding_budget_for_batch_size( - self.datasets['train'], invalid_batch_size, num_estimation_graphs=1) + self.datasets['train'], invalid_batch_size, num_estimation_graphs=1 + ) if __name__ == '__main__': diff --git a/examples/ogbg_molpcba/main.py b/examples/ogbg_molpcba/main.py index caea210dc3..683fc49692 100644 --- a/examples/ogbg_molpcba/main.py +++ b/examples/ogbg_molpcba/main.py @@ -36,7 +36,8 @@ 'config', None, 'File path to the training hyperparameter configuration.', - lock_config=True) + lock_config=True, +) def main(argv): @@ -53,10 +54,12 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}') - platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, - FLAGS.workdir, 'workdir') + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' + ) + platform.work_unit().create_artifact( + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/ogbg_molpcba/models.py b/examples/ogbg_molpcba/models.py index af1302282c..de513c6c19 100644 --- a/examples/ogbg_molpcba/models.py +++ b/examples/ogbg_molpcba/models.py @@ -21,13 +21,15 @@ import jraph -def add_graphs_tuples(graphs: jraph.GraphsTuple, - other_graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: +def add_graphs_tuples( + graphs: jraph.GraphsTuple, other_graphs: jraph.GraphsTuple +) -> jraph.GraphsTuple: """Adds the nodes, edges and global features from other_graphs to graphs.""" return graphs._replace( nodes=graphs.nodes + other_graphs.nodes, edges=graphs.edges + other_graphs.edges, - globals=graphs.globals + other_graphs.globals) + globals=graphs.globals + other_graphs.globals, + ) class MLP(nn.Module): @@ -44,8 +46,7 @@ def __call__(self, inputs): for size in self.feature_sizes: x = nn.Dense(features=size)(x) x = self.activation(x) - x = nn.Dropout( - rate=self.dropout_rate, deterministic=self.deterministic)(x) + x = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(x) return x @@ -68,7 +69,8 @@ def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: embedder = jraph.GraphMapFeatures( embed_node_fn=nn.Dense(self.latent_size), embed_edge_fn=nn.Dense(self.latent_size), - embed_global_fn=nn.Dense(self.latent_size)) + embed_global_fn=nn.Dense(self.latent_size), + ) processed_graphs = embedder(graphs) # Now, we will apply a Graph Network once for each message-passing round. @@ -76,29 +78,40 @@ def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: for _ in range(self.message_passing_steps): if self.use_edge_model: update_edge_fn = jraph.concatenated_args( - MLP(mlp_feature_sizes, + MLP( + mlp_feature_sizes, dropout_rate=self.dropout_rate, - deterministic=self.deterministic)) + deterministic=self.deterministic, + ) + ) else: update_edge_fn = None update_node_fn = jraph.concatenated_args( - MLP(mlp_feature_sizes, + MLP( + mlp_feature_sizes, dropout_rate=self.dropout_rate, - deterministic=self.deterministic)) + deterministic=self.deterministic, + ) + ) update_global_fn = jraph.concatenated_args( - MLP(mlp_feature_sizes, + MLP( + mlp_feature_sizes, dropout_rate=self.dropout_rate, - deterministic=self.deterministic)) + deterministic=self.deterministic, + ) + ) graph_net = jraph.GraphNetwork( update_node_fn=update_node_fn, update_edge_fn=update_edge_fn, - update_global_fn=update_global_fn) + update_global_fn=update_global_fn, + ) if self.skip_connections: processed_graphs = add_graphs_tuples( - graph_net(processed_graphs), processed_graphs) + graph_net(processed_graphs), processed_graphs + ) else: processed_graphs = graph_net(processed_graphs) @@ -111,8 +124,7 @@ def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: # Since our graph-level predictions will be at globals, we will # decode to get the required output logits. - decoder = jraph.GraphMapFeatures( - embed_global_fn=nn.Dense(self.output_globals_size)) + decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.output_globals_size)) processed_graphs = decoder(processed_graphs) return processed_graphs @@ -129,8 +141,10 @@ class GraphConvNet(nn.Module): skip_connections: bool = True layer_norm: bool = True deterministic: bool = True - pooling_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], # pytype: disable=annotation-type-mismatch # jax-ndarray - jnp.ndarray] = jraph.segment_mean + pooling_fn: Callable[ + [jnp.ndarray, jnp.ndarray, jnp.ndarray], # pytype: disable=annotation-type-mismatch # jax-ndarray + jnp.ndarray, + ] = jraph.segment_mean def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Pooling operation, taken from Jraph.""" @@ -142,10 +156,8 @@ def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: # Example: if you have `n_node=[1,2]`, we construct the tensor [0, 1, 1]. n_graph = graphs.n_node.shape[0] node_graph_indices = jnp.repeat( - jnp.arange(n_graph), - graphs.n_node, - axis=0, - total_repeat_length=sum_n_node) + jnp.arange(n_graph), graphs.n_node, axis=0, total_repeat_length=sum_n_node + ) # We use the aggregation function to pool the nodes per graph. pooled = self.pooling_fn(graphs.nodes, node_graph_indices, n_graph) # pytype: disable=wrong-arg-types # jax-ndarray return graphs._replace(globals=pooled) @@ -153,23 +165,27 @@ def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: @nn.compact def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: # We will first linearly project the original node features as 'embeddings'. - embedder = jraph.GraphMapFeatures( - embed_node_fn=nn.Dense(self.latent_size)) + embedder = jraph.GraphMapFeatures(embed_node_fn=nn.Dense(self.latent_size)) processed_graphs = embedder(graphs) # Now, we will apply the GCN once for each message-passing round. for _ in range(self.message_passing_steps): mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers update_node_fn = jraph.concatenated_args( - MLP(mlp_feature_sizes, + MLP( + mlp_feature_sizes, dropout_rate=self.dropout_rate, - deterministic=self.deterministic)) + deterministic=self.deterministic, + ) + ) graph_conv = jraph.GraphConvolution( - update_node_fn=update_node_fn, add_self_edges=True) + update_node_fn=update_node_fn, add_self_edges=True + ) if self.skip_connections: processed_graphs = add_graphs_tuples( - graph_conv(processed_graphs), processed_graphs) + graph_conv(processed_graphs), processed_graphs + ) else: processed_graphs = graph_conv(processed_graphs) @@ -182,8 +198,7 @@ def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: processed_graphs = self.pool(processed_graphs) # Now, we decode this to get the required output logits. - decoder = jraph.GraphMapFeatures( - embed_global_fn=nn.Dense(self.output_globals_size)) + decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.output_globals_size)) processed_graphs = decoder(processed_graphs) return processed_graphs diff --git a/examples/ogbg_molpcba/models_test.py b/examples/ogbg_molpcba/models_test.py index c8f39fae63..ffd0cd6037 100644 --- a/examples/ogbg_molpcba/models_test.py +++ b/examples/ogbg_molpcba/models_test.py @@ -48,7 +48,8 @@ def setUp(self): ) @parameterized.product( - dropout_rate=[0., 0.5, 1.], output_size=[50, 100], num_layers=[2]) + dropout_rate=[0.0, 0.5, 1.0], output_size=[50, 100], num_layers=[2] + ) def test_mlp(self, dropout_rate, output_size, num_layers): # Input definition. nodes = self.graphs.nodes @@ -58,16 +59,15 @@ def test_mlp(self, dropout_rate, output_size, num_layers): feature_sizes=[output_size] * num_layers, dropout_rate=dropout_rate, activation=lambda x: x, - deterministic=False) + deterministic=False, + ) nodes_after_mlp, _ = mlp.init_with_output(self.rngs, nodes) # Test that dropout actually worked. num_masked_entries = jnp.sum(nodes_after_mlp == 0) num_total_entries = jnp.size(nodes_after_mlp) - self.assertLessEqual(num_masked_entries, - (dropout_rate + 0.05) * num_total_entries) - self.assertLessEqual((dropout_rate - 0.05) * num_total_entries, - num_masked_entries) + self.assertLessEqual(num_masked_entries, (dropout_rate + 0.05) * num_total_entries) + self.assertLessEqual((dropout_rate - 0.05) * num_total_entries, num_masked_entries) # Test the shape of the output. self.assertEqual(nodes_after_mlp.shape[-1], output_size) @@ -77,13 +77,16 @@ def test_mlp(self, dropout_rate, output_size, num_layers): 'latent_size': 5, 'output_globals_size': 15, 'use_edge_model': True, - }, { + }, + { 'latent_size': 5, 'output_globals_size': 15, 'use_edge_model': False, - }) - def test_graph_net(self, latent_size: int, output_globals_size: int, - use_edge_model: bool): + }, + ) + def test_graph_net( + self, latent_size: int, output_globals_size: int, use_edge_model: bool + ): # Input definition. graphs = self.graphs num_nodes = jnp.sum(graphs.n_node) @@ -96,7 +99,8 @@ def test_graph_net(self, latent_size: int, output_globals_size: int, num_mlp_layers=2, message_passing_steps=2, output_globals_size=output_globals_size, - use_edge_model=use_edge_model) + use_edge_model=use_edge_model, + ) output, _ = net.init_with_output(self.rngs, graphs) # Output should be graph with the same topology, but a @@ -110,13 +114,10 @@ def test_graph_net(self, latent_size: int, output_globals_size: int, self.assertEqual(output.edges.shape, (num_edges, latent_size)) self.assertEqual(output.globals.shape, (num_graphs, output_globals_size)) - @parameterized.parameters({ - 'latent_size': 15, - 'output_globals_size': 15 - }, { - 'latent_size': 5, - 'output_globals_size': 5 - }) + @parameterized.parameters( + {'latent_size': 15, 'output_globals_size': 15}, + {'latent_size': 5, 'output_globals_size': 5}, + ) def test_graph_conv_net(self, latent_size: int, output_globals_size: int): graphs = self.graphs num_nodes = jnp.sum(graphs.n_node) @@ -127,7 +128,8 @@ def test_graph_conv_net(self, latent_size: int, output_globals_size: int): latent_size=latent_size, num_mlp_layers=2, message_passing_steps=2, - output_globals_size=output_globals_size) + output_globals_size=output_globals_size, + ) output, _ = net.init_with_output(self.rngs, graphs) # Output should be graph with the same topology, but a @@ -135,8 +137,7 @@ def test_graph_conv_net(self, latent_size: int, output_globals_size: int): self.assertIsInstance(output, jraph.GraphsTuple) self.assertSequenceAlmostEqual(output.n_node, graphs.n_node) self.assertSequenceAlmostEqual(output.n_edge, graphs.n_edge) - self.assertSequenceAlmostEqual(output.edges.flatten(), - graphs.edges.flatten()) + self.assertSequenceAlmostEqual(output.edges.flatten(), graphs.edges.flatten()) self.assertSequenceAlmostEqual(output.senders, graphs.senders) self.assertSequenceAlmostEqual(output.receivers, graphs.receivers) self.assertEqual(output.nodes.shape, (num_nodes, latent_size)) diff --git a/examples/ogbg_molpcba/ogbg_molpcba_benchmark.py b/examples/ogbg_molpcba/ogbg_molpcba_benchmark.py index 69a9b0513c..6515698e18 100644 --- a/examples/ogbg_molpcba/ogbg_molpcba_benchmark.py +++ b/examples/ogbg_molpcba/ogbg_molpcba_benchmark.py @@ -75,16 +75,11 @@ def test_1x_v100(self): # Use the reporting API to report single or multiple metrics/extras. self.report_wall_time(benchmark_time) self.report_metrics({ - 'sec_per_epoch': - sec_per_epoch, - 'test_accuracy': - end_test_accuracy, - 'test_mean_average_precision': - end_test_mean_average_precision, - 'validation_accuracy': - end_validation_accuracy, - 'validation_mean_average_precision': - end_validation_mean_average_precision, + 'sec_per_epoch': sec_per_epoch, + 'test_accuracy': end_test_accuracy, + 'test_mean_average_precision': end_test_mean_average_precision, + 'validation_accuracy': end_validation_accuracy, + 'validation_mean_average_precision': end_validation_mean_average_precision, }) self.report_extras({ 'model_name': 'Graph Convolutional Network', @@ -125,16 +120,11 @@ def test_cpu(self): # Use the reporting API to report single or multiple metrics/extras. self.report_wall_time(benchmark_time) self.report_metrics({ - 'sec_per_epoch': - sec_per_epoch, - 'test_accuracy': - end_test_accuracy, - 'test_mean_average_precision': - end_test_mean_average_precision, - 'validation_accuracy': - end_validation_accuracy, - 'validation_mean_average_precision': - end_validation_mean_average_precision, + 'sec_per_epoch': sec_per_epoch, + 'test_accuracy': end_test_accuracy, + 'test_mean_average_precision': end_test_mean_average_precision, + 'validation_accuracy': end_validation_accuracy, + 'validation_mean_average_precision': end_validation_mean_average_precision, }) self.report_extras({ 'model_name': 'Graph Convolutional Network', diff --git a/examples/ogbg_molpcba/train.py b/examples/ogbg_molpcba/train.py index cf158ea766..03fff0f1c9 100644 --- a/examples/ogbg_molpcba/train.py +++ b/examples/ogbg_molpcba/train.py @@ -40,8 +40,7 @@ import models -def create_model(config: ml_collections.ConfigDict, - deterministic: bool) -> nn.Module: +def create_model(config: ml_collections.ConfigDict, deterministic: bool) -> nn.Module: """Creates a Flax model, as specified by the config.""" if config.model == 'GraphNet': return models.GraphNet( @@ -53,7 +52,8 @@ def create_model(config: ml_collections.ConfigDict, skip_connections=config.skip_connections, layer_norm=config.layer_norm, use_edge_model=config.use_edge_model, - deterministic=deterministic) + deterministic=deterministic, + ) if config.model == 'GraphConvNet': return models.GraphConvNet( latent_size=config.latent_size, @@ -63,25 +63,23 @@ def create_model(config: ml_collections.ConfigDict, dropout_rate=config.dropout_rate, skip_connections=config.skip_connections, layer_norm=config.layer_norm, - deterministic=deterministic) + deterministic=deterministic, + ) raise ValueError(f'Unsupported model: {config.model}.') -def create_optimizer( - config: ml_collections.ConfigDict) -> optax.GradientTransformation: +def create_optimizer(config: ml_collections.ConfigDict) -> optax.GradientTransformation: """Creates an optimizer, as specified by the config.""" if config.optimizer == 'adam': - return optax.adam( - learning_rate=config.learning_rate) + return optax.adam(learning_rate=config.learning_rate) if config.optimizer == 'sgd': - return optax.sgd( - learning_rate=config.learning_rate, - momentum=config.momentum) + return optax.sgd(learning_rate=config.learning_rate, momentum=config.momentum) raise ValueError(f'Unsupported optimizer: {config.optimizer}.') -def binary_cross_entropy_with_mask(*, logits: jnp.ndarray, labels: jnp.ndarray, - mask: jnp.ndarray): +def binary_cross_entropy_with_mask( + *, logits: jnp.ndarray, labels: jnp.ndarray, mask: jnp.ndarray +): """Binary cross entropy loss for unnormalized logits, with masked elements.""" assert logits.shape == labels.shape == mask.shape assert len(logits.shape) == 2 @@ -92,18 +90,18 @@ def binary_cross_entropy_with_mask(*, logits: jnp.ndarray, labels: jnp.ndarray, # Numerically stable implementation of BCE loss. # This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits(). - positive_logits = (logits >= 0) + positive_logits = logits >= 0 relu_logits = jnp.where(positive_logits, logits, 0) abs_logits = jnp.where(positive_logits, logits, -logits) - return relu_logits - (logits * labels) + ( - jnp.log(1 + jnp.exp(-abs_logits))) + return relu_logits - (logits * labels) + (jnp.log(1 + jnp.exp(-abs_logits))) -def predictions_match_labels(*, logits: jnp.ndarray, labels: jnp.ndarray, - **kwargs) -> jnp.ndarray: +def predictions_match_labels( + *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs +) -> jnp.ndarray: """Returns a binary array indicating where predictions match the labels.""" del kwargs # Unused. - preds = (logits > 0) + preds = logits > 0 return (preds == labels).astype(jnp.float32) @@ -114,7 +112,8 @@ def add_prefix_to_keys(result: Dict[str, Any], prefix: str) -> Dict[str, Any]: @flax.struct.dataclass class MeanAveragePrecision( - metrics.CollectingMetric.from_outputs(('labels', 'logits', 'mask'))): + metrics.CollectingMetric.from_outputs(('labels', 'logits', 'mask')) +): """Computes the mean average precision (mAP) over different tasks.""" def compute(self): @@ -137,7 +136,8 @@ def compute(self): is_labeled = mask[:, task] if len(np.unique(labels[is_labeled, task])) >= 2: average_precisions[task] = sklearn.metrics.average_precision_score( - labels[is_labeled, task], probs[is_labeled, task]) + labels[is_labeled, task], probs[is_labeled, task] + ) # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. if np.isnan(average_precisions).all(): @@ -147,7 +147,6 @@ def compute(self): @flax.struct.dataclass class EvalMetrics(metrics.Collection): - accuracy: metrics.Average.from_fun(predictions_match_labels) loss: metrics.Average.from_output('loss') mean_average_precision: MeanAveragePrecision @@ -155,28 +154,27 @@ class EvalMetrics(metrics.Collection): @flax.struct.dataclass class TrainMetrics(metrics.Collection): - accuracy: metrics.Average.from_fun(predictions_match_labels) loss: metrics.Average.from_output('loss') def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Replaces the globals attribute with a constant feature for each graph.""" - return graphs._replace( - globals=jnp.ones([graphs.n_node.shape[0], 1])) + return graphs._replace(globals=jnp.ones([graphs.n_node.shape[0], 1])) -def get_predicted_logits(state: train_state.TrainState, - graphs: jraph.GraphsTuple, - rngs: Optional[Dict[str, jnp.ndarray]]) -> jnp.ndarray: +def get_predicted_logits( + state: train_state.TrainState, + graphs: jraph.GraphsTuple, + rngs: Optional[Dict[str, jnp.ndarray]], +) -> jnp.ndarray: """Get predicted logits from the network for input graphs.""" pred_graphs = state.apply_fn(state.params, graphs, rngs=rngs) logits = pred_graphs.globals return logits -def get_valid_mask(labels: jnp.ndarray, - graphs: jraph.GraphsTuple) -> jnp.ndarray: +def get_valid_mask(labels: jnp.ndarray, graphs: jraph.GraphsTuple) -> jnp.ndarray: """Gets the binary mask indicating only valid labels and graphs.""" # We have to ignore all NaN values - which indicate labels for which # the current graphs have no label. @@ -194,8 +192,9 @@ def get_valid_mask(labels: jnp.ndarray, @jax.jit def train_step( - state: train_state.TrainState, graphs: jraph.GraphsTuple, - rngs: Dict[str, jnp.ndarray] + state: train_state.TrainState, + graphs: jraph.GraphsTuple, + rngs: Dict[str, jnp.ndarray], ) -> Tuple[train_state.TrainState, metrics.Collection]: """Performs one update step over the current batch of graphs.""" @@ -211,8 +210,7 @@ def loss_fn(params, graphs): # Compute logits and resulting loss. logits = get_predicted_logits(curr_state, graphs, rngs) mask = get_valid_mask(labels, graphs) - loss = binary_cross_entropy_with_mask( - logits=logits, labels=labels, mask=mask) + loss = binary_cross_entropy_with_mask(logits=logits, labels=labels, mask=mask) mean_loss = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask) return mean_loss, (loss, logits, labels, mask) @@ -222,7 +220,8 @@ def loss_fn(params, graphs): state = state.apply_gradients(grads=grads) metrics_update = TrainMetrics.single_from_model_output( - loss=loss, logits=logits, labels=labels, mask=mask) + loss=loss, logits=logits, labels=labels, mask=mask + ) return state, metrics_update @@ -249,12 +248,15 @@ def evaluate_step( loss = binary_cross_entropy_with_mask(logits=logits, labels=labels, mask=mask) return EvalMetrics.single_from_model_output( - loss=loss, logits=logits, labels=labels, mask=mask) + loss=loss, logits=logits, labels=labels, mask=mask + ) -def evaluate_model(state: train_state.TrainState, - datasets: Dict[str, tf.data.Dataset], - splits: Iterable[str]) -> Dict[str, metrics.Collection]: +def evaluate_model( + state: train_state.TrainState, + datasets: Dict[str, tf.data.Dataset], + splits: Iterable[str], +) -> Dict[str, metrics.Collection]: """Evaluates the model on metrics over the specified splits.""" # Loop over each split independently. @@ -276,8 +278,9 @@ def evaluate_model(state: train_state.TrainState, return eval_metrics # pytype: disable=bad-return-type -def train_and_evaluate(config: ml_collections.ConfigDict, - workdir: str) -> train_state.TrainState: +def train_and_evaluate( + config: ml_collections.ConfigDict, workdir: str +) -> train_state.TrainState: """Execute model training and evaluation loop. Args: @@ -300,7 +303,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, config.batch_size, add_virtual_node=config.add_virtual_node, add_undirected_edges=config.add_undirected_edges, - add_self_loops=config.add_self_loops) + add_self_loops=config.add_self_loops, + ) train_iter = iter(datasets['train']) # Create and initialize the network. @@ -318,8 +322,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, # Create the training state. net = create_model(config, deterministic=False) - state = train_state.TrainState.create( - apply_fn=net.apply, params=params, tx=tx) + state = train_state.TrainState.create(apply_fn=net.apply, params=params, tx=tx) # Set up checkpointing of the model. # The input pipeline cannot be checkpointed in its current form, @@ -335,7 +338,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, # Hooks called periodically during training. report_progress = periodic_actions.ReportProgress( - num_train_steps=config.num_train_steps, writer=writer) + num_train_steps=config.num_train_steps, writer=writer + ) profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir) hooks = [report_progress, profiler] @@ -343,15 +347,13 @@ def train_and_evaluate(config: ml_collections.ConfigDict, logging.info('Starting training.') train_metrics = None for step in range(initial_step, config.num_train_steps + 1): - # Split PRNG key, to ensure different 'randomness' for every step. rng, dropout_rng = jax.random.split(rng) # Perform one step of training. with jax.profiler.StepTraceAnnotation('train', step_num=step): graphs = jax.tree_util.tree_map(np.asarray, next(train_iter)) - state, metrics_update = train_step( - state, graphs, rngs={'dropout': dropout_rng}) + state, metrics_update = train_step(state, graphs, rngs={'dropout': dropout_rng}) # Update metrics. if train_metrics is None: @@ -365,10 +367,9 @@ def train_and_evaluate(config: ml_collections.ConfigDict, hook(step) # Log, if required. - is_last_step = (step == config.num_train_steps - 1) + is_last_step = step == config.num_train_steps - 1 if step % config.log_every_steps == 0 or is_last_step: - writer.write_scalars(step, - add_prefix_to_keys(train_metrics.compute(), 'train')) + writer.write_scalars(step, add_prefix_to_keys(train_metrics.compute(), 'train')) train_metrics = None # Evaluate on validation and test splits, if required. @@ -380,7 +381,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, eval_metrics = evaluate_model(eval_state, datasets, splits=splits) for split in splits: writer.write_scalars( - step, add_prefix_to_keys(eval_metrics[split].compute(), split)) + step, add_prefix_to_keys(eval_metrics[split].compute(), split) + ) # Checkpoint model, if required. if step % config.checkpoint_every_steps == 0 or is_last_step: diff --git a/examples/ogbg_molpcba/train_test.py b/examples/ogbg_molpcba/train_test.py index 985703970b..120d59c2b0 100644 --- a/examples/ogbg_molpcba/train_test.py +++ b/examples/ogbg_molpcba/train_test.py @@ -73,13 +73,14 @@ def get_dummy_graphs(): datasets = {} for split in ['train', 'validation', 'test']: datasets[split] = tf.data.Dataset.from_generator( - get_dummy_graphs, output_signature=dummy_graph_spec) + get_dummy_graphs, output_signature=dummy_graph_spec + ) return datasets def get_dummy_datasets( - dataset_length: int, - batch_size: Optional[int] = None) -> Dict[str, tf.data.Dataset]: + dataset_length: int, batch_size: Optional[int] = None +) -> Dict[str, tf.data.Dataset]: """Returns dummy datasets, mocking input_pipeline.get_datasets().""" datasets = get_dummy_raw_datasets(dataset_length) @@ -94,24 +95,24 @@ def get_dummy_datasets( # Process each split separately. for split_name in datasets: - # Convert to GraphsTuple. datasets[split_name] = datasets[split_name].map( convert_to_graphs_tuple_fn, num_parallel_calls=tf.data.AUTOTUNE, - deterministic=True) + deterministic=True, + ) # If batch size is None, do not batch. if batch_size is not None: budget = input_pipeline.estimate_padding_budget_for_batch_size( - datasets['train'], batch_size, num_estimation_graphs=1) + datasets['train'], batch_size, num_estimation_graphs=1 + ) # Pad an example graph to see what the output shapes will be. # We will use this shape information when creating the tf.data.Dataset. example_graph = next(datasets['train'].as_numpy_iterator()) example_padded_graph = jraph.pad_with_graphs(example_graph, *budget) - padded_graphs_spec = input_pipeline.specs_from_graphs_tuple( - example_padded_graph) + padded_graphs_spec = input_pipeline.specs_from_graphs_tuple(example_padded_graph) # Batch and pad each split separately. for split, dataset_split in datasets.items(): @@ -120,10 +121,11 @@ def get_dummy_datasets( graphs_tuple_iterator=iter(dataset_split), n_node=budget.n_node, n_edge=budget.n_edge, - n_graph=budget.n_graph) + n_graph=budget.n_graph, + ) datasets[split] = tf.data.Dataset.from_generator( - batching_fn, - output_signature=padded_graphs_spec) + batching_fn, output_signature=padded_graphs_spec + ) return datasets @@ -147,8 +149,12 @@ def setUp(self): @parameterized.product( probs=[[[0.8, 0.9, 0.3, 0.5]]], - labels=[[[1, 0, 1, 1]], [[1, 0, 1, jnp.nan]], [[1, 0, jnp.nan, jnp.nan]], - [[1, jnp.nan, jnp.nan, jnp.nan]]], + labels=[ + [[1, 0, 1, 1]], + [[1, 0, 1, jnp.nan]], + [[1, 0, jnp.nan, jnp.nan]], + [[1, jnp.nan, jnp.nan, jnp.nan]], + ], ) def test_binary_cross_entropy_loss(self, probs, labels): probs = jnp.asarray(probs) @@ -158,10 +164,12 @@ def test_binary_cross_entropy_loss(self, probs, labels): mask = ~jnp.isnan(labels) loss_array = train.binary_cross_entropy_with_mask( - logits=logits, labels=labels, mask=mask) + logits=logits, labels=labels, mask=mask + ) loss = average_with_mask(loss_array, mask) expected_loss_array = -(jnp.log(probs) * labels) - ( - jnp.log(1 - probs) * (1 - labels)) + jnp.log(1 - probs) * (1 - labels) + ) expected_loss = average_with_mask(expected_loss_array, mask) self.assertAlmostEqual(loss, expected_loss, places=5) @@ -169,19 +177,22 @@ def test_binary_cross_entropy_loss(self, probs, labels): @parameterized.named_parameters( dict( testcase_name='no_valid_tasks', - logits=[[-1., 1.], [1., 1.], [2., -1.]], + logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, -1.0]], labels=[[jnp.nan, jnp.nan], [jnp.nan, jnp.nan], [jnp.nan, jnp.nan]], - expected_result=jnp.nan), + expected_result=jnp.nan, + ), dict( testcase_name='1_valid_task', - logits=[[-1., 1.], [1., 1.], [2., -1.]], + logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, -1.0]], labels=[[0, jnp.nan], [1, jnp.nan], [1, jnp.nan]], - expected_result=1.), + expected_result=1.0, + ), dict( testcase_name='2_valid_tasks', - logits=[[-1., 1.], [1., 1.], [2., -1.]], + logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, -1.0]], labels=[[0, jnp.nan], [1, 0], [1, 1]], - expected_result=0.75), + expected_result=0.75, + ), ) def test_mean_average_precision(self, logits, labels, expected_result): logits = jnp.asarray(logits) @@ -189,7 +200,8 @@ def test_mean_average_precision(self, logits, labels, expected_result): mask = ~jnp.isnan(labels) mean_average_precision = train.MeanAveragePrecision.from_model_output( - logits=logits, labels=labels, mask=mask).compute() + logits=logits, labels=labels, mask=mask + ).compute() if jnp.isnan(expected_result): self.assertTrue(jnp.isnan(mean_average_precision)) @@ -198,12 +210,16 @@ def test_mean_average_precision(self, logits, labels, expected_result): @parameterized.parameters( dict( - loss=[[0.5, 1.], [1.5, 1.3], [2., 1.2]], - logits=[[-1., 1.], [1., 1.], [2., 0.]], + loss=[[0.5, 1.0], [1.5, 1.3], [2.0, 1.2]], + logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, 0.0]], labels=[[0, jnp.nan], [1, 0], [0, 1]], mask=[[True, False], [True, True], [False, False]], - expected_results={'loss': 1.1, 'accuracy': 2/3, - 'mean_average_precision': 1.0}), + expected_results={ + 'loss': 1.1, + 'accuracy': 2 / 3, + 'mean_average_precision': 1.0, + }, + ), ) def test_eval_metrics(self, loss, logits, labels, mask, expected_results): loss = jnp.asarray(loss) @@ -215,17 +231,20 @@ def test_eval_metrics(self, loss, logits, labels, mask, expected_results): with warnings.catch_warnings(): warnings.simplefilter('ignore', category=RuntimeWarning) eval_metrics = train.EvalMetrics.single_from_model_output( - loss=loss, logits=logits, labels=labels, mask=mask).compute() + loss=loss, logits=logits, labels=labels, mask=mask + ).compute() for metric in expected_results: self.assertAlmostEqual(expected_results[metric], eval_metrics[metric]) @parameterized.parameters( - dict(loss=[[0.5, 1.], [1.5, 1.3], [2., 1.2]], - logits=[[-1., 1.], [1., 1.], [2., 0.]], - labels=[[0, jnp.nan], [1, 0], [0, 1]], - mask=[[True, False], [True, True], [False, False]], - expected_results={'loss': 1.1, 'accuracy': 2/3}), + dict( + loss=[[0.5, 1.0], [1.5, 1.3], [2.0, 1.2]], + logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, 0.0]], + labels=[[0, jnp.nan], [1, 0], [0, 1]], + mask=[[True, False], [True, True], [False, False]], + expected_results={'loss': 1.1, 'accuracy': 2 / 3}, + ), ) def test_train_metrics(self, loss, logits, labels, mask, expected_results): loss = jnp.asarray(loss) @@ -234,7 +253,8 @@ def test_train_metrics(self, loss, logits, labels, mask, expected_results): mask = jnp.asarray(mask) train_metrics = train.TrainMetrics.single_from_model_output( - loss=loss, logits=logits, labels=labels, mask=mask).compute() + loss=loss, logits=logits, labels=labels, mask=mask + ).compute() for metric in expected_results: self.assertAlmostEqual(expected_results[metric], train_metrics[metric]) @@ -255,18 +275,17 @@ def test_train_step(self): # Create the training state. net = train.create_model(config, deterministic=False) state = train_state.TrainState.create( - apply_fn=net.apply, params=params, tx=optimizer) + apply_fn=net.apply, params=params, tx=optimizer + ) # Perform one step of updates. # We use the same batch of graphs that we used for initialization. - state, train_metrics = train.train_step( - state, init_graphs, rngs={'dropout': rng}) + state, train_metrics = train.train_step(state, init_graphs, rngs={'dropout': rng}) # Check that none of the parameters are NaNs! params = flax.core.unfreeze(state.params) flat_params = { - '/'.join(k): v - for k, v in flax.traverse_util.flatten_dict(params).items() + '/'.join(k): v for k, v in flax.traverse_util.flatten_dict(params).items() } for array in flat_params.values(): self.assertTrue(jnp.all(~jnp.isnan(array))) @@ -285,7 +304,8 @@ def test_evaluate_step(self): _, init_rng = jax.random.split(self.rng) init_graphs = next(self.datasets['train'].as_numpy_iterator()) init_graphs_preprocessed = init_graphs._replace( - globals=jnp.zeros([init_graphs.n_node.shape[0], 1])) + globals=jnp.zeros([init_graphs.n_node.shape[0], 1]) + ) init_net = train.create_model(config, deterministic=True) params = jax.jit(init_net.init)(init_rng, init_graphs_preprocessed) @@ -295,7 +315,8 @@ def test_evaluate_step(self): # Create the evaluation state. eval_net = train.create_model(config, deterministic=True) eval_state = train_state.TrainState.create( - apply_fn=eval_net.apply, params=params, tx=optimizer) + apply_fn=eval_net.apply, params=params, tx=optimizer + ) # Perform one step of evaluation. # We use the same batch of graphs that we used for initialization. diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index 71ee3d3eb6..b274cd8de2 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -28,9 +28,10 @@ @functools.partial(jax.jit, static_argnums=0) def policy_action( - apply_fn: Callable[..., Any], - params: flax.core.frozen_dict.FrozenDict, - state: np.ndarray): + apply_fn: Callable[..., Any], + params: flax.core.frozen_dict.FrozenDict, + state: np.ndarray, +): """Forward pass of the network. Args: @@ -46,7 +47,8 @@ def policy_action( ExpTuple = collections.namedtuple( - 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) + 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done'] +) class RemoteSimulator: @@ -59,7 +61,8 @@ def __init__(self, game: str): """Start the remote process and create Pipe() to communicate with it.""" parent_conn, child_conn = multiprocessing.Pipe() self.proc = multiprocessing.Process( - target=rcv_action_send_exp, args=(child_conn, game)) + target=rcv_action_send_exp, args=(child_conn, game) + ) self.proc.daemon = True self.conn = parent_conn self.proc.start() diff --git a/examples/ppo/configs/default.py b/examples/ppo/configs/default.py index 296f7c8a88..aa213a2167 100644 --- a/examples/ppo/configs/default.py +++ b/examples/ppo/configs/default.py @@ -16,6 +16,7 @@ import ml_collections + def get_config(): """Get the default configuration. diff --git a/examples/ppo/env_utils.py b/examples/ppo/env_utils.py index e058745904..03b3a21b36 100644 --- a/examples/ppo/env_utils.py +++ b/examples/ppo/env_utils.py @@ -21,6 +21,7 @@ import seed_rl_atari_preprocessing + class ClipRewardEnv(gym.RewardWrapper): """Adapted from OpenAI baselines. @@ -34,6 +35,7 @@ def reward(self, reward): """Bin reward to {+1, 0, -1} by its sign.""" return np.sign(reward) + class FrameStack: """Implements stacking of `num_frames` last frames of the game. @@ -41,9 +43,8 @@ class FrameStack: """ def __init__( - self, - preproc: seed_rl_atari_preprocessing.AtariPreprocessing, - num_frames: int): + self, preproc: seed_rl_atari_preprocessing.AtariPreprocessing, num_frames: int + ): self.preproc = preproc self.num_frames = num_frames self.frames = collections.deque(maxlen=num_frames) @@ -63,15 +64,17 @@ def _get_array(self): assert len(self.frames) == self.num_frames return np.concatenate(self.frames, axis=-1) + def create_env(game: str, clip_rewards: bool): """Create a FrameStack object that serves as environment for the `game`.""" env = gym.make(game) if clip_rewards: - env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} + env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env) stack = FrameStack(preproc, num_frames=4) return stack + def get_num_actions(game: str): """Get the number of possible actions of a given Atari game. diff --git a/examples/ppo/models.py b/examples/ppo/models.py index d79eef9217..4892be5108 100644 --- a/examples/ppo/models.py +++ b/examples/ppo/models.py @@ -38,15 +38,18 @@ def __call__(self, x): github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py """ dtype = jnp.float32 - x = x.astype(dtype) / 255. - x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1', - dtype=dtype)(x) + x = x.astype(dtype) / 255.0 + x = nn.Conv( + features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1', dtype=dtype + )(x) x = nn.relu(x) - x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2', - dtype=dtype)(x) + x = nn.Conv( + features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2', dtype=dtype + )(x) x = nn.relu(x) - x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3', - dtype=dtype)(x) + x = nn.Conv( + features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3', dtype=dtype + )(x) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=512, name='hidden', dtype=dtype)(x) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index b0d5b53790..a5caad5c55 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -42,7 +42,8 @@ def gae_advantages( terminal_masks: np.ndarray, values: np.ndarray, discount: float, - gae_param: float): + gae_param: float, +): """Use Generalized Advantage Estimation (GAE) to compute advantages. As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementation uses @@ -59,11 +60,13 @@ def gae_advantages( Returns: advantages: calculated advantages shaped (actor_steps, num_agents) """ - assert rewards.shape[0] + 1 == values.shape[0], ('One more value needed; Eq. ' - '(12) in PPO paper requires ' - 'V(s_{t+1}) for delta_t') + assert rewards.shape[0] + 1 == values.shape[0], ( + 'One more value needed; Eq. ' + '(12) in PPO paper requires ' + 'V(s_{t+1}) for delta_t' + ) advantages = [] - gae = 0. + gae = 0.0 for t in reversed(range(len(rewards))): # Masks used to set next state value to 0 for terminal states. value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] @@ -82,7 +85,8 @@ def loss_fn( minibatch: Tuple, clip_param: float, vf_coeff: float, - entropy_coeff: float): + entropy_coeff: float, +): """Evaluate the loss function. Compute loss as a sum of three components: the negative of the PPO clipped @@ -112,18 +116,18 @@ def loss_fn( value_loss = jnp.mean(jnp.square(returns - values), axis=0) - entropy = jnp.sum(-probs*log_probs, axis=1).mean() + entropy = jnp.sum(-probs * log_probs, axis=1).mean() log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) ratios = jnp.exp(log_probs_act_taken - old_log_probs) # Advantage normalization (following the OpenAI baselines). advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) pg_loss = ratios * advantages - clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, - 1. + clip_param) + clipped_loss = advantages * jax.lax.clamp(1.0 - clip_param, ratios, 1.0 + clip_param) ppo_loss = -jnp.mean(jnp.minimum(pg_loss, clipped_loss), axis=0) - return ppo_loss + vf_coeff*value_loss - entropy_coeff*entropy + return ppo_loss + vf_coeff * value_loss - entropy_coeff * entropy + @functools.partial(jax.jit, static_argnums=(2,)) def train_step( @@ -133,7 +137,8 @@ def train_step( *, clip_param: float, vf_coeff: float, - entropy_coeff: float): + entropy_coeff: float +): """Compilable train step. Runs an entire epoch of training (i.e. the loop over minibatches within @@ -158,20 +163,24 @@ def train_step( """ iterations = trajectories[0].shape[0] // batch_size trajectories = jax.tree_util.tree_map( - lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories) - loss = 0. + lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories + ) + loss = 0.0 for batch in zip(*trajectories): grad_fn = jax.value_and_grad(loss_fn) - l, grads = grad_fn(state.params, state.apply_fn, batch, clip_param, vf_coeff, - entropy_coeff) + l, grads = grad_fn( + state.params, state.apply_fn, batch, clip_param, vf_coeff, entropy_coeff + ) loss += l state = state.apply_gradients(grads=grads) return state, loss + def get_experience( state: train_state.TrainState, simulators: List[agent.RemoteSimulator], - steps_per_actor: int): + steps_per_actor: int, +): """Collect experience from agents. Runs `steps_per_actor` time steps of the game for each of the `simulators`. @@ -201,12 +210,14 @@ def get_experience( all_experience.append(experiences) return all_experience + def process_experience( experience: List[List[agent.ExpTuple]], actor_steps: int, num_agents: int, gamma: float, - lambda_: float): + lambda_: float, +): """Process experience for training, including advantage estimation. Args: @@ -245,8 +256,9 @@ def process_experience( # After preprocessing, concatenate data from all agents. trajectories = (states, actions, log_probs, returns, advantages) trajectory_len = num_agents * actor_steps - trajectories = tuple(map( - lambda x: np.reshape(x, (trajectory_len,) + x.shape[2:]), trajectories)) + trajectories = tuple( + map(lambda x: np.reshape(x, (trajectory_len,) + x.shape[2:]), trajectories) + ) return trajectories @@ -258,26 +270,21 @@ def get_initial_params(key: np.ndarray, model: nn.Module): return initial_params -def create_train_state(params, model: nn.Module, - config: ml_collections.ConfigDict, train_steps: int) -> train_state.TrainState: +def create_train_state( + params, model: nn.Module, config: ml_collections.ConfigDict, train_steps: int +) -> train_state.TrainState: if config.decaying_lr_and_clip_param: lr = optax.linear_schedule( - init_value=config.learning_rate, end_value=0., - transition_steps=train_steps) + init_value=config.learning_rate, end_value=0.0, transition_steps=train_steps + ) else: lr = config.learning_rate tx = optax.adam(lr) - state = train_state.TrainState.create( - apply_fn=model.apply, - params=params, - tx=tx) + state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state -def train( - model: models.ActorCritic, - config: ml_collections.ConfigDict, - model_dir: str): +def train(model: models.ActorCritic, config: ml_collections.ConfigDict, model_dir: str): """Main training loop. Args: @@ -290,8 +297,7 @@ def train( """ game = config.game + 'NoFrameskip-v4' - simulators = [agent.RemoteSimulator(game) - for _ in range(config.num_agents)] + simulators = [agent.RemoteSimulator(game) for _ in range(config.num_agents)] summary_writer = tensorboard.SummaryWriter(model_dir) summary_writer.hparams(dict(config)) loop_steps = config.total_frames // (config.num_agents * config.actor_steps) @@ -300,12 +306,15 @@ def train( # train_step does multiple steps per call for better performance # compute number of steps per call here to convert between the number of # train steps and the inner number of optimizer steps - iterations_per_step = (config.num_agents * config.actor_steps - // config.batch_size) + iterations_per_step = config.num_agents * config.actor_steps // config.batch_size initial_params = get_initial_params(jax.random.PRNGKey(0), model) - state = create_train_state(initial_params, model, config, - loop_steps * config.num_epochs * iterations_per_step) + state = create_train_state( + initial_params, + model, + config, + loop_steps * config.num_epochs * iterations_per_step, + ) del initial_params state = checkpoints.restore_checkpoint(model_dir, state) # number of train iterations done by each train_step @@ -322,22 +331,27 @@ def train( logging.info('Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score) # Core training code. - alpha = 1. - step / loop_steps if config.decaying_lr_and_clip_param else 1. - all_experiences = get_experience( - state, simulators, config.actor_steps) + alpha = 1.0 - step / loop_steps if config.decaying_lr_and_clip_param else 1.0 + all_experiences = get_experience(state, simulators, config.actor_steps) trajectories = process_experience( - all_experiences, config.actor_steps, config.num_agents, config.gamma, - config.lambda_) + all_experiences, + config.actor_steps, + config.num_agents, + config.gamma, + config.lambda_, + ) clip_param = config.clip_param * alpha for _ in range(config.num_epochs): - permutation = np.random.permutation( - config.num_agents * config.actor_steps) + permutation = np.random.permutation(config.num_agents * config.actor_steps) trajectories = tuple(x[permutation] for x in trajectories) state, _ = train_step( - state, trajectories, config.batch_size, + state, + trajectories, + config.batch_size, clip_param=clip_param, vf_coeff=config.vf_coeff, - entropy_coeff=config.entropy_coeff) + entropy_coeff=config.entropy_coeff, + ) if (step + 1) % checkpoint_frequency == 0: checkpoints.save_checkpoint(model_dir, state, step + 1) return state diff --git a/examples/ppo/ppo_lib_test.py b/examples/ppo/ppo_lib_test.py index 5656f4622b..80732aeb86 100644 --- a/examples/ppo/ppo_lib_test.py +++ b/examples/ppo/ppo_lib_test.py @@ -29,38 +29,42 @@ # test GAE class TestGAE(absltest.TestCase): + def test_gae_shape_on_random(self): # create random data, simulating 4 parallel envs and 20 time_steps envs, steps = 10, 100 - rewards = np.random.choice([-1., 0., 1.], size=(steps, envs), - p=[0.01, 0.98, 0.01]) + rewards = np.random.choice( + [-1.0, 0.0, 1.0], size=(steps, envs), p=[0.01, 0.98, 0.01] + ) terminal_masks = np.ones(shape=(steps, envs), dtype=np.float64) values = np.random.random(size=(steps + 1, envs)) discount = 0.99 gae_param = 0.95 - adv = ppo_lib.gae_advantages(rewards, terminal_masks, values, discount, - gae_param) + adv = ppo_lib.gae_advantages(rewards, terminal_masks, values, discount, gae_param) self.assertEqual(adv.shape, (steps, envs)) def test_gae_hardcoded(self): - #test on small example that can be verified by hand - rewards = np.array([[1., 0.], [0., 0.], [-1., 1.]]) - #one of the two episodes terminated in the middle - terminal_masks = np.array([[1., 1.], [0., 1.], [1., 1.]]) - values = np.array([[1., 1.], [1., 1.], [1., 1.], [1., 1.]]) + # test on small example that can be verified by hand + rewards = np.array([[1.0, 0.0], [0.0, 0.0], [-1.0, 1.0]]) + # one of the two episodes terminated in the middle + terminal_masks = np.array([[1.0, 1.0], [0.0, 1.0], [1.0, 1.0]]) + values = np.array([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]) discount = 0.5 gae_param = 0.25 - correct_gae = np.array([[0.375, -0.5546875], [-1., -0.4375], [-1.5, 0.5]]) - actual_gae = ppo_lib.gae_advantages(rewards, terminal_masks, values, - discount, gae_param) + correct_gae = np.array([[0.375, -0.5546875], [-1.0, -0.4375], [-1.5, 0.5]]) + actual_gae = ppo_lib.gae_advantages( + rewards, terminal_masks, values, discount, gae_param + ) np_testing.assert_allclose(actual_gae, correct_gae) + + # test environment and preprocessing class TestEnvironmentPreprocessing(absltest.TestCase): + def choose_random_game(self): - games = ['BeamRider', 'Breakout', 'Pong', - 'Qbert', 'Seaquest', 'SpaceInvaders'] + games = ['BeamRider', 'Breakout', 'Pong', 'Qbert', 'Seaquest', 'SpaceInvaders'] ind = np.random.choice(len(games)) - return games[ind] + "NoFrameskip-v4" + return games[ind] + 'NoFrameskip-v4' def test_creation(self): frame_shape = (84, 84, 4) @@ -78,12 +82,14 @@ def test_step(self): for a in actions: obs, reward, done, info = env.step(a) self.assertEqual(obs.shape, frame_shape) - self.assertTrue(reward <= 1. and reward >= -1.) + self.assertTrue(reward <= 1.0 and reward >= -1.0) self.assertTrue(isinstance(done, bool)) self.assertTrue(isinstance(info, dict)) + # test the model (creation and forward pass) class TestModel(absltest.TestCase): + def choose_random_outputs(self): return np.random.choice([4, 5, 6, 7, 8, 9]) @@ -93,20 +99,20 @@ def test_model(self): params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = np.random.random(size=(test_batch_size,) + obs_shape) - log_probs, values = agent.policy_action( - module.apply, params, random_input) + log_probs, values = agent.policy_action(module.apply, params, random_input) self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = np.sum(np.exp(log_probs), axis=1) - self.assertEqual(sum_probs.shape, (test_batch_size, )) - np_testing.assert_allclose(sum_probs, np.ones((test_batch_size, )), - atol=1e-6) + self.assertEqual(sum_probs.shape, (test_batch_size,)) + np_testing.assert_allclose(sum_probs, np.ones((test_batch_size,)), atol=1e-6) + # test one optimization step class TestOptimizationStep(absltest.TestCase): + def generate_random_data(self, num_actions): - data_len = 256 # equal to one default-sized batch + data_len = 256 # equal to one default-sized batch state_shape = (84, 84, 4) - states = np.random.randint(0, 255, size=((data_len, ) + state_shape)) + states = np.random.randint(0, 255, size=((data_len,) + state_shape)) actions = np.random.choice(num_actions, size=data_len) old_log_probs = np.random.random(size=data_len) returns = np.random.random(size=data_len) @@ -123,16 +129,20 @@ def test_optimization_step(self): module = models.ActorCritic(num_outputs) initial_params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module) config = ml_collections.ConfigDict({ - 'learning_rate': 2.5e-4, - 'decaying_lr_and_clip_param': True, + 'learning_rate': 2.5e-4, + 'decaying_lr_and_clip_param': True, }) state = ppo_lib.create_train_state(initial_params, module, config, 1000) state, _ = ppo_lib.train_step( - state, trn_data, batch_size, + state, + trn_data, + batch_size, clip_param=clip_param, vf_coeff=vf_coeff, - entropy_coeff=entropy_coeff) + entropy_coeff=entropy_coeff, + ) self.assertIsInstance(state, train_state.TrainState) + if __name__ == '__main__': absltest.main() diff --git a/examples/ppo/ppo_main.py b/examples/ppo/ppo_main.py index 4f0d6c8f0f..33059ac7cd 100644 --- a/examples/ppo/ppo_main.py +++ b/examples/ppo/ppo_main.py @@ -30,13 +30,15 @@ flags.DEFINE_string( 'workdir', default='/tmp/ppo_training', - help=('Directory to save checkpoints and logging info.')) + help=('Directory to save checkpoints and logging info.'), +) config_flags.DEFINE_config_file( 'config', - "configs/default.py", + 'configs/default.py', 'File path to the default configuration file.', - lock_config=True) + lock_config=True, +) def main(argv): @@ -49,5 +51,6 @@ def main(argv): model = models.ActorCritic(num_outputs=num_actions) ppo_lib.train(model, config, FLAGS.workdir) + if __name__ == '__main__': app.run(main) diff --git a/examples/ppo/seed_rl_atari_preprocessing.py b/examples/ppo/seed_rl_atari_preprocessing.py index bd5c002963..a33dd447c2 100644 --- a/examples/ppo/seed_rl_atari_preprocessing.py +++ b/examples/ppo/seed_rl_atari_preprocessing.py @@ -51,8 +51,14 @@ class AtariPreprocessing: and R2D2 papers. """ - def __init__(self, environment: gym.Env, frame_skip=4, terminal_on_life_loss=False, - screen_size=84, max_random_noops=0): + def __init__( + self, + environment: gym.Env, + frame_skip=4, + terminal_on_life_loss=False, + screen_size=84, + max_random_noops=0, + ): """Constructor for an Atari 2600 preprocessor. Args: environment: Gym environment whose observations are preprocessed. @@ -67,11 +73,13 @@ def __init__(self, environment: gym.Env, frame_skip=4, terminal_on_life_loss=Fal ValueError: if frame_skip or screen_size are not strictly positive. """ if frame_skip <= 0: - raise ValueError('Frame skip should be strictly positive, got {}'. - format(frame_skip)) + raise ValueError( + 'Frame skip should be strictly positive, got {}'.format(frame_skip) + ) if screen_size <= 0: - raise ValueError('Target screen size should be strictly positive, got {}'. - format(screen_size)) + raise ValueError( + 'Target screen size should be strictly positive, got {}'.format(screen_size) + ) self.environment = environment self.terminal_on_life_loss = terminal_on_life_loss @@ -84,7 +92,7 @@ def __init__(self, environment: gym.Env, frame_skip=4, terminal_on_life_loss=Fal # frames. self.screen_buffer = [ np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), - np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8) + np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), ] self.game_over = False @@ -94,8 +102,9 @@ def __init__(self, environment: gym.Env, frame_skip=4, terminal_on_life_loss=Fal def observation_space(self): # Return the observation space adjusted to match the shape of the processed # observations. - return Box(low=0, high=255, shape=(self.screen_size, self.screen_size, 1), - dtype=np.uint8) + return Box( + low=0, high=255, shape=(self.screen_size, self.screen_size, 1), dtype=np.uint8 + ) @property def action_space(self): @@ -169,7 +178,7 @@ def step(self, action): episode is over. info: Gym API's info data structure. """ - accumulated_reward = 0. + accumulated_reward = 0.0 for time_step in range(self.frame_skip): # We bypass the Gym observation altogether and directly fetch the @@ -216,11 +225,14 @@ def _pool_and_resize(self): """ # Pool if there are enough screens to do so. if self.frame_skip > 1: - np.maximum(self.screen_buffer[0], self.screen_buffer[1], - out=self.screen_buffer[0]) - - transformed_image = cv2.resize(self.screen_buffer[0], - (self.screen_size, self.screen_size), - interpolation=cv2.INTER_LINEAR) + np.maximum( + self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0] + ) + + transformed_image = cv2.resize( + self.screen_buffer[0], + (self.screen_size, self.screen_size), + interpolation=cv2.INTER_LINEAR, + ) int_image = np.asarray(transformed_image, dtype=np.uint8) return np.expand_dims(int_image, axis=2) diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index c8f1d0b303..3f747f74fe 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -28,7 +28,8 @@ def policy_test( n_episodes: int, apply_fn: Callable[..., Any], params: flax.core.frozen_dict.FrozenDict, - game: str): + game: str, +): """Perform a test of the policy in Atari environment. Args: diff --git a/examples/seq2seq/input_pipeline.py b/examples/seq2seq/input_pipeline.py index 6c460903d5..9df6c2927f 100644 --- a/examples/seq2seq/input_pipeline.py +++ b/examples/seq2seq/input_pipeline.py @@ -20,7 +20,7 @@ import jax.numpy as jnp import numpy as np -Array = Any # pylint: disable=invalid-name +Array = Any # pylint: disable=invalid-name class CharacterTable: @@ -28,10 +28,8 @@ class CharacterTable: def __init__(self, chars: str, max_len_query_digit: int = 3) -> None: self._chars = sorted(set(chars)) - self._char_indices = { - ch: idx + 2 for idx, ch in enumerate(self._chars)} - self._indices_char = { - idx + 2: ch for idx, ch in enumerate(self._chars)} + self._char_indices = {ch: idx + 2 for idx, ch in enumerate(self._chars)} + self._indices_char = {idx + 2: ch for idx, ch in enumerate(self._chars)} self._indices_char[self.pad_id] = '_' # Maximum length of a single input digit. self._max_len_query_digit = max_len_query_digit @@ -74,8 +72,7 @@ def decoder_input_shape(self) -> Tuple[int, int, int]: def encode(self, inputs: str) -> np.ndarray: """Encodes from string to list of integers.""" - return np.array( - [self._char_indices[char] for char in inputs] + [self.eos_id]) + return np.array([self._char_indices[char] for char in inputs] + [self.eos_id]) def decode(self, inputs: Array) -> str: """Decodes from list of integers to string.""" @@ -92,7 +89,8 @@ def one_hot(self, tokens: np.ndarray) -> np.ndarray: return vecs def encode_onehot( - self, batch_inputs: Array, max_len: Optional[int] = None) -> np.ndarray: + self, batch_inputs: Array, max_len: Optional[int] = None + ) -> np.ndarray: """One-hot encodes a string input.""" if max_len is None: @@ -102,8 +100,7 @@ def encode_str(s): tokens = self.encode(s) unpadded_len = len(tokens) if unpadded_len > max_len: - raise ValueError( - f'Sequence too long ({len(tokens)}>{max_len}): \'{s}\'') + raise ValueError(f"Sequence too long ({len(tokens)}>{max_len}): '{s}'") tokens = np.pad(tokens, [(0, max_len - len(tokens))], mode='constant') return self.one_hot(tokens) @@ -115,7 +112,8 @@ def decode_onehot(self, batch_inputs: Array) -> np.ndarray: return np.array(list(map(decode_inputs, batch_inputs))) def generate_examples( - self, num_examples: int) -> Generator[Tuple[str, str], None, None]: + self, num_examples: int + ) -> Generator[Tuple[str, str], None, None]: """Yields `num_examples` examples.""" for _ in range(num_examples): max_digit = pow(10, self._max_len_query_digit) - 1 @@ -138,7 +136,8 @@ def get_batch(self, batch_size: int) -> Dict[str, np.ndarray]: def mask_sequences(sequence_batch: Array, lengths: Array) -> Array: """Sets positions beyond the length of each sequence to 0.""" return sequence_batch * ( - lengths[:, np.newaxis] > np.arange(sequence_batch.shape[1])[np.newaxis]) + lengths[:, np.newaxis] > np.arange(sequence_batch.shape[1])[np.newaxis] + ) def get_sequence_lengths(sequence_batch: Array, eos_id: int) -> Array: @@ -150,5 +149,5 @@ def get_sequence_lengths(sequence_batch: Array, eos_id: int) -> Array: return jnp.where( eos_row[jnp.arange(eos_row.shape[0]), eos_idx], eos_idx + 1, - sequence_batch.shape[1] # if there is no EOS, use full length + sequence_batch.shape[1], # if there is no EOS, use full length ) diff --git a/examples/seq2seq/models.py b/examples/seq2seq/models.py index 3c1e3fad7f..e78782c15b 100644 --- a/examples/seq2seq/models.py +++ b/examples/seq2seq/models.py @@ -37,6 +37,7 @@ class DecoderLSTMCell(nn.RNNCellBase): teacher_force: See docstring on Seq2seq module. vocab_size: Size of the vocabulary. """ + features: int teacher_force: bool vocab_size: int @@ -56,8 +57,7 @@ def __call__( categorical_rng = self.make_rng('lstm') predicted_token = jax.random.categorical(categorical_rng, logits) # Convert to one-hot encoding. - prediction = jax.nn.one_hot( - predicted_token, self.vocab_size, dtype=jnp.float32) + prediction = jax.nn.one_hot(predicted_token, self.vocab_size, dtype=jnp.float32) return (lstm_state, prediction), (logits, prediction) @@ -78,14 +78,16 @@ class Seq2seq(nn.Module): vocab_size: the size of the vocabulary. eos_id: EOS id. """ + teacher_force: bool hidden_size: int vocab_size: int eos_id: int = 1 @nn.compact - def __call__(self, encoder_inputs: Array, - decoder_inputs: Array) -> Tuple[Array, Array]: + def __call__( + self, encoder_inputs: Array, decoder_inputs: Array + ) -> Tuple[Array, Array]: """Applies the seq2seq model. Args: @@ -106,13 +108,18 @@ def __call__(self, encoder_inputs: Array, """ # Encode inputs. encoder = nn.RNN(nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder') - decoder = nn.RNN(DecoderLSTMCell(decoder_inputs.shape[-1], self.teacher_force, self.vocab_size), - split_rngs={'params': False, 'lstm': True}, name='decoder') + decoder = nn.RNN( + DecoderLSTMCell(decoder_inputs.shape[-1], self.teacher_force, self.vocab_size), + split_rngs={'params': False, 'lstm': True}, + name='decoder', + ) seq_lengths = self.get_seq_lengths(encoder_inputs) encoder_state, _ = encoder(encoder_inputs, seq_lengths=seq_lengths) - logits, predictions = decoder(decoder_inputs[:, :-1], initial_carry=(encoder_state, decoder_inputs[:, 0])) + logits, predictions = decoder( + decoder_inputs[:, :-1], initial_carry=(encoder_state, decoder_inputs[:, 0]) + ) return logits, predictions @@ -124,4 +131,3 @@ def get_seq_lengths(self, inputs: Array) -> Array: seq_lengths = jnp.argmax(inputs == self.eos_id, axis=-1) return seq_lengths - diff --git a/examples/seq2seq/train.py b/examples/seq2seq/train.py index 513fe902c4..6428461fd7 100644 --- a/examples/seq2seq/train.py +++ b/examples/seq2seq/train.py @@ -43,44 +43,44 @@ flags.DEFINE_string('workdir', default='.', help='Where to store log output.') flags.DEFINE_float( - 'learning_rate', - default=0.003, - help=('The learning rate for the Adam optimizer.')) + 'learning_rate', default=0.003, help=('The learning rate for the Adam optimizer.') +) -flags.DEFINE_integer( - 'batch_size', default=128, help=('Batch size for training.')) +flags.DEFINE_integer('batch_size', default=128, help=('Batch size for training.')) -flags.DEFINE_integer( - 'hidden_size', default=512, help=('Hidden size of the LSTM.')) +flags.DEFINE_integer('hidden_size', default=512, help=('Hidden size of the LSTM.')) -flags.DEFINE_integer( - 'num_train_steps', default=10000, help=('Number of train steps.')) +flags.DEFINE_integer('num_train_steps', default=10000, help=('Number of train steps.')) flags.DEFINE_integer( 'decode_frequency', default=200, - help=('Frequency of decoding during training, e.g. every 1000 steps.')) + help=('Frequency of decoding during training, e.g. every 1000 steps.'), +) flags.DEFINE_integer( - 'max_len_query_digit', - default=3, - help=('Maximum length of a single input digit.')) + 'max_len_query_digit', default=3, help=('Maximum length of a single input digit.') +) def get_model(ctable: CTable, *, teacher_force: bool = False) -> models.Seq2seq: - return models.Seq2seq(teacher_force=teacher_force, - hidden_size=FLAGS.hidden_size, eos_id=ctable.eos_id, - vocab_size=ctable.vocab_size) + return models.Seq2seq( + teacher_force=teacher_force, + hidden_size=FLAGS.hidden_size, + eos_id=ctable.eos_id, + vocab_size=ctable.vocab_size, + ) -def get_initial_params(model: models.Seq2seq, rng: PRNGKey, - ctable: CTable) -> Dict[str, Any]: +def get_initial_params( + model: models.Seq2seq, rng: PRNGKey, ctable: CTable +) -> Dict[str, Any]: """Returns the initial parameters of a seq2seq model.""" rng1, rng2 = jax.random.split(rng) variables = model.init( {'params': rng1, 'lstm': rng2}, jnp.ones(ctable.encoder_input_shape, jnp.float32), - jnp.ones(ctable.decoder_input_shape, jnp.float32) + jnp.ones(ctable.decoder_input_shape, jnp.float32), ) return variables['params'] @@ -90,8 +90,7 @@ def get_train_state(rng: PRNGKey, ctable: CTable) -> train_state.TrainState: model = get_model(ctable) params = get_initial_params(model, rng, ctable) tx = optax.adam(FLAGS.learning_rate) - state = train_state.TrainState.create( - apply_fn=model.apply, params=params, tx=tx) + state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state @@ -102,8 +101,7 @@ def cross_entropy_loss(logits: Array, labels: Array, lengths: Array) -> float: return -masked_xe -def compute_metrics(logits: Array, labels: Array, - eos_id: int) -> Dict[str, float]: +def compute_metrics(logits: Array, labels: Array, eos_id: int) -> Dict[str, float]: """Computes metrics and returns them.""" lengths = get_sequence_lengths(labels, eos_id) loss = cross_entropy_loss(logits, labels, lengths) @@ -122,19 +120,18 @@ def compute_metrics(logits: Array, labels: Array, @jax.jit -def train_step(state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey, - eos_id: int) -> Tuple[train_state.TrainState, Dict[str, float]]: +def train_step( + state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey, eos_id: int +) -> Tuple[train_state.TrainState, Dict[str, float]]: """Trains one step.""" labels = batch['answer'][:, 1:] lstm_key = jax.random.fold_in(lstm_rng, state.step) def loss_fn(params): - logits, _ = state.apply_fn({'params': params}, - batch['query'], - batch['answer'], - rngs={'lstm': lstm_key}) - loss = cross_entropy_loss( - logits, labels, get_sequence_lengths(labels, eos_id)) + logits, _ = state.apply_fn( + {'params': params}, batch['query'], batch['answer'], rngs={'lstm': lstm_key} + ) + loss = cross_entropy_loss(logits, labels, get_sequence_lengths(labels, eos_id)) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) @@ -147,28 +144,32 @@ def loss_fn(params): def log_decode(question: str, inferred: str, golden: str): """Logs the given question, inferred query, and correct query.""" - suffix = '(CORRECT)' if inferred == golden else (f'(INCORRECT) ' - f'correct={golden}') + suffix = '(CORRECT)' if inferred == golden else (f'(INCORRECT) ' f'correct={golden}') logging.info('DECODE: %s = %s %s', question, inferred, suffix) @functools.partial(jax.jit, static_argnums=3) -def decode(params: Dict[str, Any], inputs: Array, decode_rng: PRNGKey, - ctable: CTable) -> Array: +def decode( + params: Dict[str, Any], inputs: Array, decode_rng: PRNGKey, ctable: CTable +) -> Array: """Decodes inputs.""" init_decoder_input = ctable.one_hot(ctable.encode('=')[0:1]) - init_decoder_inputs = jnp.tile(init_decoder_input, - (inputs.shape[0], ctable.max_output_len, 1)) + init_decoder_inputs = jnp.tile( + init_decoder_input, (inputs.shape[0], ctable.max_output_len, 1) + ) model = get_model(ctable, teacher_force=False) - _, predictions = model.apply({'params': params}, - inputs, - init_decoder_inputs, - rngs={'lstm': decode_rng}) + _, predictions = model.apply( + {'params': params}, inputs, init_decoder_inputs, rngs={'lstm': decode_rng} + ) return predictions -def decode_batch(state: train_state.TrainState, batch: Dict[str, Array], - decode_rng: PRNGKey, ctable: CTable): +def decode_batch( + state: train_state.TrainState, + batch: Dict[str, Array], + decode_rng: PRNGKey, + ctable: CTable, +): """Decodes and log results for a batch.""" inputs, outputs = batch['query'], batch['answer'][:, 1:] decode_rng = jax.random.fold_in(decode_rng, state.step) diff --git a/examples/seq2seq/train_test.py b/examples/seq2seq/train_test.py index 2a454b5d3b..4655cdbbb8 100644 --- a/examples/seq2seq/train_test.py +++ b/examples/seq2seq/train_test.py @@ -35,12 +35,14 @@ def create_ctable(chars='0123456789+= '): def create_train_state(ctable): - model = models.Seq2seq(teacher_force=False, - hidden_size=train.FLAGS.hidden_size, vocab_size=ctable.vocab_size) + model = models.Seq2seq( + teacher_force=False, + hidden_size=train.FLAGS.hidden_size, + vocab_size=ctable.vocab_size, + ) params = train.get_initial_params(model, jax.random.PRNGKey(0), ctable) tx = optax.adam(train.FLAGS.learning_rate) - state = train_state.TrainState.create( - apply_fn=model.apply, params=params, tx=tx) + state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state @@ -58,32 +60,26 @@ def test_character_table(self): def test_mask_sequences(self): np.testing.assert_equal( input_pipeline.mask_sequences( - np.arange(1, 13).reshape((4, 3)), - np.array([3, 2, 1, 0]) + np.arange(1, 13).reshape((4, 3)), np.array([3, 2, 1, 0]) ), - np.array( - [[1, 2, 3], - [4, 5, 0], - [7, 0, 0], - [0, 0, 0]] - ) + np.array([[1, 2, 3], [4, 5, 0], [7, 0, 0], [0, 0, 0]]), ) def test_get_sequence_lengths(self): - oh_sequence_batch = jax.vmap( - functools.partial(jax.nn.one_hot, num_classes=4))( - np.array([[0, 1, 0], [1, 0, 2], [1, 2, 0], [1, 2, 3]])) + oh_sequence_batch = jax.vmap(functools.partial(jax.nn.one_hot, num_classes=4))( + np.array([[0, 1, 0], [1, 0, 2], [1, 2, 0], [1, 2, 3]]) + ) np.testing.assert_equal( input_pipeline.get_sequence_lengths(oh_sequence_batch, eos_id=0), - np.array([1, 2, 3, 3], np.int32) + np.array([1, 2, 3, 3], np.int32), ) np.testing.assert_equal( input_pipeline.get_sequence_lengths(oh_sequence_batch, eos_id=1), - np.array([2, 1, 1, 1], np.int32) + np.array([2, 1, 1, 1], np.int32), ) np.testing.assert_equal( input_pipeline.get_sequence_lengths(oh_sequence_batch, eos_id=2), - np.array([3, 3, 2, 2], np.int32) + np.array([3, 3, 2, 2], np.int32), ) def test_train_one_step(self): @@ -104,5 +100,6 @@ def test_decode_batch(self): state = create_train_state(ctable) train.decode_batch(state, batch, key, ctable) + if __name__ == '__main__': absltest.main() diff --git a/examples/sst2/build_vocabulary.py b/examples/sst2/build_vocabulary.py index 04c44de52e..24c24b4425 100755 --- a/examples/sst2/build_vocabulary.py +++ b/examples/sst2/build_vocabulary.py @@ -26,13 +26,15 @@ def get_tokenized_sequences( - dataset: tf.data.Dataset, - tokenizer: tftext.Tokenizer = tftext.WhitespaceTokenizer(), - input_key: str = 'sentence') -> Iterable[Sequence[bytes]]: + dataset: tf.data.Dataset, + tokenizer: tftext.Tokenizer = tftext.WhitespaceTokenizer(), + input_key: str = 'sentence', +) -> Iterable[Sequence[bytes]]: """Returns tokenized sequences for vocabulary building.""" dataset = dataset.map( lambda example: tokenizer.tokenize(example[input_key]), - num_parallel_calls=tf.data.experimental.AUTOTUNE) + num_parallel_calls=tf.data.experimental.AUTOTUNE, + ) yield from tfds.as_numpy(dataset) @@ -49,8 +51,7 @@ def get_tokenized_sequences( # Builds the vocabulary from the tokenized sequences. # A token needs to appear at least 3 times to be in the vocabulary. You can # play with this. It is there to make sure we don't overfit on rare words. - vocab = vocabulary.Vocabulary( - tokenized_sequences=tokenized_sequences, min_freq=3) + vocab = vocabulary.Vocabulary(tokenized_sequences=tokenized_sequences, min_freq=3) vocab.save('vocab.txt') logging.info('Total time elapsed: %f s', time.time() - start_time) diff --git a/examples/sst2/input_pipeline.py b/examples/sst2/input_pipeline.py index a8a03b37bd..e126feb66f 100755 --- a/examples/sst2/input_pipeline.py +++ b/examples/sst2/input_pipeline.py @@ -113,40 +113,45 @@ def get_bucketed_batches( bucket_batch_sizes, padded_shapes=padded_shapes, pad_to_bucket_boundary=True, - drop_remainder=drop_remainder) + drop_remainder=drop_remainder, + ) if shuffle: # For shuffling we need to know how many training examples we have. num_examples = get_num_examples(dataset) num_batches = num_examples // batch_size - return dataset.shuffle( - num_examples, seed=shuffle_seed, - reshuffle_each_iteration=True).apply(bucket_fn).shuffle( - num_batches, - seed=shuffle_seed, - reshuffle_each_iteration=True).prefetch( - tf.data.experimental.AUTOTUNE) + return ( + dataset.shuffle(num_examples, seed=shuffle_seed, reshuffle_each_iteration=True) + .apply(bucket_fn) + .shuffle(num_batches, seed=shuffle_seed, reshuffle_each_iteration=True) + .prefetch(tf.data.experimental.AUTOTUNE) + ) return dataset.apply(bucket_fn).prefetch(tf.data.experimental.AUTOTUNE) -def vocab_to_hashtable(vocab: vocabulary.Vocabulary, - unk_idx: int) -> tf.lookup.StaticHashTable: +def vocab_to_hashtable( + vocab: vocabulary.Vocabulary, unk_idx: int +) -> tf.lookup.StaticHashTable: """Returns a TF lookup table (token -> ID) from a vocabulary.""" return tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer( - list(vocab.keys()), list(vocab.values())), default_value=unk_idx) + tf.lookup.KeyValueTensorInitializer(list(vocab.keys()), list(vocab.values())), + default_value=unk_idx, + ) -def vocab_to_inverse_hashtable(vocab: vocabulary.Vocabulary, - unk_token: bytes) -> tf.lookup.StaticHashTable: +def vocab_to_inverse_hashtable( + vocab: vocabulary.Vocabulary, unk_token: bytes +) -> tf.lookup.StaticHashTable: """Returns an inverse TF lookup table (ID -> token) from a vocabulary.""" return tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( list(vocab.values()), list(vocab.keys()), key_dtype=tf.int64, - value_dtype=tf.string), - default_value=unk_token) + value_dtype=tf.string, + ), + default_value=unk_token, + ) def _is_text_field(feature_name_and_type): @@ -164,11 +169,13 @@ def _is_class_label(feature_name_and_type): class TextDataset: """A text dataset with one sequence as input and a label.""" - def __init__(self, - tfds_name: str = 'glue/sst2', - vocab_path: str = 'vocab.txt', - tokenizer: text.Tokenizer = text.WhitespaceTokenizer(), - split='train'): + def __init__( + self, + tfds_name: str = 'glue/sst2', + vocab_path: str = 'vocab.txt', + tokenizer: text.Tokenizer = text.WhitespaceTokenizer(), + split='train', + ): """Initializes the SST2 data source.""" self.dataset, self.info = tfds.load(tfds_name, split=split, with_info=True) @@ -186,7 +193,8 @@ def __init__(self, self.tokenizer = tokenizer self.tf_vocab = vocab_to_hashtable(self.vocab, unk_idx=self.vocab.unk_idx) self.examples = self.dataset.map( - self.prepare_example, num_parallel_calls=AUTOTUNE).cache() + self.prepare_example, num_parallel_calls=AUTOTUNE + ).cache() @property def padded_shapes(self): @@ -200,8 +208,7 @@ def example_length_fn(self, example: Example) -> tf.Tensor: def add_bos_eos(self, sequence: tf.Tensor) -> tf.Tensor: """Prepends BOS ID and appends EOS ID to a sequence of token IDs.""" - return tf.concat( - [[self.vocab.bos_idx], sequence, [self.vocab.eos_idx]], 0) + return tf.concat([[self.vocab.bos_idx], sequence, [self.vocab.eos_idx]], 0) def prepare_example(self, example: Example) -> Example: """Prepares an example by converting text to token IDs.""" @@ -214,34 +221,40 @@ def prepare_example(self, example: Example) -> Example: example['label'] = label return example - def get_batches(self, - batch_size: int, - drop_remainder: bool = False, - shuffle: bool = False, - shuffle_seed: Optional[int] = None, - fixed_pad_length: Optional[int] = None, - dataset: Optional[tf.data.Dataset] = None): + def get_batches( + self, + batch_size: int, + drop_remainder: bool = False, + shuffle: bool = False, + shuffle_seed: Optional[int] = None, + fixed_pad_length: Optional[int] = None, + dataset: Optional[tf.data.Dataset] = None, + ): """Returns an iterator with padded batches for the provided dataset.""" if dataset is None: dataset = self.examples if shuffle: buffer_size = get_num_examples(dataset) dataset = dataset.shuffle( - buffer_size, seed=shuffle_seed, reshuffle_each_iteration=True) + buffer_size, seed=shuffle_seed, reshuffle_each_iteration=True + ) padded_shapes = {k: v for k, v in self.padded_shapes.items()} if fixed_pad_length is not None: padded_shapes['token_ids'] = fixed_pad_length return dataset.padded_batch( - batch_size, padded_shapes=padded_shapes, drop_remainder=drop_remainder) - - def get_bucketed_batches(self, - batch_size: int, - bucket_size: int, - max_input_length: int, - drop_remainder: bool = False, - shuffle: bool = False, - shuffle_seed: Optional[int] = None, - dataset: Optional[tf.data.Dataset] = None): + batch_size, padded_shapes=padded_shapes, drop_remainder=drop_remainder + ) + + def get_bucketed_batches( + self, + batch_size: int, + bucket_size: int, + max_input_length: int, + drop_remainder: bool = False, + shuffle: bool = False, + shuffle_seed: Optional[int] = None, + dataset: Optional[tf.data.Dataset] = None, + ): """Returns an iterator with bucketed batches for the provided dataset.""" if dataset is None: dataset = self.examples @@ -254,4 +267,5 @@ def get_bucketed_batches(self, self.example_length_fn, shuffle=shuffle, shuffle_seed=shuffle_seed, - drop_remainder=drop_remainder) + drop_remainder=drop_remainder, + ) diff --git a/examples/sst2/input_pipeline_test.py b/examples/sst2/input_pipeline_test.py index 42b68716e1..5aa2bbd18b 100644 --- a/examples/sst2/input_pipeline_test.py +++ b/examples/sst2/input_pipeline_test.py @@ -46,8 +46,7 @@ def _get_dataset(self, vocab_path): """Uses mock data to create the dataset.""" # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] - data_dir = str(flax_root_dir) + \ - '/.tfds/metadata' # pylint: disable=unused-variable + data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): return input_pipeline.TextDataset(vocab_path=vocab_path, split='train') @@ -56,8 +55,11 @@ def test_bucketed_dataset(self): batch_size = 2 bucket_size = 8 for batch in self.dataset.get_bucketed_batches( - batch_size=batch_size, - bucket_size=bucket_size, max_input_length=60, shuffle=False).take(3): + batch_size=batch_size, + bucket_size=bucket_size, + max_input_length=60, + shuffle=False, + ).take(3): # Because of bucketing, sequence length must be multiple of bucket_size. length = batch['token_ids'].numpy().shape[-1] self.assertEqual(0, length % bucket_size) @@ -67,8 +69,7 @@ def test_bucketed_dataset(self): def test_batched_dataset(self): """Tests that the length of a batch matches the longest sequence.""" batch_size = 2 - for batch in self.dataset.get_batches( - batch_size=batch_size, shuffle=False).take(1): + for batch in self.dataset.get_batches(batch_size=batch_size, shuffle=False).take(1): # Each batch is padded to the maximum sentence length in the batch. max_length_in_batch = max(batch['length'].numpy()) length = batch['token_ids'].numpy().shape[-1] @@ -81,8 +82,8 @@ def test_batched_dataset_fixed_length(self): batch_size = 2 fixed_pad_length = 77 for batch in self.dataset.get_batches( - batch_size=batch_size, shuffle=False, - fixed_pad_length=fixed_pad_length).take(1): + batch_size=batch_size, shuffle=False, fixed_pad_length=fixed_pad_length + ).take(1): length = batch['token_ids'].numpy().shape[-1] self.assertEqual(fixed_pad_length, length) diff --git a/examples/sst2/main.py b/examples/sst2/main.py index ce25270d34..5085c17559 100644 --- a/examples/sst2/main.py +++ b/examples/sst2/main.py @@ -35,7 +35,8 @@ 'config', None, 'File path to the training hyperparameter configuration.', - lock_config=True) + lock_config=True, +) flags.mark_flags_as_required(['config', 'workdir']) @@ -52,10 +53,12 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}') - platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, - FLAGS.workdir, 'workdir') + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' + ) + platform.work_unit().create_artifact( + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/sst2/models.py b/examples/sst2/models.py index 1ce7501030..502e45e041 100644 --- a/examples/sst2/models.py +++ b/examples/sst2/models.py @@ -85,6 +85,7 @@ class WordDropout(nn.Module): This is basically the same as `nn.Dropout`, but allows specifying the value of dropped out items. """ + dropout_rate: float unk_idx: int deterministic: Optional[bool] = None @@ -92,8 +93,9 @@ class WordDropout(nn.Module): @nn.compact def __call__(self, inputs: Array, deterministic: Optional[bool] = None): deterministic = nn.module.merge_param( - 'deterministic', self.deterministic, deterministic) - if deterministic or self.dropout_rate == 0.: + 'deterministic', self.deterministic, deterministic + ) + if deterministic or self.dropout_rate == 0.0: return inputs rng = self.make_rng('dropout') mask = jax.random.bernoulli(rng, p=self.dropout_rate, shape=inputs.shape) @@ -112,13 +114,13 @@ class Embedder(nn.Module): word_dropout_rate: Percentage of input words to replace with unk_idx. unk_idx: The index (integer) to use to replace inputs for word dropout. """ + vocab_size: int embedding_size: int - embedding_init: Callable[..., - Array] = nn.initializers.normal(stddev=0.1) + embedding_init: Callable[..., Array] = nn.initializers.normal(stddev=0.1) frozen: bool = False - dropout_rate: float = 0. - word_dropout_rate: float = 0. + dropout_rate: float = 0.0 + word_dropout_rate: float = 0.0 unk_idx: Optional[int] = None deterministic: Optional[bool] = None dtype: jnp.dtype = jnp.float32 @@ -127,16 +129,15 @@ def setup(self): self.embedding = self.param( 'embedding', self.embedding_init, - (self.vocab_size, - self.embedding_size), - self.dtype) + (self.vocab_size, self.embedding_size), + self.dtype, + ) self.dropout_layer = nn.Dropout(rate=self.dropout_rate) self.word_dropout_layer = WordDropout( - dropout_rate=self.word_dropout_rate, - unk_idx=self.unk_idx) + dropout_rate=self.word_dropout_rate, unk_idx=self.unk_idx + ) - def __call__(self, inputs: Array, - deterministic: Optional[bool] = None) -> Array: + def __call__(self, inputs: Array, deterministic: Optional[bool] = None) -> Array: """Embeds the input sequences and applies word dropout and dropout. Args: @@ -148,7 +149,8 @@ def __call__(self, inputs: Array, embedding_size]. """ deterministic = nn.module.merge_param( - 'deterministic', self.deterministic, deterministic) + 'deterministic', self.deterministic, deterministic + ) inputs = self.word_dropout_layer(inputs, deterministic=deterministic) embedded_inputs = self.embedding[inputs] @@ -161,13 +163,16 @@ def __call__(self, inputs: Array, class SimpleLSTM(nn.Module): """A simple unidirectional LSTM.""" + hidden_size: int @functools.partial( nn.transforms.scan, variable_broadcast='params', - in_axes=1, out_axes=1, - split_rngs={'params': False}) + in_axes=1, + out_axes=1, + split_rngs={'params': False}, + ) @nn.compact def __call__(self, carry, x): return nn.OptimizedLSTMCell(self.hidden_size)(carry, x) @@ -175,11 +180,13 @@ def __call__(self, carry, x): def initialize_carry(self, input_shape): # Use fixed random key since default state init fn is just zeros. return nn.OptimizedLSTMCell(self.hidden_size, parent=None).initialize_carry( - jax.random.PRNGKey(0), input_shape) + jax.random.PRNGKey(0), input_shape + ) class SimpleBiLSTM(nn.Module): """A simple bi-directional LSTM.""" + hidden_size: int def setup(self): @@ -213,6 +220,7 @@ class MLP(nn.Module): output_bias: If False, do not use a bias term in the last layer. deterministic: Disables dropout if set to True. """ + hidden_size: int output_size: int activation: Callable[..., Any] = nn.tanh @@ -236,7 +244,8 @@ def __call__(self, inputs: Array, deterministic: Optional[bool] = None): The MLP output [batch_size, ..., output_size] """ deterministic = nn.module.merge_param( - 'deterministic', self.deterministic, deterministic) + 'deterministic', self.deterministic, deterministic + ) hidden = self.intermediate_layer(inputs) hidden = self.activation(hidden) hidden = self.dropout_layer(hidden, deterministic=deterministic) @@ -260,6 +269,7 @@ class KeysOnlyMlpAttention(nn.Module): Attributes: hidden_size: The hidden size of the MLP that computes the attention score. """ + hidden_size: int @nn.compact @@ -299,23 +309,25 @@ class AttentionClassifier(nn.Module): of the inputs, and inside the MLP. Applied when `deterministic` is False. deterministic: Disables dropout if True. """ + hidden_size: int output_size: int - dropout_rate: float = 0. + dropout_rate: float = 0.0 deterministic: Optional[bool] = None def setup(self): self.dropout_layer = nn.Dropout(rate=self.dropout_rate) - self.keys_only_mlp_attention = KeysOnlyMlpAttention( - hidden_size=self.hidden_size) + self.keys_only_mlp_attention = KeysOnlyMlpAttention(hidden_size=self.hidden_size) self.mlp = MLP( hidden_size=self.hidden_size, output_size=self.output_size, output_bias=False, - dropout_rate=self.dropout_rate) + dropout_rate=self.dropout_rate, + ) - def __call__(self, encoded_inputs: Array, lengths: Array, - deterministic: Optional[bool] = None) -> Array: + def __call__( + self, encoded_inputs: Array, lengths: Array, deterministic: Optional[bool] = None + ) -> Array: """Applies model to the encoded inputs. Args: @@ -329,9 +341,9 @@ def __call__(self, encoded_inputs: Array, lengths: Array, An array of logits [batch_size, output_size]. """ deterministic = nn.module.merge_param( - 'deterministic', self.deterministic, deterministic) - encoded_inputs = self.dropout_layer( - encoded_inputs, deterministic=deterministic) + 'deterministic', self.deterministic, deterministic + ) + encoded_inputs = self.dropout_layer(encoded_inputs, deterministic=deterministic) # Compute attention. attention.shape: [batch_size, seq_len]. mask = sequence_mask(lengths, encoded_inputs.shape[1]) @@ -366,33 +378,38 @@ def setup(self): embedding_size=self.embedding_size, dropout_rate=self.dropout_rate, word_dropout_rate=self.word_dropout_rate, - unk_idx=self.unk_idx) + unk_idx=self.unk_idx, + ) self.encoder = SimpleBiLSTM(hidden_size=self.hidden_size) self.classifier = AttentionClassifier( hidden_size=self.hidden_size, output_size=self.output_size, - dropout_rate=self.dropout_rate) + dropout_rate=self.dropout_rate, + ) - def embed_token_ids(self, token_ids: Array, - deterministic: Optional[bool] = None) -> Array: + def embed_token_ids( + self, token_ids: Array, deterministic: Optional[bool] = None + ) -> Array: deterministic = nn.module.merge_param( - 'deterministic', self.deterministic, deterministic) + 'deterministic', self.deterministic, deterministic + ) return self.embedder(token_ids, deterministic=deterministic) def logits_from_embedded_inputs( - self, embedded_inputs: Array, lengths: Array, - deterministic: Optional[bool] = None) -> Array: + self, embedded_inputs: Array, lengths: Array, deterministic: Optional[bool] = None + ) -> Array: deterministic = nn.module.merge_param( - 'deterministic', self.deterministic, deterministic) + 'deterministic', self.deterministic, deterministic + ) encoded_inputs = self.encoder(embedded_inputs, lengths) - return self.classifier( - encoded_inputs, lengths, deterministic=deterministic) + return self.classifier(encoded_inputs, lengths, deterministic=deterministic) - def __call__(self, token_ids: Array, lengths: Array, - deterministic: Optional[bool] = None) -> Array: + def __call__( + self, token_ids: Array, lengths: Array, deterministic: Optional[bool] = None + ) -> Array: """Embeds the token IDs, encodes them, and classifies with attention.""" - embedded_inputs = self.embed_token_ids( - token_ids, deterministic=deterministic) + embedded_inputs = self.embed_token_ids(token_ids, deterministic=deterministic) logits = self.logits_from_embedded_inputs( - embedded_inputs, lengths, deterministic=deterministic) - return logits \ No newline at end of file + embedded_inputs, lengths, deterministic=deterministic + ) + return logits diff --git a/examples/sst2/models_test.py b/examples/sst2/models_test.py index 7c77a2c0e7..ab75d6df78 100644 --- a/examples/sst2/models_test.py +++ b/examples/sst2/models_test.py @@ -33,9 +33,7 @@ def test_embedder_returns_correct_output_shape(self): """Tests if the embedder returns the correct shape.""" vocab_size = 5 embedding_size = 3 - model = models.Embedder( - vocab_size=vocab_size, - embedding_size=embedding_size) + model = models.Embedder(vocab_size=vocab_size, embedding_size=embedding_size) rng = jax.random.PRNGKey(0) token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32) output, _ = model.init_with_output(rng, token_ids, deterministic=True) @@ -49,8 +47,7 @@ def test_lstm_returns_correct_output_shape(self): hidden_size = 5 model = models.SimpleLSTM(5) rng = jax.random.PRNGKey(0) - inputs = np.random.RandomState(0).normal( - size=[batch_size, seq_len, embedding_size]) + inputs = np.random.RandomState(0).normal(size=[batch_size, seq_len, embedding_size]) initial_state = model.initialize_carry(inputs[:, 0].shape) (_, output), _ = model.init_with_output(rng, initial_state, inputs) self.assertEqual((batch_size, seq_len, hidden_size), output.shape) @@ -63,8 +60,7 @@ def test_bilstm_returns_correct_output_shape(self): hidden_size = 5 model = models.SimpleBiLSTM(hidden_size=hidden_size) rng = jax.random.PRNGKey(0) - inputs = np.random.RandomState(0).normal( - size=[batch_size, seq_len, embedding_size]) + inputs = np.random.RandomState(0).normal(size=[batch_size, seq_len, embedding_size]) lengths = np.array([2, 3], dtype=np.int32) outputs, _ = model.init_with_output(rng, inputs, lengths) # We expect 2*hidden_size because we concatenate forward+backward LSTMs. @@ -88,7 +84,8 @@ def test_text_classifier_returns_correct_output_shape(self): dropout_rate=dropout_rate, word_dropout_rate=word_dropout_rate, unk_idx=unk_idx, - deterministic=True) + deterministic=True, + ) rng = jax.random.PRNGKey(0) token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32) @@ -99,4 +96,4 @@ def test_text_classifier_returns_correct_output_shape(self): if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() diff --git a/examples/sst2/train.py b/examples/sst2/train.py index 0619e4515a..37a9bdb4c7 100644 --- a/examples/sst2/train.py +++ b/examples/sst2/train.py @@ -37,6 +37,7 @@ class Metrics(struct.PyTreeNode): """Computed metrics.""" + loss: float accuracy: float count: Optional[int] = None @@ -46,7 +47,7 @@ class Metrics(struct.PyTreeNode): def sigmoid_cross_entropy_with_logits(*, labels: Array, logits: Array) -> Array: """Sigmoid cross entropy loss.""" zeros = jnp.zeros_like(logits, dtype=logits.dtype) - condition = (logits >= zeros) + condition = logits >= zeros relu_logits = jnp.where(condition, logits, zeros) neg_abs_logits = jnp.where(condition, -logits, logits) return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits)) @@ -65,7 +66,8 @@ def create_train_state(rng, config: ml_collections.ConfigDict, model): params = get_initial_params(rng, model) tx = optax.chain( optax.sgd(learning_rate=config.learning_rate, momentum=config.momentum), - optax.add_decayed_weights(weight_decay=config.weight_decay)) + optax.add_decayed_weights(weight_decay=config.weight_decay), + ) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state @@ -75,12 +77,11 @@ def compute_metrics(*, labels: Array, logits: Array) -> Metrics: if labels.ndim == 1: # Prevent the labels from broadcasting over the logits. labels = jnp.expand_dims(labels, axis=1) loss = sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) - binary_predictions = (logits >= 0.) + binary_predictions = logits >= 0.0 binary_accuracy = jnp.equal(binary_predictions, labels) return Metrics( - loss=jnp.sum(loss), - accuracy=jnp.sum(binary_accuracy), - count=logits.shape[0]) + loss=jnp.sum(loss), accuracy=jnp.sum(binary_accuracy), count=logits.shape[0] + ) def model_from_config(config: ml_collections.ConfigDict): @@ -92,7 +93,8 @@ def model_from_config(config: ml_collections.ConfigDict): output_size=config.output_size, dropout_rate=config.dropout_rate, word_dropout_rate=config.word_dropout_rate, - unk_idx=config.unk_idx) + unk_idx=config.unk_idx, + ) return model @@ -109,15 +111,13 @@ def train_step( def loss_fn(params): variables = {'params': params} logits = state.apply_fn( - variables, batch['token_ids'], batch['length'], - deterministic=False, - rngs=rngs) + variables, batch['token_ids'], batch['length'], deterministic=False, rngs=rngs + ) labels = batch['label'] if labels.ndim == 1: labels = jnp.expand_dims(labels, 1) - loss = jnp.mean( - sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)) + loss = jnp.mean(sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) @@ -129,28 +129,26 @@ def loss_fn(params): return new_state, metrics -def eval_step(state: TrainState, batch: Dict[str, Array], - rngs: Dict[str, Any]) -> Metrics: +def eval_step( + state: TrainState, batch: Dict[str, Array], rngs: Dict[str, Any] +) -> Metrics: """Evaluate for a single step. Model should be in deterministic mode.""" variables = {'params': state.params} logits = state.apply_fn( - variables, batch['token_ids'], batch['length'], - deterministic=True, - rngs=rngs) + variables, batch['token_ids'], batch['length'], deterministic=True, rngs=rngs + ) metrics = compute_metrics(labels=batch['label'], logits=logits) return metrics -def normalize_batch_metrics( - batch_metrics: Sequence[Metrics]) -> Metrics: +def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics: """Consolidates and normalizes a list of per-batch metrics dicts.""" # Here we sum the metrics that were already summed per batch. total_loss = np.sum([metrics.loss for metrics in batch_metrics]) total_accuracy = np.sum([metrics.accuracy for metrics in batch_metrics]) total = np.sum([metrics.count for metrics in batch_metrics]) # Divide each metric by the total number of items in the data set. - return Metrics( - loss=total_loss.item() / total, accuracy=total_accuracy.item() / total) + return Metrics(loss=total_loss.item() / total, accuracy=total_accuracy.item() / total) def batch_to_numpy(batch: Dict[str, tf.Tensor]) -> Dict[str, Array]: @@ -161,11 +159,11 @@ def batch_to_numpy(batch: Dict[str, tf.Tensor]) -> Dict[str, Array]: def evaluate_model( - eval_step_fn: Callable[..., Any], - state: TrainState, - batches: Union[Iterable[Example], tf.data.Dataset], - epoch: int, - rngs: Optional[Dict[str, Any]] = None + eval_step_fn: Callable[..., Any], + state: TrainState, + batches: Union[Iterable[Example], tf.data.Dataset], + epoch: int, + rngs: Optional[Dict[str, Any]] = None, ) -> Metrics: """Evaluate a model on a dataset.""" batch_metrics = [] @@ -179,17 +177,22 @@ def evaluate_model( batch_metrics = jax.device_get(batch_metrics) metrics = normalize_batch_metrics(batch_metrics) - logging.info('eval epoch %03d loss %.4f accuracy %.2f', epoch, - metrics.loss, metrics.accuracy * 100) + logging.info( + 'eval epoch %03d loss %.4f accuracy %.2f', + epoch, + metrics.loss, + metrics.accuracy * 100, + ) return metrics -def train_epoch(train_step_fn: Callable[..., Tuple[TrainState, Metrics]], - state: TrainState, - train_batches: tf.data.Dataset, - epoch: int, - rngs: Optional[Dict[str, Any]] = None - ) -> Tuple[TrainState, Metrics]: +def train_epoch( + train_step_fn: Callable[..., Tuple[TrainState, Metrics]], + state: TrainState, + train_batches: tf.data.Dataset, + epoch: int, + rngs: Optional[Dict[str, Any]] = None, +) -> Tuple[TrainState, Metrics]: """Train for a single epoch.""" batch_metrics = [] for batch in train_batches: @@ -201,14 +204,17 @@ def train_epoch(train_step_fn: Callable[..., Tuple[TrainState, Metrics]], batch_metrics = jax.device_get(batch_metrics) metrics = normalize_batch_metrics(batch_metrics) - logging.info('train epoch %03d loss %.4f accuracy %.2f', epoch, - metrics.loss, metrics.accuracy * 100) + logging.info( + 'train epoch %03d loss %.4f accuracy %.2f', + epoch, + metrics.loss, + metrics.accuracy * 100, + ) return state, metrics -def train_and_evaluate(config: ml_collections.ConfigDict, - workdir: str) -> TrainState: +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState: """Execute model training and evaluation loop. Args: @@ -218,17 +224,16 @@ def train_and_evaluate(config: ml_collections.ConfigDict, The final train state that includes the trained parameters. """ # Prepare datasets. - train_dataset = input_pipeline.TextDataset( - tfds_name='glue/sst2', split='train') - eval_dataset = input_pipeline.TextDataset( - tfds_name='glue/sst2', split='validation') + train_dataset = input_pipeline.TextDataset(tfds_name='glue/sst2', split='train') + eval_dataset = input_pipeline.TextDataset(tfds_name='glue/sst2', split='validation') train_batches = train_dataset.get_bucketed_batches( config.batch_size, config.bucket_size, max_input_length=config.max_input_length, drop_remainder=True, shuffle=True, - shuffle_seed=config.seed) + shuffle_seed=config.seed, + ) eval_batches = eval_dataset.get_batches(batch_size=config.batch_size) # Keep track of vocab size in the config so that the embedder knows it. @@ -249,27 +254,19 @@ def train_and_evaluate(config: ml_collections.ConfigDict, # Main training loop. logging.info('Starting training...') for epoch in range(1, config.num_epochs + 1): - # Train for one epoch. rng, epoch_rng = jax.random.split(rng) rngs = {'dropout': epoch_rng} - state, train_metrics = train_epoch( - train_step_fn, state, train_batches, epoch, rngs) + state, train_metrics = train_epoch(train_step_fn, state, train_batches, epoch, rngs) # Evaluate current model on the validation data. eval_metrics = evaluate_model(eval_step_fn, state, eval_batches, epoch) # Write metrics to TensorBoard. summary_writer.scalar('train_loss', train_metrics.loss, epoch) - summary_writer.scalar( - 'train_accuracy', - train_metrics.accuracy * 100, - epoch) + summary_writer.scalar('train_accuracy', train_metrics.accuracy * 100, epoch) summary_writer.scalar('eval_loss', eval_metrics.loss, epoch) - summary_writer.scalar( - 'eval_accuracy', - eval_metrics.accuracy * 100, - epoch) + summary_writer.scalar('eval_accuracy', eval_metrics.accuracy * 100, epoch) summary_writer.flush() return state diff --git a/examples/sst2/vocabulary.py b/examples/sst2/vocabulary.py index 11e1efb50c..8ea55188c8 100755 --- a/examples/sst2/vocabulary.py +++ b/examples/sst2/vocabulary.py @@ -23,14 +23,16 @@ class Vocabulary: """Represents a vocabulary that can be built from a dataset.""" - def __init__(self, - vocab_path: Optional[str] = None, - tokenized_sequences: Optional[Iterable[Sequence[bytes]]] = None, - min_freq: int = 1, - pad_token: bytes = b'', - unk_token: bytes = b'', - bos_token: bytes = b'', - eos_token: bytes = b''): + def __init__( + self, + vocab_path: Optional[str] = None, + tokenized_sequences: Optional[Iterable[Sequence[bytes]]] = None, + min_freq: int = 1, + pad_token: bytes = b'', + unk_token: bytes = b'', + bos_token: bytes = b'', + eos_token: bytes = b'', + ): """Loads the vocab from disk (if `vocab_path` is given) or builds it from `tokenized_sequences`.""" self.pad_token = pad_token self.unk_token = unk_token @@ -44,12 +46,14 @@ def __init__(self, self.build(tokenized_sequences, min_freq=min_freq) else: raise ValueError( - ('Vocabulary needs either `vocab_path` or `tokenized_sequences` to ' - 'be provided, got %r and %r.') % (vocab_path, tokenized_sequences)) - - def build(self, - tokenized_sequences: Iterable[Sequence[bytes]], - min_freq: int = 1): + ( + 'Vocabulary needs either `vocab_path` or `tokenized_sequences` to ' + 'be provided, got %r and %r.' + ) + % (vocab_path, tokenized_sequences) + ) + + def build(self, tokenized_sequences: Iterable[Sequence[bytes]], min_freq: int = 1): """Builds a vocabulary over tokens with optional minimum frequency. Args: @@ -71,7 +75,8 @@ def build(self, # Sort by frequency (from high to low), and then by token string. # This makes sure high frequency tokens get a low token ID. counter.items(), - key=lambda token_freq: (-token_freq[1], token_freq[0])): + key=lambda token_freq: (-token_freq[1], token_freq[0]), + ): if freq >= min_freq: vocab[token] = len(vocab) diff --git a/examples/vae/main.py b/examples/vae/main.py index 55bbec66a1..72baaab189 100644 --- a/examples/vae/main.py +++ b/examples/vae/main.py @@ -29,24 +29,14 @@ FLAGS = flags.FLAGS flags.DEFINE_float( - 'learning_rate', default=1e-3, - help=('The learning rate for the Adam optimizer.') + 'learning_rate', default=1e-3, help=('The learning rate for the Adam optimizer.') ) -flags.DEFINE_integer( - 'batch_size', default=128, - help=('Batch size for training.') -) +flags.DEFINE_integer('batch_size', default=128, help=('Batch size for training.')) -flags.DEFINE_integer( - 'num_epochs', default=30, - help=('Number of training epochs.') -) +flags.DEFINE_integer('num_epochs', default=30, help=('Number of training epochs.')) -flags.DEFINE_integer( - 'latents', default=20, - help=('Number of latent variables.') -) +flags.DEFINE_integer('latents', default=20, help=('Number of latent variables.')) def main(argv): @@ -56,12 +46,10 @@ def main(argv): # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') - train.train_and_evaluate(FLAGS.batch_size, - FLAGS.learning_rate, - FLAGS.num_epochs, - FLAGS.latents) + train.train_and_evaluate( + FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs, FLAGS.latents + ) if __name__ == '__main__': - app.run(main) diff --git a/examples/vae/models.py b/examples/vae/models.py index 85cd4cb1e2..0f3233c95f 100644 --- a/examples/vae/models.py +++ b/examples/vae/models.py @@ -21,6 +21,7 @@ class Encoder(nn.Module): """VAE Encoder.""" + latents: int @nn.compact @@ -45,6 +46,7 @@ def __call__(self, z): class VAE(nn.Module): """Full VAE model.""" + latents: int = 20 def setup(self): diff --git a/examples/vae/train.py b/examples/vae/train.py index 0922ad9629..6c8f400dd1 100644 --- a/examples/vae/train.py +++ b/examples/vae/train.py @@ -50,28 +50,26 @@ def kl_divergence(mean, logvar): @jax.vmap def binary_cross_entropy_with_logits(logits, labels): logits = nn.log_sigmoid(logits) - return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits))) + return -jnp.sum(labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits))) def compute_metrics(recon_x, x, mean, logvar): bce_loss = binary_cross_entropy_with_logits(recon_x, x).mean() kld_loss = kl_divergence(mean, logvar).mean() - return { - 'bce': bce_loss, - 'kld': kld_loss, - 'loss': bce_loss + kld_loss - } + return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss} def train_step(state, batch, z_rng, latents): def loss_fn(params): - recon_x, mean, logvar = models.model(latents).apply({'params': params}, - batch, z_rng) + recon_x, mean, logvar = models.model(latents).apply( + {'params': params}, batch, z_rng + ) bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean() kld_loss = kl_divergence(mean, logvar).mean() loss = bce_loss + kld_loss return loss + grads = jax.grad(loss_fn)(state.params) return state.apply_gradients(grads=grads) @@ -79,8 +77,9 @@ def loss_fn(params): def eval_f(params, images, z, z_rng, latents): def eval_model(vae): recon_images, mean, logvar = vae(images, z_rng) - comparison = jnp.concatenate([images[:8].reshape(-1, 28, 28, 1), - recon_images[:8].reshape(-1, 28, 28, 1)]) + comparison = jnp.concatenate( + [images[:8].reshape(-1, 28, 28, 1), recon_images[:8].reshape(-1, 28, 28, 1)] + ) generate_images = vae.generate(z) generate_images = generate_images.reshape(-1, 28, 28, 1) @@ -115,7 +114,7 @@ def train_and_evaluate(batch_size, learning_rate, num_epochs, latents): rng, z_key, eval_rng = random.split(rng, 3) z = random.normal(z_key, (64, latents)) - steps_per_epoch = ds_builder.info.splits["train"].num_examples // batch_size + steps_per_epoch = ds_builder.info.splits['train'].num_examples // batch_size for epoch in range(num_epochs): for _ in range(steps_per_epoch): @@ -123,13 +122,12 @@ def train_and_evaluate(batch_size, learning_rate, num_epochs, latents): rng, key = random.split(rng) state = train_step(state, batch, key, latents) - metrics, comparison, sample = eval_f(state.params, test_ds, z, eval_rng, - latents) - vae_utils.save_image( - comparison, f'results/reconstruction_{epoch}.png', nrow=8) + metrics, comparison, sample = eval_f(state.params, test_ds, z, eval_rng, latents) + vae_utils.save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8) vae_utils.save_image(sample, f'results/sample_{epoch}.png', nrow=8) - print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format( - epoch + 1, metrics['loss'], metrics['bce'], metrics['kld'] - )) - + print( + 'eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format( + epoch + 1, metrics['loss'], metrics['bce'], metrics['kld'] + ) + ) diff --git a/examples/vae/utils.py b/examples/vae/utils.py index 0e4b4c081e..ccac6d3252 100644 --- a/examples/vae/utils.py +++ b/examples/vae/utils.py @@ -54,9 +54,12 @@ def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format_img=None): this parameter should always be used. """ - if not (isinstance(ndarray, jnp.ndarray) or - (isinstance(ndarray, list) and all(isinstance(t, jnp.ndarray) for t - in ndarray))): + if not ( + isinstance(ndarray, jnp.ndarray) + or ( + isinstance(ndarray, list) and all(isinstance(t, jnp.ndarray) for t in ndarray) + ) + ): raise TypeError(f'array_like of tensors expected, got {type(ndarray)}') ndarray = jnp.asarray(ndarray) @@ -68,18 +71,19 @@ def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format_img=None): nmaps = ndarray.shape[0] xmaps = min(nrow, nmaps) ymaps = int(math.ceil(float(nmaps) / xmaps)) - height, width = (int(ndarray.shape[1] + padding), - int(ndarray.shape[2] + padding)) + height, width = (int(ndarray.shape[1] + padding), int(ndarray.shape[2] + padding)) num_channels = ndarray.shape[3] - grid = jnp.full((height * ymaps + padding, width * xmaps + padding, - num_channels), pad_value).astype(jnp.float32) + grid = jnp.full( + (height * ymaps + padding, width * xmaps + padding, num_channels), pad_value + ).astype(jnp.float32) k = 0 for y in range(ymaps): for x in range(xmaps): if k >= nmaps: break - grid = grid.at[y * height + padding:(y + 1) * height, - x * width + padding:(x + 1) * width].set(ndarray[k]) + grid = grid.at[ + y * height + padding : (y + 1) * height, x * width + padding : (x + 1) * width + ].set(ndarray[k]) k = k + 1 # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer diff --git a/examples/wmt/bleu.py b/examples/wmt/bleu.py index c3acf33f7f..ee5b0762cd 100644 --- a/examples/wmt/bleu.py +++ b/examples/wmt/bleu.py @@ -60,7 +60,8 @@ def property_chars(self, prefix): return "".join( chr(x) for x in range(sys.maxunicode) - if unicodedata.category(chr(x)).startswith(prefix)) + if unicodedata.category(chr(x)).startswith(prefix) + ) uregex = UnicodeRegex() @@ -109,14 +110,12 @@ def _get_ngrams(segment, max_order): ngram_counts = collections.Counter() for order in range(1, max_order + 1): for i in range(0, len(segment) - order + 1): - ngram = tuple(segment[i:i + order]) + ngram = tuple(segment[i : i + order]) ngram_counts[ngram] += 1 return ngram_counts -def compute_bleu_matches(reference_corpus, - translation_corpus, - max_order=4): +def compute_bleu_matches(reference_corpus, translation_corpus, max_order=4): """Computes BLEU match stats of translations against one or more references. Args: @@ -138,32 +137,36 @@ def compute_bleu_matches(reference_corpus, possible_matches_by_order = [0] * max_order precisions = [] - for (references, translations) in zip(reference_corpus, translation_corpus): + for references, translations in zip(reference_corpus, translation_corpus): reference_length += len(references) translation_length += len(translations) ref_ngram_counts = _get_ngrams(references, max_order) translation_ngram_counts = _get_ngrams(translations, max_order) - overlap = {ngram: min(count, translation_ngram_counts[ngram]) - for ngram, count in ref_ngram_counts.items()} + overlap = { + ngram: min(count, translation_ngram_counts[ngram]) + for ngram, count in ref_ngram_counts.items() + } for ngram in overlap: matches_by_order[len(ngram) - 1] += overlap[ngram] for ngram in translation_ngram_counts: - possible_matches_by_order[len(ngram) - - 1] += translation_ngram_counts[ngram] + possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[ngram] - return (np.array(matches_by_order), - np.array(possible_matches_by_order), - np.array(reference_length), - np.array(translation_length)) + return ( + np.array(matches_by_order), + np.array(possible_matches_by_order), + np.array(reference_length), + np.array(translation_length), + ) def bleu_partial(ref_lines, hyp_lines, case_sensitive=False): """Compute n-gram statistics for two lists of references and translations.""" if len(ref_lines) != len(hyp_lines): - raise ValueError("Reference and translation lists have different " - "numbers of lines.") + raise ValueError( + "Reference and translation lists have different " "numbers of lines." + ) if not case_sensitive: ref_lines = [x.lower() for x in ref_lines] hyp_lines = [x.lower() for x in hyp_lines] @@ -172,12 +175,14 @@ def bleu_partial(ref_lines, hyp_lines, case_sensitive=False): return compute_bleu_matches(ref_tokens, hyp_tokens) -def complete_bleu(matches_by_order, - possible_matches_by_order, - reference_length, - translation_length, - max_order=4, - use_bp=True): +def complete_bleu( + matches_by_order, + possible_matches_by_order, + reference_length, + translation_length, + max_order=4, + use_bp=True, +): """Compute BLEU score from aggregated n-gram statistics.""" precisions = [0] * max_order smooth = 1.0 @@ -207,7 +212,7 @@ def complete_bleu(matches_by_order, elif ratio >= 1.0: bp = 1.0 else: - bp = math.exp(1 - 1. / ratio) + bp = math.exp(1 - 1.0 / ratio) bleu = geo_mean * bp return float(bleu) * 100.0 diff --git a/examples/wmt/decode.py b/examples/wmt/decode.py index 81aa0f2c7f..3641958d42 100644 --- a/examples/wmt/decode.py +++ b/examples/wmt/decode.py @@ -91,12 +91,15 @@ def gather_beams(nested, beam_indices, batch_size, new_beam_size): """ batch_indices = jnp.reshape( jnp.arange(batch_size * new_beam_size) // new_beam_size, - (batch_size, new_beam_size)) + (batch_size, new_beam_size), + ) + def gather_fn(x): if x.ndim == 0: # ignore scalars (e.g. cache index) return x else: return x[batch_indices, beam_indices] + return jax.tree_util.tree_map(gather_fn, nested) @@ -125,6 +128,7 @@ def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): @flax.struct.dataclass class BeamState: """Holds beam search state data.""" + # The position of the decoding loop in the length dimension. cur_index: jax.Array # scalar int32: current decoded length index # The active sequence log probabilities and finished sequence scores. @@ -144,35 +148,37 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( - jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), - [batch_size, 1]) + jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1] + ) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF - live_seqs0 = jnp.zeros( - (batch_size, beam_size, max_decode_len), jnp.int32) - finished_seqs0 = jnp.zeros( - (batch_size, beam_size, max_decode_len), jnp.int32) + live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements beam_cache0 = jax.tree_util.tree_map(lambda x: add_beam_dim(x, beam_size), cache) - return BeamState(cur_index=cur_index0, - live_logprobs=live_logprobs0, - finished_scores=finished_scores0, - live_seqs=live_seqs0, - finished_seqs=finished_seqs0, - finished_flags=finished_flags0, - cache=beam_cache0) + return BeamState( + cur_index=cur_index0, + live_logprobs=live_logprobs0, + finished_scores=finished_scores0, + live_seqs=live_seqs0, + finished_seqs=finished_seqs0, + finished_flags=finished_flags0, + cache=beam_cache0, + ) # Beam search routine: -def beam_search(inputs, - cache, - tokens_to_logits, - beam_size=4, - alpha=0.6, - eos_id=EOS_ID, - max_decode_len=None): +def beam_search( + inputs, + cache, + tokens_to_logits, + beam_size=4, + alpha=0.6, + eos_id=EOS_ID, + max_decode_len=None, +): """Beam search for transformer machine translation. Args: @@ -198,26 +204,23 @@ def beam_search(inputs, end_marker = jnp.array(eos_id) # initialize beam search state - beam_search_init_state = beam_init(batch_size, - beam_size, - max_decode_len, - cache) + beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len, cache) def beam_search_loop_cond_fn(state): """Beam search loop termination condition.""" # Have we reached max decoding length? - not_at_end = (state.cur_index < max_decode_len - 1) + not_at_end = state.cur_index < max_decode_len - 1 # Is no further progress in the beam search possible? # Get the best possible scores from alive sequences. min_brevity_penalty = brevity_penalty(alpha, max_decode_len) best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty # Get the worst scores from finished sequences. - worst_finished_scores = jnp.min( - state.finished_scores, axis=1, keepdims=True) + worst_finished_scores = jnp.min(state.finished_scores, axis=1, keepdims=True) # Mask out scores from slots without any actual finished sequences. worst_finished_scores = jnp.where( - state.finished_flags, worst_finished_scores, NEG_INF) + state.finished_flags, worst_finished_scores, NEG_INF + ) # If no best possible live score is better than current worst finished # scores, the search cannot improve the finished set further. search_terminated = jnp.all(worst_finished_scores > best_live_scores) @@ -232,10 +235,11 @@ def beam_search_loop_body_fn(state): # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # --> [batch * beam, 1] - flat_ids = flatten_beam_dim(lax.dynamic_slice( - state.live_seqs, - (0, 0, state.cur_index), - (batch_size, beam_size, 1))) + flat_ids = flatten_beam_dim( + lax.dynamic_slice( + state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1) + ) + ) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree_util.tree_map(flatten_beam_dim, state.cache) @@ -250,14 +254,14 @@ def beam_search_loop_body_fn(state): # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree_util.tree_map( - lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) + lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache + ) # Gather log probabilities from logits candidate_log_probs = jax.nn.log_softmax(logits) # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] - log_probs = (candidate_log_probs + - jnp.expand_dims(state.live_logprobs, axis=2)) + log_probs = candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] @@ -277,9 +281,9 @@ def beam_search_loop_body_fn(state): topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams. # --> [batch, 2*beams, length] - topk_seq = gather_beams(state.live_seqs, - topk_beam_indices, - batch_size, beams_to_keep) + topk_seq = gather_beams( + state.live_seqs, topk_beam_indices, batch_size, beams_to_keep + ) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. @@ -287,13 +291,12 @@ def beam_search_loop_body_fn(state): topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] - topk_seq = lax.dynamic_update_slice( - topk_seq, topk_ids, (0, 0, state.cur_index + 1)) + topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids, (0, 0, state.cur_index + 1)) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] - newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) + newly_finished = topk_seq[:, :, state.cur_index + 1] == end_marker # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. @@ -305,16 +308,17 @@ def beam_search_loop_body_fn(state): # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] top_alive_seq, top_alive_log_probs = gather_beams( - [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size) + [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size + ) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] top_alive_indices = gather_beams( - topk_beam_indices, new_topk_indices, batch_size, beam_size) + topk_beam_indices, new_topk_indices, batch_size, beam_size + ) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} - top_alive_cache = gather_beams( - new_cache, top_alive_indices, batch_size, beam_size) + top_alive_cache = gather_beams(new_cache, top_alive_indices, batch_size, beam_size) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. @@ -327,40 +331,48 @@ def beam_search_loop_body_fn(state): # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] - [state.finished_seqs, topk_seq], axis=1) + [state.finished_seqs, topk_seq], axis=1 + ) finished_scores = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_scores, new_scores], axis=1) + [state.finished_scores, new_scores], axis=1 + ) finished_flags = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_flags, newly_finished], axis=1) + [state.finished_flags, newly_finished], axis=1 + ) # --> [batch, beams, length], [batch, beams], [batch, beams] - top_finished_seq, top_finished_scores, top_finished_flags = ( - gather_topk_beams([finished_seqs, finished_scores, finished_flags], - finished_scores, batch_size, beam_size)) - - return BeamState(cur_index=state.cur_index + 1, - live_logprobs=top_alive_log_probs, - finished_scores=top_finished_scores, - live_seqs=top_alive_seq, - finished_seqs=top_finished_seq, - finished_flags=top_finished_flags, - cache=top_alive_cache) + top_finished_seq, top_finished_scores, top_finished_flags = gather_topk_beams( + [finished_seqs, finished_scores, finished_flags], + finished_scores, + batch_size, + beam_size, + ) + + return BeamState( + cur_index=state.cur_index + 1, + live_logprobs=top_alive_log_probs, + finished_scores=top_finished_scores, + live_seqs=top_alive_seq, + finished_seqs=top_finished_seq, + finished_flags=top_finished_flags, + cache=top_alive_cache, + ) # Run while loop and get final beam search state. - final_state = lax.while_loop(beam_search_loop_cond_fn, - beam_search_loop_body_fn, - beam_search_init_state) + final_state = lax.while_loop( + beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state + ) # Account for the edge-case where there are no finished sequences for a # particular batch item. If so, return live sequences for that batch item. # --> [batch] none_finished = jnp.any(final_state.finished_flags, axis=1) # --> [batch, beams, length] - finished_seqs = jnp.where(none_finished[:, None, None], - final_state.finished_seqs, - final_state.live_seqs) + finished_seqs = jnp.where( + none_finished[:, None, None], final_state.finished_seqs, final_state.live_seqs + ) # --> [batch, beams] - finished_scores = jnp.where(none_finished[:, None], - final_state.finished_scores, - final_state.live_logprobs) + finished_scores = jnp.where( + none_finished[:, None], final_state.finished_scores, final_state.live_logprobs + ) return finished_seqs, finished_scores diff --git a/examples/wmt/input_pipeline.py b/examples/wmt/input_pipeline.py index ebbd6229d6..516b5d2b41 100644 --- a/examples/wmt/input_pipeline.py +++ b/examples/wmt/input_pipeline.py @@ -43,10 +43,12 @@ def __call__(self, features: Features) -> Features: return features -def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, - split: str, - *, - reverse_translation: bool = False) -> tf.data.Dataset: +def get_raw_dataset( + dataset_builder: tfds.core.DatasetBuilder, + split: str, + *, + reverse_translation: bool = False +) -> tf.data.Dataset: """Loads a raw WMT dataset and normalizes feature keys. Args: @@ -62,18 +64,23 @@ def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, """ num_examples = dataset_builder.info.splits[split].num_examples per_host_split = deterministic_data.get_read_instruction_for_host( - split, num_examples, drop_remainder=False) + split, num_examples, drop_remainder=False + ) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map( NormalizeFeatureNamesOp( - dataset_builder.info, reverse_translation=reverse_translation), - num_parallel_calls=AUTOTUNE) + dataset_builder.info, reverse_translation=reverse_translation + ), + num_parallel_calls=AUTOTUNE, + ) return ds -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str]] = None, +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. @@ -118,9 +125,10 @@ def pack_dataset(dataset: tf.data.Dataset, keys = list(shapes.keys()) for k in keys: if k not in shapes: - raise ValueError('Key %s not found in dataset. Available keys are %s' % - (k, shapes.keys())) - if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] + raise ValueError( + 'Key %s not found in dataset. Available keys are %s' % (k, shapes.keys()) + ) + if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_segmentation" and "_position" @@ -132,13 +140,12 @@ def pack_dataset(dataset: tf.data.Dataset, # trim to length dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE + ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) - dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}) + dataset = dataset.padded_batch(batch_size, padded_shapes={k: [-1] for k in keys}) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -148,8 +155,9 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops( + dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] +) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. @@ -174,7 +182,8 @@ def write_packed_example(partial, outputs): for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + ) return new_partial, new_outputs def map_fn(x): @@ -194,9 +203,11 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -213,13 +224,13 @@ def body_fn(i, partial, outputs): one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, - tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + tf.less_equal(tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]), + ) def false_fn(): return write_packed_example(partial, outputs) @@ -230,12 +241,12 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], - tf.range(new_seq_len)], 0) + [partial[k + '_position'], tf.range(new_seq_len)], 0 + ) partial = new_partial return i + 1, partial, outputs @@ -245,18 +256,18 @@ def true_fn(): body=body_fn, loop_vars=(i, partial, outputs), shape_invariants=( - tf.TensorShape([]), # type: ignore[wrong-arg-types] - {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] - {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] + tf.TensorShape([]), # type: ignore[wrong-arg-types] + {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] + {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] ), - maximum_iterations=dynamic_batch_size) + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + '_segmentation'] = tf.cumsum( + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) @@ -266,19 +277,20 @@ def true_fn(): # ----------------------------------------------------------------------------- # Main dataset prep routines. # ----------------------------------------------------------------------------- -def preprocess_wmt_data(dataset, - shuffle: bool, - num_epochs: Optional[int] = 1, - pack_examples: bool = True, - shuffle_buffer_size: int = 1024, - max_length: int = 512, - batch_size: int = 256, - drop_remainder: bool = True, - prefetch_size: int = AUTOTUNE): +def preprocess_wmt_data( + dataset, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + batch_size: int = 256, + drop_remainder: bool = True, + prefetch_size: int = AUTOTUNE, +): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): - def filter_fn(x): source, target = x['inputs'], x['targets'] l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) @@ -299,15 +311,10 @@ def filter_fn(x): else: # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size, - padded_shapes={ - 'inputs': max_length, - 'targets': max_length - }, - padding_values={ - 'inputs': 0, - 'targets': 0 - }, - drop_remainder=drop_remainder) + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=drop_remainder, + ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) @@ -315,38 +322,43 @@ def filter_fn(x): return dataset -def get_wmt_datasets(config: ml_collections.ConfigDict, - *, - n_devices: int, - reverse_translation: bool = True, - vocab_path: Optional[str] = None): +def get_wmt_datasets( + config: ml_collections.ConfigDict, + *, + n_devices: int, + reverse_translation: bool = True, + vocab_path: Optional[str] = None +): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: vocab_path = os.path.expanduser('~/wmt_sentencepiece_model') train_ds_builder = tfds.builder(config.dataset_name) train_data = get_raw_dataset( - train_ds_builder, 'train', reverse_translation=reverse_translation) + train_ds_builder, 'train', reverse_translation=reverse_translation + ) if config.eval_dataset_name: eval_ds_builder = tfds.builder(config.eval_dataset_name) else: eval_ds_builder = train_ds_builder eval_data = get_raw_dataset( - eval_ds_builder, - config.eval_split, - reverse_translation=reverse_translation) + eval_ds_builder, config.eval_split, reverse_translation=reverse_translation + ) # Tokenize data. sp_tokenizer = tokenizer.load_or_train_tokenizer( train_data, vocab_path=vocab_path, vocab_size=config.vocab_size, - max_corpus_chars=config.max_corpus_chars) + max_corpus_chars=config.max_corpus_chars, + ) train_data = train_data.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) eval_data = eval_data.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) batch_size = config.per_device_batch_size * n_devices @@ -356,14 +368,16 @@ def get_wmt_datasets(config: ml_collections.ConfigDict, num_epochs=None, pack_examples=True, batch_size=batch_size, - max_length=config.max_target_length) + max_length=config.max_target_length, + ) eval_ds = preprocess_wmt_data( eval_data, shuffle=False, pack_examples=False, batch_size=batch_size, - max_length=config.max_eval_target_length) + max_length=config.max_eval_target_length, + ) predict_ds = preprocess_wmt_data( eval_data, @@ -371,6 +385,7 @@ def get_wmt_datasets(config: ml_collections.ConfigDict, pack_examples=False, batch_size=batch_size, max_length=config.max_predict_length, - drop_remainder=False) + drop_remainder=False, + ) return train_ds, eval_ds, predict_ds, sp_tokenizer diff --git a/examples/wmt/input_pipeline_test.py b/examples/wmt/input_pipeline_test.py index a50e57752f..77b25e73f1 100644 --- a/examples/wmt/input_pipeline_test.py +++ b/examples/wmt/input_pipeline_test.py @@ -53,7 +53,8 @@ def _get_datasets(self): with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): train_ds, eval_ds, predict_ds, _ = input_pipeline.get_wmt_datasets( - n_devices=2, config=config, vocab_path=vocab_path) + n_devices=2, config=config, vocab_path=vocab_path + ) return train_ds, eval_ds, predict_ds def test_train_ds(self): @@ -61,30 +62,39 @@ def test_train_ds(self): # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. for batch in self.train_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'inputs_position': expected_shape, + 'inputs_segmentation': expected_shape, + 'targets': expected_shape, + 'targets_position': expected_shape, + 'targets_segmentation': expected_shape, + }, + ) def test_eval_ds(self): expected_shape = [2, _EVAL_TARGET_LENGTH] # 2 devices. for batch in self.eval_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'targets': expected_shape, + }, + ) def test_predict_ds(self): expected_shape = [2, _PREDICT_TARGET_LENGTH] # 2 devices. for batch in self.predict_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'targets': expected_shape, + }, + ) if __name__ == '__main__': diff --git a/examples/wmt/main.py b/examples/wmt/main.py index e2b3d71486..a78c2ee2ac 100644 --- a/examples/wmt/main.py +++ b/examples/wmt/main.py @@ -36,7 +36,8 @@ 'config', 'configs/default.py', 'File path to the training hyperparameter configuration.', - lock_config=True) + lock_config=True, +) flags.mark_flags_as_required(['config', 'workdir']) @@ -53,10 +54,12 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}') - platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, - FLAGS.workdir, 'workdir') + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' + ) + platform.work_unit().create_artifact( + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/wmt/models.py b/examples/wmt/models.py index e63f0786a2..5e44697542 100644 --- a/examples/wmt/models.py +++ b/examples/wmt/models.py @@ -32,6 +32,7 @@ @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int output_vocab_size: int share_embeddings: bool = False @@ -56,14 +57,11 @@ def shift_right(x, axis=1): """Shift the input to the right by padding on axis 1.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) - padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + padded = jnp.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) return padded[:, :-1] -def sinusoidal_init(max_len=2048, - min_scale=1.0, - max_scale=10000.0): +def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): """1D Sinusoidal Position Embedding Initializer. Args: @@ -83,8 +81,8 @@ def init(key, shape, dtype=np.float32): position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) - pe[:, :d_feature // 2] = np.sin(position * div_term) - pe[:, d_feature // 2: 2 * (d_feature // 2)] = np.cos(position * div_term) + pe[:, : d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) @@ -98,13 +96,12 @@ class AddPositionEmbs(nn.Module): config: TransformerConfig dataclass containing hyperparameters. decode: whether to run in single-position autoregressive mode. """ + config: TransformerConfig decode: bool = False @nn.compact - def __call__(self, - inputs, - inputs_positions=None): + def __call__(self, inputs, inputs_positions=None): """Applies AddPositionEmbs module. By default this layer uses a fixed sinusoidal embedding table. If a @@ -120,32 +117,29 @@ def __call__(self, """ config = self.config # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3,' ' but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, config.max_len, inputs.shape[-1]) if config.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=config.max_len)(None, - pos_emb_shape, - None) + pos_embedding = sinusoidal_init(max_len=config.max_len)(None, pos_emb_shape, None) else: - pos_embedding = self.param('pos_embedding', config.posemb_init, - pos_emb_shape) + pos_embedding = self.param('pos_embedding', config.posemb_init, pos_emb_shape) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 _, _, df = pos_embedding.shape - pe = lax.dynamic_slice(pos_embedding, - jnp.array((0, i, 0)), - (1, 1, df)) + pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)), (1, 1, df)) if inputs_positions is None: # normal unpacked case: return inputs + pe @@ -161,6 +155,7 @@ class MlpBlock(nn.Module): config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ + config: TransformerConfig out_dim: Optional[int] = None @@ -168,25 +163,24 @@ class MlpBlock(nn.Module): def __call__(self, inputs): """Applies Transformer MlpBlock module.""" config = self.config - actual_out_dim = (inputs.shape[-1] if self.out_dim is None - else self.out_dim) + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( config.mlp_dim, dtype=config.dtype, kernel_init=config.kernel_init, - bias_init=config.bias_init)( - inputs) + bias_init=config.bias_init, + )(inputs) x = nn.relu(x) - x = nn.Dropout(rate=config.dropout_rate)( - x, deterministic=config.deterministic) + x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=config.deterministic) output = nn.Dense( actual_out_dim, dtype=config.dtype, kernel_init=config.kernel_init, - bias_init=config.bias_init)( - x) + bias_init=config.bias_init, + )(x) output = nn.Dropout(rate=config.dropout_rate)( - output, deterministic=config.deterministic) + output, deterministic=config.deterministic + ) return output @@ -196,12 +190,11 @@ class Encoder1DBlock(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig @nn.compact - def __call__(self, - inputs, - encoder_mask=None): + def __call__(self, inputs, encoder_mask=None): """Applies Encoder1DBlock module. Args: @@ -225,10 +218,10 @@ def __call__(self, use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, - deterministic=config.deterministic)(x, encoder_mask) + deterministic=config.deterministic, + )(x, encoder_mask) - x = nn.Dropout(rate=config.dropout_rate)( - x, deterministic=config.deterministic) + x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=config.deterministic) x = x + inputs # MLP block. @@ -244,14 +237,11 @@ class EncoderDecoder1DBlock(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None): + def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None): """Applies EncoderDecoder1DBlock module. Args: @@ -278,9 +268,9 @@ def __call__(self, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, - decode=config.decode)(x, decoder_mask) - x = nn.Dropout(rate=config.dropout_rate)( - x, deterministic=config.deterministic) + decode=config.decode, + )(x, decoder_mask) + x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=config.deterministic) x = x + targets # Encoder-Decoder block. @@ -294,10 +284,10 @@ def __call__(self, use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, - deterministic=config.deterministic)(y, encoded, encoder_decoder_mask) + deterministic=config.deterministic, + )(y, encoded, encoder_decoder_mask) - y = nn.Dropout(rate=config.dropout_rate)( - y, deterministic=config.deterministic) + y = nn.Dropout(rate=config.dropout_rate)(y, deterministic=config.deterministic) y = y + x # MLP block. @@ -314,14 +304,12 @@ class Encoder(nn.Module): config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ + config: TransformerConfig shared_embedding: Any = None @nn.compact - def __call__(self, - inputs, - inputs_positions=None, - encoder_mask=None): + def __call__(self, inputs, inputs_positions=None, encoder_mask=None): """Applies Transformer model on the inputs. Args: @@ -340,23 +328,22 @@ def __call__(self, input_embed = nn.Embed( num_embeddings=config.vocab_size, features=config.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) - x = AddPositionEmbs( - config=config, decode=False, name='posembed_input')( - x, inputs_positions=inputs_positions) - x = nn.Dropout(rate=config.dropout_rate)( - x, deterministic=config.deterministic) + x = AddPositionEmbs(config=config, decode=False, name='posembed_input')( + x, inputs_positions=inputs_positions + ) + x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=config.deterministic) x = x.astype(config.dtype) # Input Encoder for lyr in range(config.num_layers): - x = Encoder1DBlock( - config=config, name=f'encoderblock_{lyr}')(x, encoder_mask) + x = Encoder1DBlock(config=config, name=f'encoderblock_{lyr}')(x, encoder_mask) encoded = nn.LayerNorm(dtype=config.dtype, name='encoder_norm')(x) @@ -370,16 +357,19 @@ class Decoder(nn.Module): config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ + config: TransformerConfig shared_embedding: Any = None @nn.compact - def __call__(self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None): + def __call__( + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + ): """Applies Transformer model on the inputs. Args: @@ -402,7 +392,8 @@ def __call__(self, output_embed = nn.Embed( num_embeddings=config.output_vocab_size, features=config.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: output_embed = self.shared_embedding @@ -410,22 +401,21 @@ def __call__(self, if not config.decode: y = shift_right(y) y = output_embed(y) - y = AddPositionEmbs( - config=config, decode=config.decode, name='posembed_output')( - y, inputs_positions=targets_positions) - y = nn.Dropout(rate=config.dropout_rate)( - y, deterministic=config.deterministic) + y = AddPositionEmbs(config=config, decode=config.decode, name='posembed_output')( + y, inputs_positions=targets_positions + ) + y = nn.Dropout(rate=config.dropout_rate)(y, deterministic=config.deterministic) y = y.astype(config.dtype) # Target-Input Decoder for lyr in range(config.num_layers): - y = EncoderDecoder1DBlock( - config=config, name=f'encoderdecoderblock_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + y = EncoderDecoder1DBlock(config=config, name=f'encoderdecoderblock_{lyr}')( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + ) y = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')(y) # Decoded Logits @@ -440,8 +430,8 @@ def __call__(self, dtype=config.dtype, kernel_init=config.kernel_init, bias_init=config.bias_init, - name='logitdense')( - y) + name='logitdense', + )(y) return logits @@ -451,6 +441,7 @@ class Transformer(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig def setup(self): @@ -458,24 +449,21 @@ def setup(self): if config.share_embeddings: if config.output_vocab_size is not None: - assert config.output_vocab_size == config.vocab_size, ( - "can't share embedding with different vocab sizes.") + assert ( + config.output_vocab_size == config.vocab_size + ), "can't share embedding with different vocab sizes." self.shared_embedding = nn.Embed( num_embeddings=config.vocab_size, features=config.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: self.shared_embedding = None - self.encoder = Encoder( - config=config, shared_embedding=self.shared_embedding) - self.decoder = Decoder( - config=config, shared_embedding=self.shared_embedding) + self.encoder = Encoder(config=config, shared_embedding=self.shared_embedding) + self.decoder = Decoder(config=config, shared_embedding=self.shared_embedding) - def encode(self, - inputs, - inputs_positions=None, - inputs_segmentation=None): + def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies Transformer encoder-branch on the inputs. Args: @@ -488,29 +476,28 @@ def encode(self, """ config = self.config # Make padding attention mask. - encoder_mask = nn.make_attention_mask( - inputs > 0, inputs > 0, dtype=config.dtype) + encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=config.dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, nn.make_attention_mask( - inputs_segmentation, - inputs_segmentation, - jnp.equal, - dtype=config.dtype)) + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype + ), + ) return self.encoder( - inputs, - inputs_positions=inputs_positions, - encoder_mask=encoder_mask) - - def decode(self, - encoded, - inputs, # only needed for masks - targets, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None): + inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask + ) + + def decode( + self, + encoded, + inputs, # only needed for masks + targets, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None, + ): """Applies Transformer decoder-branch on encoded-input and target. Args: @@ -531,45 +518,49 @@ def decode(self, # for fast autoregressive decoding only a special encoder-decoder mask is used decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( - jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype) + jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype + ) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=config.dtype), - nn.make_causal_mask(targets, dtype=config.dtype)) + nn.make_causal_mask(targets, dtype=config.dtype), + ) encoder_decoder_mask = nn.make_attention_mask( - targets > 0, inputs > 0, dtype=config.dtype) + targets > 0, inputs > 0, dtype=config.dtype + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask( - targets_segmentation, - targets_segmentation, - jnp.equal, - dtype=config.dtype)) + targets_segmentation, targets_segmentation, jnp.equal, dtype=config.dtype + ), + ) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask( - targets_segmentation, - inputs_segmentation, - jnp.equal, - dtype=config.dtype)) + targets_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype + ), + ) logits = self.decoder( encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + encoder_decoder_mask=encoder_decoder_mask, + ) return logits.astype(self.config.dtype) - def __call__(self, - inputs, - targets, - inputs_positions=None, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None): + def __call__( + self, + inputs, + targets, + inputs_positions=None, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None, + ): """Applies Transformer model on the inputs. Args: @@ -583,13 +574,17 @@ def __call__(self, Returns: logits array from full transformer. """ - encoded = self.encode(inputs, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) - - return self.decode(encoded, - inputs, # only used for masks - targets, - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation) + encoded = self.encode( + inputs, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation, + ) + + return self.decode( + encoded, + inputs, # only used for masks + targets, + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + ) diff --git a/examples/wmt/tokenizer.py b/examples/wmt/tokenizer.py index 8b0e73296d..86e1ea05f4 100644 --- a/examples/wmt/tokenizer.py +++ b/examples/wmt/tokenizer.py @@ -31,9 +31,7 @@ def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('inputs', 'targets') + dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=('inputs', 'targets') ) -> Tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. @@ -47,8 +45,7 @@ def _dump_chars_to_textfile( """ char_count = 0 ds_iter = dataset.as_numpy_iterator() - with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + with tempfile.NamedTemporaryFile(delete=False, prefix='/tmp/ds_chars') as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: @@ -58,14 +55,16 @@ def _dump_chars_to_textfile( return outfp.name, char_count -def _train_sentencepiece(dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('inputs', 'targets')): +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + model_path: str, + model_type: str = 'unigram', + character_coverage: float = 1.0, + data_keys=('inputs', 'targets'), +): """Train SentencePiece tokenizer from subset of tf dataset. Args: @@ -86,15 +85,15 @@ def _train_sentencepiece(dataset: tf.data.Dataset, abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) - fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys) - with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) + with tempfile.NamedTemporaryFile(delete=False, prefix='/tmp/sp_tmp') as model_fp: pass # we just want a prefix'd tmp-filename argstr = ' '.join([ - f'--input={fname}', f'--vocab_size={vocab_size}', + f'--input={fname}', + f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', - f'--model_prefix={model_fp.name}', f'--model_type={model_type}' + f'--model_prefix={model_fp.name}', + f'--model_type={model_type}', ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: @@ -111,24 +110,26 @@ def _train_sentencepiece(dataset: tf.data.Dataset, return abs_model_path -def _load_sentencepiece_tokenizer(model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def _load_sentencepiece_tokenizer( + model_path: str, add_bos: bool = False, add_eos: bool = True, reverse: bool = False +): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) return sp_tokenizer -def load_or_train_tokenizer(dataset: tf.data.Dataset, - *, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: Tuple[str, str] = ('inputs', 'targets')): +def load_or_train_tokenizer( + dataset: tf.data.Dataset, + *, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: Tuple[str, str] = ('inputs', 'targets'), +): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: return _load_sentencepiece_tokenizer(vocab_path) @@ -139,13 +140,13 @@ def load_or_train_tokenizer(dataset: tf.data.Dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, model_path=vocab_path, - data_keys=data_keys) + data_keys=data_keys, + ) return _load_sentencepiece_tokenizer(vocab_path) @dataclasses.dataclass class TokenizeOp: - sp_tokenizer: Any data_keys: Iterable[str] = ('inputs', 'targets') diff --git a/examples/wmt/train.py b/examples/wmt/train.py index 1e792fe0a4..2861bdf8cb 100644 --- a/examples/wmt/train.py +++ b/examples/wmt/train.py @@ -68,25 +68,25 @@ def rsqrt_schedule( """ def schedule(count): - return init_value * (count + shift)**-.5 * shift**.5 + return init_value * (count + shift) ** -0.5 * shift**0.5 return schedule def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): """Creates a rsqrt schedule with linear warmup.""" - return optax.join_schedules([ - optax.linear_schedule( - init_value=0, end_value=learning_rate, transition_steps=warmup_steps), - rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), - ], - boundaries=[warmup_steps]) - - -def compute_weighted_cross_entropy(logits, - targets, - weights=None, - label_smoothing=0.0): + return optax.join_schedules( + [ + optax.linear_schedule( + init_value=0, end_value=learning_rate, transition_steps=warmup_steps + ), + rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), + ], + boundaries=[warmup_steps], + ) + + +def compute_weighted_cross_entropy(logits, targets, weights=None, label_smoothing=0.0): """Compute weighted cross entropy and entropy for log probs and targets. Args: @@ -100,16 +100,20 @@ def compute_weighted_cross_entropy(logits, Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: - raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % - (str(logits.shape), str(targets.shape))) + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" + % (str(logits.shape), str(targets.shape)) + ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( - confidence * jnp.log(confidence) + - (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) + confidence * jnp.log(confidence) + + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) soft_targets = common_utils.onehot( - targets, vocab_size, on_value=confidence, off_value=low_confidence) + targets, vocab_size, on_value=confidence, off_value=low_confidence + ) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant @@ -134,8 +138,10 @@ def compute_weighted_accuracy(logits, targets, weights=None): Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: - raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % - (str(logits.shape), str(targets.shape))) + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" + % (str(logits.shape), str(targets.shape)) + ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: @@ -147,8 +153,9 @@ def compute_weighted_accuracy(logits, targets, weights=None): def compute_metrics(logits, labels, weights, label_smoothing=0.0): """Compute summary metrics.""" - loss, weight_sum = compute_weighted_cross_entropy(logits, labels, weights, - label_smoothing) + loss, weight_sum = compute_weighted_cross_entropy( + logits, labels, weights, label_smoothing + ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { "loss": loss, @@ -163,12 +170,9 @@ def compute_metrics(logits, labels, weights, label_smoothing=0.0): # ----------------------------------------------------------------------------- -def train_step(state, - batch, - config, - learning_rate_fn, - label_smoothing=0.0, - dropout_rng=None): +def train_step( + state, batch, config, learning_rate_fn, label_smoothing=0.0, dropout_rng=None +): """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" # where multiple sequences are packed into the same example with this @@ -176,11 +180,21 @@ def train_step(state, # if such features are not present they are ignored and the example is treated # like a normal, unpacked sequence example. train_keys = [ - "inputs", "targets", "inputs_position", "targets_position", - "inputs_segmentation", "targets_segmentation" + "inputs", + "targets", + "inputs_position", + "targets_position", + "inputs_segmentation", + "targets_segmentation", ] - (inputs, targets, inputs_positions, targets_positions, inputs_segmentation, - targets_segmentation) = (batch.get(k, None) for k in train_keys) + ( + inputs, + targets, + inputs_positions, + targets_positions, + inputs_segmentation, + targets_segmentation, + ) = (batch.get(k, None) for k in train_keys) weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32) @@ -196,18 +210,22 @@ def loss_fn(params): targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, - rngs={"dropout": dropout_rng}) + rngs={"dropout": dropout_rng}, + ) - loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights, - label_smoothing) + loss, weight_sum = compute_weighted_cross_entropy( + logits, targets, weights, label_smoothing + ) mean_loss = loss / weight_sum return mean_loss, logits + step = state.step if state.dynamic_scale: # dynamic scale takes care of averaging gradients across replicas grad_fn = state.dynamic_scale.value_and_grad( - loss_fn, has_aux=True, axis_name="batch") + loss_fn, has_aux=True, axis_name="batch" + ) dynamic_scale, is_fin, (_, logits), grads = grad_fn(state.params) state = state.replace(dynamic_scale=dynamic_scale) else: @@ -225,10 +243,10 @@ def loss_fn(params): select_fn = functools.partial(jnp.where, is_fin) new_state = new_state.replace( opt_state=jax.tree_util.tree_map( - select_fn, new_state.opt_state, state.opt_state), - params=jax.tree_util.tree_map( - select_fn, new_state.params, state.params) - ) + select_fn, new_state.opt_state, state.opt_state + ), + params=jax.tree_util.tree_map(select_fn, new_state.params, state.params), + ) metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"] return new_state, metrics @@ -247,18 +265,14 @@ def initialize_cache(inputs, max_decode_len, config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype), - jnp.ones(target_shape, config.dtype)) + jax.random.PRNGKey(0), + jnp.ones(inputs.shape, config.dtype), + jnp.ones(target_shape, config.dtype), + ) return initial_variables["cache"] -def predict_step(inputs, - params, - cache, - eos_id, - max_decode_len, - config, - beam_size=4): +def predict_step(inputs, params, cache, eos_id, max_decode_len, config, beam_size=4): """Predict translation with fast decoding beam search on a batch.""" # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to @@ -267,25 +281,24 @@ def predict_step(inputs, # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] encoded_inputs = decode.flat_batch_beam_expand( - models.Transformer(config).apply({"params": params}, - inputs, - method=models.Transformer.encode), - beam_size) + models.Transformer(config).apply( + {"params": params}, inputs, method=models.Transformer.encode + ), + beam_size, + ) raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( - { - "params": params, - "cache": flat_cache - }, + {"params": params, "cache": flat_cache}, encoded_inputs, raw_inputs, # only needed for input padding mask flat_ids, mutable=["cache"], - method=models.Transformer.decode) + method=models.Transformer.decode, + ) new_flat_cache = new_vars["cache"] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] @@ -301,7 +314,8 @@ def tokens_ids_to_logits(flat_ids, flat_cache): beam_size=beam_size, alpha=0.6, eos_id=eos_id, - max_decode_len=max_decode_len) + max_decode_len=max_decode_len, + ) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. @@ -342,8 +356,7 @@ def tohost(x): return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims)) -def evaluate(*, p_eval_step, params, eval_ds: tf.data.Dataset, - num_eval_steps: int): +def evaluate(*, p_eval_step, params, eval_ds: tf.data.Dataset, num_eval_steps: int): """Evaluate the params an return a dictionary with the metrics.""" logging.info("Gathering evaluation metrics.") eval_metrics = [] @@ -358,13 +371,20 @@ def evaluate(*, p_eval_step, params, eval_ds: tf.data.Dataset, eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_util.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop - eval_metrics_sums) + eval_metrics_sums, + ) return eval_summary -def translate_and_calculate_bleu(*, p_pred_step, p_init_cache, params, - predict_ds: tf.data.Dataset, decode_tokens, - max_predict_length: int): +def translate_and_calculate_bleu( + *, + p_pred_step, + p_init_cache, + params, + predict_ds: tf.data.Dataset, + decode_tokens, + max_predict_length: int, +): """Translates the `predict_ds` and calculates the BLEU score.""" n_devices = jax.local_device_count() logging.info("Translating evaluation dataset.") @@ -377,11 +397,13 @@ def translate_and_calculate_bleu(*, p_pred_step, p_init_cache, params, padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_util.tree_map( lambda x: pad_examples(x, padded_size), # pylint: disable=cell-var-from-loop - pred_batch) + pred_batch, + ) pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch["inputs"]) - predicted = p_pred_step(pred_batch["inputs"], params, cache, decode.EOS_ID, - max_predict_length) + predicted = p_pred_step( + pred_batch["inputs"], params, cache, decode.EOS_ID, max_predict_length + ) predicted = tohost(predicted) inputs = tohost(pred_batch["inputs"]) targets = tohost(pred_batch["targets"]) @@ -390,8 +412,12 @@ def translate_and_calculate_bleu(*, p_pred_step, p_init_cache, params, sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) - logging.info("Translation: %d predictions %d references %d sources.", - len(predictions), len(references), len(sources)) + logging.info( + "Translation: %d predictions %d references %d sources.", + len(predictions), + len(references), + len(sources), + ) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) @@ -437,14 +463,15 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): n_devices=jax.local_device_count(), config=config, reverse_translation=config.reverse_translation, - vocab_path=vocab_path) + vocab_path=vocab_path, + ) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): - valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) + valid_toks = toks[: np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: @@ -473,7 +500,8 @@ def decode_tokens(toks): deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.normal(stddev=1e-6)) + bias_init=nn.initializers.normal(stddev=1e-6), + ) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) @@ -484,13 +512,14 @@ def decode_tokens(toks): target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) - initial_variables = jax.jit(m.init)(init_rng, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + initial_variables = jax.jit(m.init)( + init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32) + ) # Create train state with Adam optimizer and weight decay. learning_rate_fn = create_learning_rate_schedule( - learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) + learning_rate=config.learning_rate, warmup_steps=config.warmup_steps + ) dynamic_scale = None if dtype == jnp.float16: dynamic_scale = dynamic_scale_lib.DynamicScale() @@ -517,7 +546,8 @@ def decode_tokens(toks): start_step = int(state.step) writer = metric_writers.create_default_writer( - workdir, just_logging=jax.process_index() > 0) + workdir, just_logging=jax.process_index() > 0 + ) if start_step == 0: writer.write_hparams(dict(config)) @@ -530,24 +560,29 @@ def decode_tokens(toks): train_step, config=train_config, learning_rate_fn=learning_rate_fn, - label_smoothing=config.label_smoothing), + label_smoothing=config.label_smoothing, + ), axis_name="batch", - donate_argnums=(0,)) # pytype: disable=wrong-arg-types + donate_argnums=(0,), + ) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( - functools.partial( - eval_step, config=eval_config), - axis_name="batch") + functools.partial(eval_step, config=eval_config), axis_name="batch" + ) p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, - config=predict_config), - axis_name="batch") + config=predict_config, + ), + axis_name="batch", + ) p_pred_step = jax.pmap( functools.partial( - predict_step, config=predict_config, beam_size=config.beam_size), + predict_step, config=predict_config, beam_size=config.beam_size + ), axis_name="batch", - static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant + static_broadcasted_argnums=(3, 4), + ) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- @@ -560,11 +595,12 @@ def decode_tokens(toks): logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( - num_train_steps=config.num_train_steps, writer=writer) + num_train_steps=config.num_train_steps, writer=writer + ) if jax.process_index() == 0: hooks += [ report_progress, - periodic_actions.Profile(logdir=workdir, num_profile_steps=5) + periodic_actions.Profile(logdir=workdir, num_profile_steps=5), ] train_metrics = [] with metric_writers.ensure_flushes(writer): @@ -574,8 +610,7 @@ def decode_tokens(toks): # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter))) - state, metrics = p_train_step( - state, batch, dropout_rng=dropout_rngs) + state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. @@ -602,9 +637,9 @@ def decode_tokens(toks): p_eval_step=p_eval_step, params=state.params, eval_ds=eval_ds, - num_eval_steps=config.num_eval_steps) - writer.write_scalars( - step, {"eval_" + k: v for k, v in eval_results.items()}) + num_eval_steps=config.num_eval_steps, + ) + writer.write_scalars(step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( @@ -613,14 +648,14 @@ def decode_tokens(toks): params=state.params, predict_ds=predict_ds, decode_tokens=decode_tokens, - max_predict_length=config.max_predict_length) + max_predict_length=config.max_predict_length, + ) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. - save_checkpoint = (step % config.checkpoint_every_steps == 0 or - is_last_step) - if (config.save_checkpoints and save_checkpoint): + save_checkpoint = step % config.checkpoint_every_steps == 0 or is_last_step + if config.save_checkpoints and save_checkpoint: logging.info("Saving checkpoint step %d.", step) with report_progress.timed("checkpoint"): checkpoints.save_checkpoint_multiprocess( diff --git a/flax/__init__.py b/flax/__init__.py index 07bc25107b..c98085218b 100644 --- a/flax/__init__.py +++ b/flax/__init__.py @@ -27,4 +27,4 @@ # DO NOT REMOVE - Marker for internal deprecated API. # DO NOT REMOVE - Marker for internal logging. -from .version import __version__ \ No newline at end of file +from .version import __version__ diff --git a/flax/configurations.py b/flax/configurations.py index a6e0ac2470..393cd3cc9b 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -18,7 +18,6 @@ To modify a config value on run time, call: ``flax.config.update('flax_', )`` - """ import os @@ -32,6 +31,7 @@ # Config parsing utils + def define_bool_state(name, default, help): """Set up a boolean flag using JAX's config system. @@ -65,7 +65,8 @@ def static_bool_env(varname: str, default: bool) -> bool: return False else: raise ValueError( - 'invalid truth value {!r} for environment {!r}'.format(val, varname)) + 'invalid truth value {!r} for environment {!r}'.format(val, varname) + ) @contextmanager @@ -92,22 +93,26 @@ def temp_flip_flag(var_name: str, var_value: bool): flax_filter_frames = define_bool_state( name='filter_frames', default=True, - help=('Whether to hide flax-internal stack frames from tracebacks.')) + help=('Whether to hide flax-internal stack frames from tracebacks.'), +) flax_profile = define_bool_state( name='profile', default=True, - help=('Whether to run Module methods under jax.named_scope for profiles.')) + help=('Whether to run Module methods under jax.named_scope for profiles.'), +) flax_use_orbax_checkpointing = define_bool_state( name='use_orbax_checkpointing', default=True, - help=('Whether to use Orbax to save checkpoints.')) + help=('Whether to use Orbax to save checkpoints.'), +) flax_preserve_adopted_names = define_bool_state( name='preserve_adopted_names', default=False, - help=("When adopting outside modules, don't clobber existing names.")) + help=("When adopting outside modules, don't clobber existing names."), +) # TODO(marcuschiam): remove this feature flag once regular dict migration is complete flax_return_frozendict = define_bool_state( @@ -117,7 +122,9 @@ def temp_flip_flag(var_name: str, var_value: bool): ) flax_fix_rng = define_bool_state( - name ='fix_rng_separator', + name='fix_rng_separator', default=False, - help=('Whether to add separator characters when folding in static data into PRNG keys.') + help=( + 'Whether to add separator characters when folding in static data into PRNG keys.' + ), ) diff --git a/flax/core/__init__.py b/flax/core/__init__.py index d23000ea04..0a344ef608 100644 --- a/flax/core/__init__.py +++ b/flax/core/__init__.py @@ -14,43 +14,44 @@ from .axes_scan import broadcast as broadcast from .frozen_dict import ( - FrozenDict as FrozenDict, - freeze as freeze, - unfreeze as unfreeze, - copy as copy, - pop as pop, - pretty_repr as pretty_repr + FrozenDict as FrozenDict, + freeze as freeze, + unfreeze as unfreeze, + copy as copy, + pop as pop, + pretty_repr as pretty_repr, ) from .tracers import ( - current_trace as current_trace, - trace_level as trace_level, - check_trace_level as check_trace_level + current_trace as current_trace, + trace_level as trace_level, + check_trace_level as check_trace_level, ) from .scope import ( - Scope as Scope, - Array as Array, - DenyList as DenyList, - apply as apply, - init as init, - lazy_init as lazy_init, - bind as bind) + Scope as Scope, + Array as Array, + DenyList as DenyList, + apply as apply, + init as init, + lazy_init as lazy_init, + bind as bind, +) from .lift import ( - scan as scan, - vmap as vmap, - jit as jit, - remat as remat, - remat_scan as remat_scan, - while_loop as while_loop, - custom_vjp as custom_vjp, - vjp as vjp, - jvp as jvp + scan as scan, + vmap as vmap, + jit as jit, + remat as remat, + remat_scan as remat_scan, + while_loop as while_loop, + custom_vjp as custom_vjp, + vjp as vjp, + jvp as jvp, ) from .meta import ( - AxisMetadata as AxisMetadata, - unbox as unbox, - map_axis_meta as map_axis_meta, + AxisMetadata as AxisMetadata, + unbox as unbox, + map_axis_meta as map_axis_meta, ) diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index 7cce9ae230..7a6e84d9b8 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -31,6 +31,7 @@ class _Broadcast: pass + broadcast = _Broadcast() @@ -40,7 +41,8 @@ def scan( out_axes: Any, length: Optional[int] = None, reverse: bool = False, - unroll: int = 1): + unroll: int = 1, +): """A wrapper around `jax.lax.scan` with in_axes/out_axes api. Example:: @@ -84,10 +86,12 @@ def transpose_to_front(ax, xs): return () if ax == 0: return xs + def trans(x): perm = tuple(range(x.ndim)) perm = (ax,) + tuple(np.delete(perm, ax)) return jnp.transpose(x, perm) + return jax.tree_util.tree_map(trans, xs) def transpose_from_front(ax, xs): @@ -95,6 +99,7 @@ def transpose_from_front(ax, xs): return () if ax == 0: return xs + def trans(x): if ax < 0: pax = x.ndim - ax @@ -103,6 +108,7 @@ def trans(x): assert pax < x.ndim perm = tuple(range(1, pax + 1)) + (0,) + tuple(range(pax + 1, x.ndim)) return jnp.transpose(x, perm) + return jax.tree_util.tree_map(trans, xs) def scan_fn(broadcast_in, init, *args): @@ -111,47 +117,53 @@ def scan_fn(broadcast_in, init, *args): def body_fn(c, xs, init_mode=False): # inject constants xs = jax.tree_util.tree_map( - lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs) + lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs + ) broadcast_out, c, ys = fn(broadcast_in, c, *xs) if init_mode: ys = jax.tree_util.tree_map( - lambda ax, y: (y if ax is broadcast else ()), out_axes, ys) + lambda ax, y: (y if ax is broadcast else ()), out_axes, ys + ) return broadcast_out, ys else: ys = jax.tree_util.tree_map( - lambda ax, y: (() if ax is broadcast else y), out_axes, ys) + lambda ax, y: (() if ax is broadcast else y), out_axes, ys + ) return c, ys + broadcast_body = functools.partial(body_fn, init_mode=True) carry_avals = jax.tree_util.tree_map( - lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), - init) + lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init + ) scan_avals = jax.tree_util.tree_map( - lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), - xs) + lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs + ) input_avals = (carry_avals, scan_avals) in_avals, in_tree = jax.tree_util.tree_flatten(input_avals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( - lu.wrap_init(broadcast_body), in_tree) + lu.wrap_init(broadcast_body), in_tree + ) in_pvals = list(map(pe.PartialVal.unknown, in_avals)) _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: if pv is not None: - raise ValueError( - 'broadcasted variable has a data dependency on the scan body.') + raise ValueError('broadcasted variable has a data dependency on the scan body.') out_flat.append(const) broadcast_in, constants_out = jax.tree_util.tree_unflatten(out_tree(), out_flat) - c, ys = lax.scan(body_fn, init, xs, length=length, - reverse=reverse, unroll=unroll) + c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse, unroll=unroll) ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys) ys = jax.tree_util.tree_map( - lambda ax, const, y: (const if ax is broadcast else y), out_axes, - constants_out, ys) + lambda ax, const, y: (const if ax is broadcast else y), + out_axes, + constants_out, + ys, + ) return broadcast_in, c, ys return scan_fn diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 2d55e8b9b6..ed711a0829 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -50,6 +50,7 @@ def _indent(x, num_spaces): @jax.tree_util.register_pytree_with_keys_class class FrozenDict(Mapping[K, V]): """An immutable variant of the Python dict.""" + __slots__ = ('_dict', '_hash') def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name @@ -88,6 +89,7 @@ def __reduce__(self): def pretty_repr(self, num_spaces=4): """Returns an indented representation of the nested dictionary.""" + def pretty_dict(x): if not isinstance(x, dict): return repr(x) @@ -98,6 +100,7 @@ def pretty_dict(x): return '{\n' + _indent(rep, num_spaces) + '}' else: return '{}' + return f'FrozenDict({pretty_dict(self._dict)})' def __hash__(self): @@ -110,7 +113,7 @@ def __hash__(self): def copy(self, add_or_replace: Mapping[K, V]) -> 'FrozenDict[K, V]': """Create a new FrozenDict with additional or replaced entries.""" - return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type] + return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type] def keys(self): return FrozenKeysView(self) @@ -218,7 +221,10 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: return x -def copy(x: Union[FrozenDict, Dict[str, Any]], add_or_replace: Union[FrozenDict, Dict[str, Any]]) -> Union[FrozenDict, Dict[str, Any]]: +def copy( + x: Union[FrozenDict, Dict[str, Any]], + add_or_replace: Union[FrozenDict, Dict[str, Any]], +) -> Union[FrozenDict, Dict[str, Any]]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the behavior of `FrozenDict.copy`. @@ -237,13 +243,15 @@ def copy(x: Union[FrozenDict, Dict[str, Any]], add_or_replace: Union[FrozenDict, if isinstance(x, FrozenDict): return x.copy(add_or_replace) elif isinstance(x, dict): - new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x + new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x new_dict.update(add_or_replace) return new_dict raise TypeError(f'Expected FrozenDict or dict, got {type(x)}') -def pop(x: Union[FrozenDict, Dict[str, Any]], key: str) -> Tuple[Union[FrozenDict, Dict[str, Any]], Any]: +def pop( + x: Union[FrozenDict, Dict[str, Any]], key: str +) -> Tuple[Union[FrozenDict, Dict[str, Any]], Any]: """Create a new dict where one entry is removed. This is a utility function that can act on either a FrozenDict or regular dict and mimics the behavior of `FrozenDict.pop`. @@ -262,7 +270,7 @@ def pop(x: Union[FrozenDict, Dict[str, Any]], key: str) -> Tuple[Union[FrozenDic if isinstance(x, FrozenDict): return x.pop(key) elif isinstance(x, dict): - new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x + new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x value = new_dict.pop(key) return new_dict, value raise TypeError(f'Expected FrozenDict or dict, got {type(x)}') @@ -284,6 +292,7 @@ def pretty_repr(x: Any, num_spaces: int = 4) -> str: if isinstance(x, FrozenDict): return x.pretty_repr() else: + def pretty_dict(x): if not isinstance(x, dict): return repr(x) @@ -294,6 +303,7 @@ def pretty_dict(x): return '{\n' + _indent(rep, num_spaces) + '}' else: return '{}' + return pretty_dict(x) @@ -304,16 +314,20 @@ def _frozen_dict_state_dict(xs): def _restore_frozen_dict(xs, states): diff = set(map(str, xs.keys())).difference(states.keys()) if diff: - raise ValueError('The target dict keys and state dict keys do not match,' - f' target dict contains keys {diff} which are not present in state dict ' - f'at path {serialization.current_path()}') + raise ValueError( + 'The target dict keys and state dict keys do not match,' + f' target dict contains keys {diff} which are not present in state dict ' + f'at path {serialization.current_path()}' + ) return FrozenDict( - {key: serialization.from_state_dict(value, states[key], name=key) - for key, value in xs.items()}) + { + key: serialization.from_state_dict(value, states[key], name=key) + for key, value in xs.items() + } + ) serialization.register_serialization_state( - FrozenDict, - _frozen_dict_state_dict, - _restore_frozen_dict) + FrozenDict, _frozen_dict_state_dict, _restore_frozen_dict +) diff --git a/flax/core/lift.py b/flax/core/lift.py index b1e6cc4419..8c148b00af 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -17,8 +17,20 @@ import collections import dataclasses import functools -from typing import (Any, Callable, Dict, Generic, Iterable, List, Mapping, - Optional, Sequence, Tuple, TypeVar, Union) +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import warnings from . import axes_scan @@ -28,10 +40,19 @@ from .frozen_dict import unfreeze import jax from jax import random -from .scope import (CollectionFilter, DenyList, PRNGSequenceFilter, # pylint: disable=g-multiple-import - Filter, Scope, group_collections, in_filter, - intersect_filters, is_filter_empty, subtract_filters, - union_filters) +from .scope import ( + CollectionFilter, + DenyList, + PRNGSequenceFilter, # pylint: disable=g-multiple-import + Filter, + Scope, + group_collections, + in_filter, + intersect_filters, + is_filter_empty, + subtract_filters, + union_filters, +) traceback_util.register_exclusion(__file__) @@ -41,7 +62,8 @@ def tree_map_rngs(fn, tree): """Needed for mapping JAX random.* functions over KeyArray leaves.""" return jax.tree_util.tree_map( - fn, tree, is_leaf=lambda x: isinstance(x, random.KeyArray)) + fn, tree, is_leaf=lambda x: isinstance(x, random.KeyArray) + ) def _dedup_scopes(scopes): @@ -82,12 +104,14 @@ def _transpose(xs): return tuple(zip(*xs)) -def pack(fn: Callable[..., Any], - in_variable_filters: Sequence[CollectionFilter], - out_variable_filters: Sequence[CollectionFilter], - rng_filters: Sequence[PRNGSequenceFilter], - name=None, - enable_kwargs=False) -> Callable[..., Any]: +def pack( + fn: Callable[..., Any], + in_variable_filters: Sequence[CollectionFilter], + out_variable_filters: Sequence[CollectionFilter], + rng_filters: Sequence[PRNGSequenceFilter], + name=None, + enable_kwargs=False, +) -> Callable[..., Any]: """Pack variables and rngs for functional transformations. The pack function is the building block for all other lifted transformations. @@ -104,10 +128,11 @@ def pack(fn: Callable[..., Any], Returns: A callable which expects a scope as the first argument. """ + @functools.wraps(fn) def wrapper(scope_tree: Scope, *args, **kwargs): if not enable_kwargs and kwargs: - msg = 'kwargs are not supported in {}, so \"{}\" is(are) ignored' + msg = 'kwargs are not supported in {}, so "{}" is(are) ignored' warnings.warn(msg.format(name, ', '.join(kwargs.keys())), RuntimeWarning) # pylint: disable=protected-access scopes, treedef = jax.tree_util.tree_flatten(scope_tree) @@ -118,8 +143,9 @@ def wrapper(scope_tree: Scope, *args, **kwargs): for scope in scopes: scope._validate_trace_level() scope._populate_collections() - variable_groups_xs.append(group_collections( - scope._variables, in_variable_filters)) + variable_groups_xs.append( + group_collections(scope._variables, in_variable_filters) + ) variable_groups_xs_t = _transpose(variable_groups_xs) # Make sure that in-only variable collections are frozen @@ -127,8 +153,8 @@ def wrapper(scope_tree: Scope, *args, **kwargs): for variable_group in variable_group_xs: for col_name, collection in variable_group.items(): col_in_out = any( - in_filter(col_filter, col_name) - for col_filter in out_variable_filters) + in_filter(col_filter, col_name) for col_filter in out_variable_filters + ) if not col_in_out: variable_group[col_name] = freeze(collection) rng_groups_xs = [] @@ -142,9 +168,9 @@ def wrapper(scope_tree: Scope, *args, **kwargs): inner_scopes: List[Scope] = [] - def scope_fn(variable_groups_xs_t, - rng_groups_xs_t, - mutable_filter: CollectionFilter = True): + def scope_fn( + variable_groups_xs_t, rng_groups_xs_t, mutable_filter: CollectionFilter = True + ): nonlocal inner_scopes for inner_scope in inner_scopes: inner_scope.invalidate() @@ -155,13 +181,13 @@ def scope_fn(variable_groups_xs_t, # could be () in the edge case where no rngs or variable_groups are lifted # in this case fallback to ((),) * len(scopes) to make sure the zip has # something to iterate over for each scope. - variable_groups_xs = _transpose(variable_groups_xs_t) or ( - (),) * len(scopes) + variable_groups_xs = _transpose(variable_groups_xs_t) or ((),) * len(scopes) rng_groups_xs = _transpose(rng_groups_xs_t) or ((),) * len(scopes) assert len(variable_groups_xs) == len(scopes) assert len(rng_groups_xs) == len(scopes) for variable_groups, rng_groups, scope, rng_counters in zip( - variable_groups_xs, rng_groups_xs, scopes, inner_rng_counters): + variable_groups_xs, rng_groups_xs, scopes, inner_rng_counters + ): variables = {} rngs = {} for variable_group in variable_groups: @@ -172,7 +198,8 @@ def scope_fn(variable_groups_xs_t, # sharing. variables = jax.tree_util.tree_map(lambda x: x, variables) scope_mutable = intersect_filters( - intersect_filters(scope.mutable, mutable), mutable_filter) + intersect_filters(scope.mutable, mutable), mutable_filter + ) new_path = scope.path if name: if new_path: @@ -180,9 +207,14 @@ def scope_fn(variable_groups_xs_t, else: new_path = (f'{name}()',) inner_scope = Scope( - variables, name=scope.name, rngs=rngs, - mutable=scope_mutable, parent=None, - path=new_path, flags=scope.flags) + variables, + name=scope.name, + rngs=rngs, + mutable=scope_mutable, + parent=None, + path=new_path, + flags=scope.flags, + ) inner_scope.rng_counters = rng_counters inner_scopes.append(inner_scope) inner_scopes = _dup_scopes(scopes, inner_scopes, paths) @@ -197,11 +229,14 @@ def repack(inner_scope_tree): for inner_scope in inner_scopes: inner_scope.invalidate() inner_scope._validate_trace_level() - mutable_variables = {key: val for key, val - in inner_scope._variables.items() - if in_filter(inner_scope.mutable, key)} + mutable_variables = { + key: val + for key, val in inner_scope._variables.items() + if in_filter(inner_scope.mutable, key) + } out_variable_groups = group_collections( - mutable_variables, tuple(out_variable_filters) + (True,)) + mutable_variables, tuple(out_variable_filters) + (True,) + ) remainder = tuple(out_variable_groups[-1].keys()) if remainder: raise ValueError(f'unmapped output variables: {remainder}') @@ -212,21 +247,19 @@ def repack(inner_scope_tree): try: if enable_kwargs: y, out_variable_groups_xs_t = fn( - scope_fn, repack, - variable_groups_xs_t, rng_groups_xs_t, - *args, **kwargs) + scope_fn, repack, variable_groups_xs_t, rng_groups_xs_t, *args, **kwargs + ) else: y, out_variable_groups_xs_t = fn( - scope_fn, repack, - variable_groups_xs_t, rng_groups_xs_t, - *args) + scope_fn, repack, variable_groups_xs_t, rng_groups_xs_t, *args + ) finally: for inner_scope in inner_scopes: inner_scope.invalidate() out_variable_groups_xs = _transpose(out_variable_groups_xs_t) - for scope, out_variable_groups, rng_counters in zip(scopes, - out_variable_groups_xs, - inner_rng_counters): + for scope, out_variable_groups, rng_counters in zip( + scopes, out_variable_groups_xs, inner_rng_counters + ): for out_variable_group in out_variable_groups: for col_name, collection in out_variable_group.items(): if not scope.is_mutable_collection(col_name): @@ -235,20 +268,23 @@ def repack(inner_scope_tree): for var_name, value in collection.items(): scope.put_variable(col_name, var_name, value) return y + return wrapper id_fn = lambda x: x -def map_variables(fn: Callable[..., Any], - mapped_collections: CollectionFilter, - map_in_fn: Callable[..., Any] = id_fn, - map_out_fn: Callable[..., Any] = id_fn, - init: bool = False, - mutable: bool = False, - rngs: PRNGSequenceFilter = True, - variables: CollectionFilter = True) -> Callable[..., Any]: +def map_variables( + fn: Callable[..., Any], + mapped_collections: CollectionFilter, + map_in_fn: Callable[..., Any] = id_fn, + map_out_fn: Callable[..., Any] = id_fn, + init: bool = False, + mutable: bool = False, + rngs: PRNGSequenceFilter = True, + variables: CollectionFilter = True, +) -> Callable[..., Any]: """Map Variables inside a scope. Args: @@ -271,8 +307,10 @@ def wrapper(scope_fn, repack, variable_groups, rng_groups, *args, **kwargs): target, variables = variable_groups if init: scopes = scope_fn((target, variables), rng_groups) - has_mutable_cols = any(not is_filter_empty(scope.mutable) - for scope in jax.tree_util.tree_leaves(scopes)) + has_mutable_cols = any( + not is_filter_empty(scope.mutable) + for scope in jax.tree_util.tree_leaves(scopes) + ) if has_mutable_cols: fn(scopes, *args, **kwargs) target, _ = repack(scopes) @@ -291,19 +329,19 @@ def wrapper(scope_fn, repack, variable_groups, rng_groups, *args, **kwargs): return y, (out_target, out_vars) in_vars = (mapped_collections, variables) - out_vars = in_vars if is_target_out else (False, - subtract_filters( - variables, mapped_collections)) + out_vars = ( + in_vars + if is_target_out + else (False, subtract_filters(variables, mapped_collections)) + ) return pack( - wrapper, - in_vars, - out_vars, (rngs,), - enable_kwargs=True, - name='map_variables') + wrapper, in_vars, out_vars, (rngs,), enable_kwargs=True, name='map_variables' + ) def swap_collection(fn: Callable[..., Any], col_a: str, col_b: str): """Swap two collections.""" + def swap(target): a = target[col_a] if col_a in target else {} b = target[col_b] if col_b in target else {} @@ -316,12 +354,14 @@ def swap(target): @dataclasses.dataclass(frozen=True) class In(Generic[T]): """Specifies a variable collection should only be lifted as input.""" + axis: T @dataclasses.dataclass(frozen=True) class Out(Generic[T]): """Specifies a variable collection should only be lifted as output.""" + axis: T @@ -407,8 +447,10 @@ def f(scope, x, y): ``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data returned by ``fn``. """ + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): vjp_vars, other_vars = variable_groups + @functools.wraps(fn) def wrapper(vjp_vars, *args): variable_groups = (vjp_vars, other_vars) @@ -419,20 +461,25 @@ def wrapper(vjp_vars, *args): y = fn(scope, *args) aux = () return y, (aux, repack_fn(scope)) + y, bwd, (aux, out_vars) = jax.vjp( - wrapper, vjp_vars, *args, - reduce_axes=reduce_axes, has_aux=True) + wrapper, vjp_vars, *args, reduce_axes=reduce_axes, has_aux=True + ) treedef = jax.tree_util.tree_structure(scope) - bwd = jax.tree_util.Partial( - functools.partial(_bwd_wrapper, treedef), bwd) + bwd = jax.tree_util.Partial(functools.partial(_bwd_wrapper, treedef), bwd) if has_aux: return (y, bwd, aux), out_vars else: return (y, bwd), out_vars + return pack( - inner, (vjp_variables, variables), (variables,), (rngs,), + inner, + (vjp_variables, variables), + (variables,), + (rngs,), name='vjp', - enable_kwargs=False)(scope, *primals) + enable_kwargs=False, + )(scope, *primals) def jvp( @@ -443,7 +490,7 @@ def jvp( variable_tangents, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, - ) -> Tuple[Any, Any]: +) -> Tuple[Any, Any]: """A lifted version of ``jax.jvp``. See ``jax.jvp`` for the unlifted Jacobian-vector product (forward gradient). @@ -491,8 +538,10 @@ def f(scope, x): ``tangents_out`` value has the same Python tree structure and shapes as ``primals_out``. """ + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): jvp_vars, other_vars = variable_groups + @functools.wraps(fn) def wrapper(vars_primals, args): variable_groups = (vars_primals, other_vars) @@ -500,32 +549,36 @@ def wrapper(vars_primals, args): y = fn(scope, *args) return y, repack_fn(scope) - (y, out_vars), out_tangents = jax.jvp(wrapper, (jvp_vars, args), - (variable_tangents, tangents)) + (y, out_vars), out_tangents = jax.jvp( + wrapper, (jvp_vars, args), (variable_tangents, tangents) + ) return (y, out_tangents[0]), out_vars + # filter out empty tangent collections because JAX will error on non-equal # tree structure for example: {"params": {}} != {}. treedef = jax.tree_util.tree_structure(scope) - variable_tangents = tuple({k: v # pylint: disable=g-complex-comprehension - for k, v in vt.items() - if v} - for vt in treedef.flatten_up_to(variable_tangents)) + variable_tangents = tuple( + {k: v for k, v in vt.items() if v} # pylint: disable=g-complex-comprehension + for vt in treedef.flatten_up_to(variable_tangents) + ) target = tuple(variable_tangents[0].keys()) return pack( - inner, (target, variables), (variables,), (rngs,), - name='jvp', enable_kwargs=False)(scope, *primals) - - -def vmap(fn: Callable[..., Any], - variable_axes: Mapping[CollectionFilter, InOutAxis], - split_rngs: Mapping[PRNGSequenceFilter, bool], - in_axes=0, - out_axes=0, - axis_size: Optional[int] = None, - axis_name: Optional[str] = None, - spmd_axis_name: Optional[str] = None, - metadata_params: Dict[Any, Any] = {}) -> Callable[..., Any]: + inner, (target, variables), (variables,), (rngs,), name='jvp', enable_kwargs=False + )(scope, *primals) + + +def vmap( + fn: Callable[..., Any], + variable_axes: Mapping[CollectionFilter, InOutAxis], + split_rngs: Mapping[PRNGSequenceFilter, bool], + in_axes=0, + out_axes=0, + axis_size: Optional[int] = None, + axis_name: Optional[str] = None, + spmd_axis_name: Optional[str] = None, + metadata_params: Dict[Any, Any] = {}, +) -> Callable[..., Any]: """A lifted version of ``jax.vmap``. See ``jax.vmap`` for the unlifted batch transform in Jax. @@ -597,12 +650,12 @@ def find_axis_size(axis, x): return () # split rngs - axis_sizes = jax.tree_util.tree_map(find_axis_size, - (variable_in_axes, in_axes), - (variable_groups, args)) + axis_sizes = jax.tree_util.tree_map( + find_axis_size, (variable_in_axes, in_axes), (variable_groups, args) + ) axis_sizes = set(jax.tree_util.tree_leaves(axis_sizes)) if axis_size is None and len(axis_sizes) == 1: - d_axis_size, = axis_sizes + (d_axis_size,) = axis_sizes elif len(axis_sizes) > 1: raise ValueError(f'Inconsistent batch axis sizes: {axis_sizes}') elif axis_size is None: @@ -613,13 +666,13 @@ def find_axis_size(axis, x): rng_groups = tuple( tree_map_rngs(split_fn, rng_group) if split else rng_group - for rng_group, split in zip(rng_groups, rng_splits)) + for rng_group, split in zip(rng_groups, rng_splits) + ) new_variable_groups = [] for var_group, axis in zip(variable_groups, variable_in_axes): if axis is not None: - new_variable_groups.append(meta.remove_axis( - var_group, axis, metadata_params)) + new_variable_groups.append(meta.remove_axis(var_group, axis, metadata_params)) else: new_variable_groups.append(var_group) variable_groups = tuple(new_variable_groups) @@ -630,7 +683,8 @@ def find_axis_size(axis, x): out_axes=(out_axes, variable_out_axes), axis_name=axis_name, axis_size=axis_size, - spmd_axis_name=spmd_axis_name) + spmd_axis_name=spmd_axis_name, + ) @functools.wraps(fn) def mapped(variable_groups, rng_groups, args): scope = scope_fn(variable_groups, rng_groups) @@ -647,26 +701,27 @@ def mapped(variable_groups, rng_groups, args): vars_out = tuple(new_vars_out) return y, vars_out - return pack( - inner, variable_in_groups, variable_out_groups, rng_groups, - name='vmap') + return pack(inner, variable_in_groups, variable_out_groups, rng_groups, name='vmap') ScanAxis = int InOutScanAxis = Union[ScanAxis, In[ScanAxis], Out[ScanAxis]] -def scan(fn: Callable[..., Any], - variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {}, - variable_broadcast: CollectionFilter = False, - variable_carry: CollectionFilter = False, - split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, - in_axes=0, out_axes=0, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1, - data_transform: Optional[Callable[..., Any]] = None, - metadata_params: Dict[Any, Any] = {}, - ) -> Callable[..., Any]: + +def scan( + fn: Callable[..., Any], + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {}, + variable_broadcast: CollectionFilter = False, + variable_carry: CollectionFilter = False, + split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, + in_axes=0, + out_axes=0, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + data_transform: Optional[Callable[..., Any]] = None, + metadata_params: Dict[Any, Any] = {}, +) -> Callable[..., Any]: """A lifted version of ``jax.lax.scan``. See ``jax.lax.scan`` for the unlifted scan in Jax. @@ -743,23 +798,21 @@ def body_fn(scope, c, x): assert all(isinstance(ax, int) for ax in variable_in_axes) assert all(isinstance(ax, int) for ax in variable_out_axes) rng_groups, rng_splits = _unzip2(split_rngs.items()) - rng_axes = tuple(0 if rng_split else axes_scan.broadcast - for rng_split in rng_splits) + rng_axes = tuple(0 if rng_split else axes_scan.broadcast for rng_split in rng_splits) - def inner(scope_fn, repack_fn, - variable_groups, rng_groups, - init, *args): + def inner(scope_fn, repack_fn, variable_groups, rng_groups, init, *args): def find_length(axis, x): if axis is not axes_scan.broadcast: leaves = jax.tree_util.tree_leaves(x) if leaves: return leaves[0].shape[axis] return () + # split rngs lengths = jax.tree_util.tree_map(find_length, in_axes, args) lengths = set(jax.tree_util.tree_leaves(lengths)) if length is None and len(lengths) == 1: - d_length, = lengths + (d_length,) = lengths elif len(lengths) > 1: raise ValueError(f'Inconsistent scan lengths: {lengths}') elif length is None: @@ -770,19 +823,22 @@ def find_length(axis, x): rng_groups = tuple( tree_map_rngs(split_fn, rng_group) if split else rng_group - for rng_group, split in zip(rng_groups, rng_splits)) + for rng_group, split in zip(rng_groups, rng_splits) + ) - @functools.partial(axes_scan.scan, - in_axes=(variable_in_axes, rng_axes, in_axes), - out_axes=(out_axes, variable_out_axes), - length=length, reverse=reverse, - unroll=unroll) + @functools.partial( + axes_scan.scan, + in_axes=(variable_in_axes, rng_axes, in_axes), + out_axes=(out_axes, variable_out_axes), + length=length, + reverse=reverse, + unroll=unroll, + ) def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): carry_vars, c = carry variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups if data_transform is not None: - variable_groups, rng_groups = data_transform(variable_groups, - rng_groups) + variable_groups, rng_groups = data_transform(variable_groups, rng_groups) scope = scope_fn(variable_groups, rng_groups) c, y = fn(scope, c, *args) out_vars = repack_fn(scope) @@ -804,13 +860,16 @@ def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): for scan_group, axis in zip(scan_vars, variable_in_axes): new_scan_vars.append(meta.remove_axis(scan_group, axis, metadata_params)) broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned( - broadcast_vars, (carry_vars, init), tuple(new_scan_vars), - rng_groups, args) + broadcast_vars, (carry_vars, init), tuple(new_scan_vars), rng_groups, args + ) new_scan_vars = [] for scan_group, axis in zip(scan_vars, variable_out_axes): new_scan_vars.append(meta.add_axis(scan_group, axis, metadata_params)) scan_vars = tuple(new_scan_vars) - out_vars = (broadcast_vars, carry_vars,) + scan_vars + out_vars = ( + broadcast_vars, + carry_vars, + ) + scan_vars return (c, ys), out_vars return pack( @@ -818,18 +877,22 @@ def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): (variable_broadcast, variable_carry) + variable_in_groups, (variable_broadcast, variable_carry) + variable_out_groups, rng_groups, - name='scan') + name='scan', + ) C = TypeVar('C') -def while_loop(cond_fn: Callable[[Scope, C], bool], - body_fn: Callable[[Scope, C], C], - scope: Scope, init: C, - carry_variables: CollectionFilter = False, - broadcast_variables: CollectionFilter = True, - split_rngs: Mapping[PRNGSequenceFilter, bool] = {}) -> C: +def while_loop( + cond_fn: Callable[[Scope, C], bool], + body_fn: Callable[[Scope, C], C], + scope: Scope, + init: C, + carry_variables: CollectionFilter = False, + broadcast_variables: CollectionFilter = True, + split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, +) -> C: """Lifted version of jax.lax.while_loop. The lifted scope is passed to `cond_fn` and `body_fn`. @@ -871,37 +934,35 @@ def body_fn(scope, c): """ rng_groups, rng_splits = _unzip2(split_rngs.items()) - def inner(scope_fn, repack_fn, - variable_groups, rng_groups): + def inner(scope_fn, repack_fn, variable_groups, rng_groups): carry_variables, broadcast_variables = variable_groups def make_loop_rngs(i): local_rng_groups = [] for rng_group, rng_split in zip(rng_groups, rng_splits): if rng_split: - rng_group = tree_map_rngs(lambda rng: random.fold_in(rng, i), - rng_group) + rng_group = tree_map_rngs(lambda rng: random.fold_in(rng, i), rng_group) local_rng_groups.append(rng_group) return local_rng_groups def cond_wrapper(c): i, carry_variables, carry = c - scope = scope_fn((carry_variables, broadcast_variables), - make_loop_rngs(-i), - mutable_filter=False) + scope = scope_fn( + (carry_variables, broadcast_variables), + make_loop_rngs(-i), + mutable_filter=False, + ) return cond_fn(scope, carry) def body_wrapper(c): i, carry_variables, carry = c - scope = scope_fn((carry_variables, broadcast_variables), - make_loop_rngs(i)) + scope = scope_fn((carry_variables, broadcast_variables), make_loop_rngs(i)) carry = body_fn(scope, carry) - carry_variables, = repack_fn(scope) + (carry_variables,) = repack_fn(scope) return (i + 1, carry_variables, carry) c = (0, carry_variables, init) - _, carry_variables, carry = jax.lax.while_loop(cond_wrapper, body_wrapper, - c) + _, carry_variables, carry = jax.lax.while_loop(cond_wrapper, body_wrapper, c) return carry, (carry_variables,) return pack( @@ -909,14 +970,19 @@ def body_wrapper(c): (carry_variables, broadcast_variables), (carry_variables,), rng_groups, - name='while_loop')(scope) + name='while_loop', + )(scope) -def cond(pred: Any, - true_fun: Callable[..., C], false_fun: Callable[..., C], - scope: Scope, *operands, - variables: CollectionFilter = True, - rngs: PRNGSequenceFilter = True) -> C: +def cond( + pred: Any, + true_fun: Callable[..., C], + false_fun: Callable[..., C], + scope: Scope, + *operands, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> C: """Lifted version of ``jax.lax.cond``. The returned values from ``true_fun`` and ``false_fun`` @@ -957,31 +1023,29 @@ def false_fn(scope, x): The result of the evaluated branch (``true_fun`` or ``false_fun``). """ branches = [true_fun, false_fun] - def inner(scope_fn, repack_fn, - variable_groups, rng_groups): + + def inner(scope_fn, repack_fn, variable_groups, rng_groups): def branch_wrapper(branch_fn, *operands): scope = scope_fn(variable_groups, rng_groups) y = branch_fn(scope, *operands) return y, repack_fn(scope) + pure_branches = [ - functools.partial(branch_wrapper, branch_fn) - for branch_fn in branches] - return jax.lax.cond( - pred, pure_branches[0], pure_branches[1], *operands) + functools.partial(branch_wrapper, branch_fn) for branch_fn in branches + ] + return jax.lax.cond(pred, pure_branches[0], pure_branches[1], *operands) - return pack( - inner, - (variables,), - (variables,), - (rngs,), - name='cond')(scope) + return pack(inner, (variables,), (variables,), (rngs,), name='cond')(scope) -def switch(index: Any, - branches: Sequence[Callable[..., C]], - scope: Scope, *operands, - variables: CollectionFilter = True, - rngs: PRNGSequenceFilter = True) -> C: +def switch( + index: Any, + branches: Sequence[Callable[..., C]], + scope: Scope, + *operands, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> C: """Lifted version of ``jax.lax.switch``. The returned values from ``branches`` @@ -1046,30 +1110,27 @@ def c_fn(scope, x): The result of the evaluated branch. """ - def inner(scope_fn, repack_fn, - variable_groups, rng_groups): + def inner(scope_fn, repack_fn, variable_groups, rng_groups): def branch_wrapper(branch_fn, *operands): scope = scope_fn(variable_groups, rng_groups) y = branch_fn(scope, *operands) return y, repack_fn(scope) + pure_branches = [ - functools.partial(branch_wrapper, branch_fn) - for branch_fn in branches] + functools.partial(branch_wrapper, branch_fn) for branch_fn in branches + ] return jax.lax.switch(index, pure_branches, *operands) - return pack( - inner, - (variables,), - (variables,), - (rngs,), - name='switch')(scope) + return pack(inner, (variables,), (variables,), (rngs,), name='switch')(scope) -def custom_vjp(fn: Callable[..., Any], - forward_fn: Callable[..., Any], - backward_fn: Callable[..., Any], - grad_vars: CollectionFilter = 'params', - nondiff_argnums=()): +def custom_vjp( + fn: Callable[..., Any], + forward_fn: Callable[..., Any], + backward_fn: Callable[..., Any], + grad_vars: CollectionFilter = 'params', + nondiff_argnums=(), +): """Lifted version of `jax.custom_vjp`. `forward_fn` and `backward_fn` together define a custom vjp for `fn`. @@ -1119,6 +1180,7 @@ def bwd(features, vjp_fn, y_t): Returns: A function with the same signature as `fn` with the custom vjp. """ + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): grad_variables, other_variables = variable_groups scopes_treedef = None @@ -1128,6 +1190,7 @@ def f(grad_variables, *args): y = fn(scope, *args) vars_out = repack_fn(scope) return y, vars_out + f = jax.custom_vjp(f, nondiff_argnums=nondiff_argnums) def f_fwd(grad_variables, *args): @@ -1160,18 +1223,19 @@ def f_bwd(*args): variable_out_groups = (grad_vars, True) rng_groups = (True,) return pack( - inner, variable_in_groups, variable_out_groups, rng_groups, - name='custom_vjp') - - -def checkpoint(fn: Callable[..., Any], - variables: CollectionFilter = True, - rngs: PRNGSequenceFilter = True, - concrete: bool = False, - prevent_cse: bool = True, - static_argnums: Union[int, Tuple[int, ...]] = (), - policy: Optional[Callable[..., bool]] = None, - ) -> Callable[..., Any]: + inner, variable_in_groups, variable_out_groups, rng_groups, name='custom_vjp' + ) + + +def checkpoint( + fn: Callable[..., Any], + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, + concrete: bool = False, + prevent_cse: bool = True, + static_argnums: Union[int, Tuple[int, ...]] = (), + policy: Optional[Callable[..., bool]] = None, +) -> Callable[..., Any]: """Lifted version of ``jax.checkpoint``. This function is aliased to ``lift.remat`` just like ``jax.remat``. @@ -1205,12 +1269,18 @@ def checkpoint(fn: Callable[..., Any], A wrapped version of ``fn``. When computing gradients intermediate computations will be re-computed when computing gradients. """ + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs): # add 2 to each static_argnums because we add two initial arguments to rematted static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums) - @functools.partial(jax.remat, - concrete=concrete, static_argnums=static_argnums_, - prevent_cse=prevent_cse, policy=policy) + + @functools.partial( + jax.remat, + concrete=concrete, + static_argnums=static_argnums_, + prevent_cse=prevent_cse, + policy=policy, + ) @functools.wraps(fn) def rematted(variable_groups, rng_groups, *args, **kwargs): scope = scope_fn(variable_groups, rng_groups) @@ -1220,9 +1290,8 @@ def rematted(variable_groups, rng_groups, *args, **kwargs): return rematted(variable_groups, rng_groups, *args, **kwargs) return pack( - inner, (variables,), (variables,), (rngs,), - name='remat', - enable_kwargs=True) + inner, (variables,), (variables,), (rngs,), name='remat', enable_kwargs=True + ) remat = checkpoint @@ -1233,19 +1302,19 @@ def _hashable_filter(x): if isinstance(x, Iterable): return tuple(x) # convert un-hashable list & sets to tuple if isinstance(x, DenyList): - return DenyList(_hashable_filter( - x.deny)) # convert inner filter recursively + return DenyList(_hashable_filter(x.deny)) # convert inner filter recursively return x -def jit(fn: Callable[..., Any], - variables: CollectionFilter = True, - rngs: PRNGSequenceFilter = True, - static_argnums: Union[int, Iterable[int]] = (), - donate_argnums: Union[int, Iterable[int]] = (), - device=None, - backend: Union[str, None] = None, - ) -> Callable[..., Any]: +def jit( + fn: Callable[..., Any], + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, + static_argnums: Union[int, Iterable[int]] = (), + donate_argnums: Union[int, Iterable[int]] = (), + device=None, + backend: Union[str, None] = None, +) -> Callable[..., Any]: """Lifted version of ``jax.jit``. Args: @@ -1297,10 +1366,14 @@ def jit(fn: Callable[..., Any], # where scope_fn or repack_fn actually produce non-identical results. scope_fn = None # type: Optional[Callable] repack_fn = None # type: Optional[Callable] - @functools.partial(jax.jit, - static_argnums=static_argnums, - donate_argnums=donate_argnums, - device=device, backend=backend) + + @functools.partial( + jax.jit, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + device=device, + backend=backend, + ) @functools.wraps(fn) def jitted(fingerprint, variable_groups, rng_groups, *args): nonlocal scope_fn, repack_fn @@ -1332,7 +1405,7 @@ def remat_scan( variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {True: 0}, - split_rngs: Mapping[PRNGSequenceFilter, bool] = {True: True} + split_rngs: Mapping[PRNGSequenceFilter, bool] = {True: True}, ) -> Callable[..., Any]: """Combines `lift.remat` and `lift.scan` for memory efficiency and constant time compilation. @@ -1375,18 +1448,29 @@ def body_fn(scope, x): variable_broadcast=variable_broadcast, variable_carry=variable_carry, variable_axes=variable_axes, - split_rngs=split_rngs) + split_rngs=split_rngs, + ) if len(lengths) == 1: + def wrapper(scope, carry): return body_fn(scope, carry), () + fn = lambda scope, c: scan_fn(wrapper, length=lengths[0])(scope, c)[0] else: + @functools.partial(remat, policy=policy, prevent_cse=False) def inner_loop(scope, carry): - carry = remat_scan(body_fn, lengths[1:], policy, - variable_broadcast, variable_carry, - variable_axes, split_rngs)(scope, carry) + carry = remat_scan( + body_fn, + lengths[1:], + policy, + variable_broadcast, + variable_carry, + variable_axes, + split_rngs, + )(scope, carry) return carry, () + fn = lambda scope, c: scan_fn(inner_loop, length=lengths[0])(scope, c)[0] return fn diff --git a/flax/core/meta.py b/flax/core/meta.py index 93b9b1be47..509b2a2926 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -31,7 +31,7 @@ from jax.experimental import maps -TAxisMetadata = Any # TypeVar('TAxisMetadata', bound='AxisMetadata') +TAxisMetadata = Any # TypeVar('TAxisMetadata', bound='AxisMetadata') class AxisMetadata(metaclass=abc.ABCMeta): @@ -82,8 +82,9 @@ def replace_boxed(self, val: Any) -> TAxisMetadata: pass @abc.abstractmethod - def add_axis(self: TAxisMetadata, index: int, - params: Dict[Any, Any]) -> TAxisMetadata: + def add_axis( + self: TAxisMetadata, index: int, params: Dict[Any, Any] + ) -> TAxisMetadata: """Adds a new axis to the axis metadata. Note that add_axis and remove_axis should act as each other's inverse @@ -102,8 +103,9 @@ def add_axis(self: TAxisMetadata, index: int, pass @abc.abstractmethod - def remove_axis(self: TAxisMetadata, index: int, - params: Dict[Any, Any]) -> TAxisMetadata: + def remove_axis( + self: TAxisMetadata, index: int, params: Dict[Any, Any] + ) -> TAxisMetadata: """Removes an axis from the axis metadata. Note that add_axis and remove_axis should act as each other's inverse @@ -129,11 +131,13 @@ def is_axis_metadata(val: Any) -> bool: def map_axis_meta(fn: Callable[[AxisMetadata], Any], tree: Any) -> Any: """Maps over all PyTree nodes that are AxisMetadata instances.""" + def wrapper(x): if isinstance(x, AxisMetadata): return fn(x) else: return x + return jax.tree_map(wrapper, tree, is_leaf=is_axis_metadata) @@ -154,11 +158,13 @@ def unbox(tree: Any) -> Any: def replace_boxed(tree: Any, updates: Any) -> Any: """Updates all AxisMetadata boxes with the values in updates.""" + def inner_update(c, v): if isinstance(c, AxisMetadata): return c.replace_boxed(replace_boxed(c.unbox(), v)) else: return v + return jax.tree_map(inner_update, tree, updates, is_leaf=is_axis_metadata) @@ -228,6 +234,7 @@ def body(mdl, c): return c """ + value: Any names: LogicalNames = struct.field(pytree_node=False) mesh: Optional[jax.sharding.Mesh] = struct.field(default=None, pytree_node=False) @@ -239,8 +246,7 @@ def unbox(self, apply_constraint=True) -> Any: if self.mesh is not None: sharding = jax.sharding.NamedSharding(self.mesh, axis_resource) return jax.lax.with_sharding_constraint(self.value, sharding) - return jax.lax.with_sharding_constraint( - self.value, axis_resource) + return jax.lax.with_sharding_constraint(self.value, axis_resource) else: return self.value @@ -256,8 +262,8 @@ def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: axis_name = self._get_partition_name(params) names = list(self.names) while len(names) < index: - names.append(None) # type: ignore - names.insert(index, axis_name) # type: ignore + names.append(None) # type: ignore + names.insert(index, axis_name) # type: ignore return self.replace(names=tuple(names)) def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: @@ -279,7 +285,7 @@ def with_partitioning( fn: Callable[..., Any], names: LogicalNames, mesh: Optional[jax.sharding.Mesh] = None, - ) -> Callable[..., Partitioned]: +) -> Callable[..., Partitioned]: """Wraps a function's return value with Partitioned. Example:: @@ -296,14 +302,17 @@ def with_partitioning( Returns: A function wrapping ``fn`` that will return an instance of ``Partitioned``. """ + @functools.wraps(fn) def wrapper(*args, **kwargs): return Partitioned(fn(*args, **kwargs), names, mesh=mesh) + return wrapper def get_partition_spec(tree: Any) -> Any: """Extracts a PartitionSpec tree from a PyTree containing ``Partitioned`` values.""" + def f(x): if isinstance(x, Partitioned): return x.get_partition_spec() @@ -312,8 +321,8 @@ def f(x): return jax.sharding.PartitionSpec() else: return None - return jax.tree_map(f, tree, - is_leaf=lambda x: isinstance(x, Partitioned)) + + return jax.tree_map(f, tree, is_leaf=lambda x: isinstance(x, Partitioned)) def get_sharding(tree: Any, mesh: jax.sharding.Mesh) -> Any: diff --git a/flax/core/nn/__init__.py b/flax/core/nn/__init__.py index 55e371dda9..528bbe38aa 100644 --- a/flax/core/nn/__init__.py +++ b/flax/core/nn/__init__.py @@ -17,43 +17,41 @@ # pylint: disable=g-multiple-import # re-export commonly used modules and functions from .attention import ( - dot_product_attention as dot_product_attention, - multi_head_dot_product_attention as multi_head_dot_product_attention + dot_product_attention as dot_product_attention, + multi_head_dot_product_attention as multi_head_dot_product_attention, ) from flax.linen import activation as activation from flax.linen import initializers as initializers from flax.linen.activation import ( - celu as celu, - elu as elu, - gelu as gelu, - glu as glu, - leaky_relu as leaky_relu, - log_sigmoid as log_sigmoid, - log_softmax as log_softmax, - relu as relu, - sigmoid as sigmoid, - silu as silu, - soft_sign as soft_sign, - softmax as softmax, - softplus as softplus, - swish as swish, - tanh as tanh) -from flax.linen.pooling import ( - avg_pool as avg_pool, - max_pool as max_pool + celu as celu, + elu as elu, + gelu as gelu, + glu as glu, + leaky_relu as leaky_relu, + log_sigmoid as log_sigmoid, + log_softmax as log_softmax, + relu as relu, + sigmoid as sigmoid, + silu as silu, + soft_sign as soft_sign, + softmax as softmax, + softplus as softplus, + swish as swish, + tanh as tanh, ) +from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool) from .linear import ( - Embedding as Embedding, - conv as conv, - conv_transpose as conv_transpose, - dense as dense, - dense_general as dense_general, - embedding as embedding + Embedding as Embedding, + conv as conv, + conv_transpose as conv_transpose, + dense as dense, + dense_general as dense_general, + embedding as embedding, ) from .normalization import ( - batch_norm as batch_norm, - group_norm as group_norm, - layer_norm as layer_norm + batch_norm as batch_norm, + group_norm as group_norm, + layer_norm as layer_norm, ) from .stochastic import dropout as dropout diff --git a/flax/core/nn/attention.py b/flax/core/nn/attention.py index 1f9c5f1afd..16eead937d 100644 --- a/flax/core/nn/attention.py +++ b/flax/core/nn/attention.py @@ -33,18 +33,20 @@ import numpy as np -def dot_product_attention(scope, - query, - key, - value, - dtype=jnp.float32, - bias=None, - axis=None, - broadcast_dropout=True, - dropout_rng=None, - dropout_rate=0., - deterministic=False, - precision=None): +def dot_product_attention( + scope, + query, + key, + value, + dtype=jnp.float32, + bias=None, + axis=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0.0, + deterministic=False, + precision=None, +): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -75,8 +77,7 @@ def dot_product_attention(scope, Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`. """ assert key.shape[:-1] == value.shape[:-1] - assert (query.shape[0:1] == key.shape[0:1] and - query.shape[-1] == key.shape[-1]) + assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1] if axis is None: axis = tuple(range(1, key.ndim - 2)) @@ -86,8 +87,9 @@ def dot_product_attention(scope, assert key.ndim == value.ndim for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): - raise ValueError('Attention axis must be between the batch ' - 'axis and the last-two axes.') + raise ValueError( + 'Attention axis must be between the batch ' 'axis and the last-two axes.' + ) depth = query.shape[-1] n = key.ndim # batch_dims is , num_heads> @@ -104,8 +106,10 @@ def dot_product_attention(scope, batch_dims_t = tuple(range(len(batch_dims))) attn_weights = lax.dot_general( query, - key, (((n - 1,), (n - 1,)), (batch_dims_t, batch_dims_t)), - precision=precision) + key, + (((n - 1,), (n - 1,)), (batch_dims_t, batch_dims_t)), + precision=precision, + ) # apply attention bias: masking, droput, proximity bias, ect. if bias is not None: @@ -114,32 +118,34 @@ def dot_product_attention(scope, # normalize the attention weights norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) attn_weights = lax.exp( - attn_weights - - jax.scipy.special.logsumexp(attn_weights, axis=norm_dims, keepdims=True)) + attn_weights + - jax.scipy.special.logsumexp(attn_weights, axis=norm_dims, keepdims=True) + ) attn_weights = attn_weights.astype(dtype) # apply dropout - if not deterministic and dropout_rate > 0.: + if not deterministic and dropout_rate > 0.0: if dropout_rng is None: dropout_rng = scope.make_rng('dropout') keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate) if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension - dropout_dims = attn_weights.shape[-(2 * len(axis)):] - dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) + dropout_dims = attn_weights.shape[-(2 * len(axis)) :] + dropout_shape = tuple([1] * len(batch_dims_t)) + dropout_dims keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) - multiplier = (keep.astype(attn_weights.dtype) / - jnp.asarray(keep_prob, dtype=dtype)) + multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype) attn_weights = attn_weights * multiplier # compute the new values given the attention weights wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) y = lax.dot_general( attn_weights, - value, (wv_contracting_dims, (batch_dims_t, batch_dims_t)), - precision=precision) + value, + (wv_contracting_dims, (batch_dims_t, batch_dims_t)), + precision=precision, + ) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) @@ -177,13 +183,14 @@ def multi_head_dot_product_attention( cache=False, broadcast_dropout=True, dropout_rng=None, - dropout_rate=0., + dropout_rate=0.0, deterministic=False, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros_init(), bias=True, - attention_fn=dot_product_attention): + attention_fn=dot_product_attention, +): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -231,8 +238,7 @@ def multi_head_dot_product_attention( output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ - assert causal_mask or not cache, ( - 'Caching is only support for causal attention.') + assert causal_mask or not cache, 'Caching is only support for causal attention.' if inputs_kv is None: inputs_kv = inputs_q @@ -243,8 +249,9 @@ def multi_head_dot_product_attention( features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] - assert qkv_features % num_heads == 0, ( - 'Memory dimension must be divisible by number of heads.') + assert ( + qkv_features % num_heads == 0 + ), 'Memory dimension must be divisible by number of heads.' head_dim = qkv_features // num_heads dense = functools.partial( @@ -255,7 +262,8 @@ def multi_head_dot_product_attention( kernel_init=kernel_init, bias_init=bias_init, bias=bias, - precision=precision) + precision=precision, + ) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query = scope.child(dense, 'query')(inputs_q) @@ -266,15 +274,20 @@ def multi_head_dot_product_attention( cache_entry: Union[Callable[[Any], CacheEntry], CacheEntry] if not scope.has_variable('cache', 'entry'): ndim, tail_shape = (key.ndim, key.shape[-2:]) + def init_fn(shape, dtype=jnp.float32): full_shape = shape + tail_shape if len(full_shape) != ndim: - raise ValueError('Shape should be a tuple with the shape of the batch' - 'and attention dims.') + raise ValueError( + 'Shape should be a tuple with the shape of the batch' + 'and attention dims.' + ) return CacheEntry( key=jnp.zeros(full_shape, dtype), value=jnp.zeros(full_shape, dtype), - i=jnp.zeros((), jnp.uint32)) + i=jnp.zeros((), jnp.uint32), + ) + cache_entry = init_fn else: cache_entry = scope.get_variable('cache', 'entry') @@ -286,9 +299,10 @@ def init_fn(shape, dtype=jnp.float32): expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: - raise ValueError('Invalid shape provided, ' - 'expected shape %s instead got %s.' % - (expected_shape, inputs_q.shape)) + raise ValueError( + 'Invalid shape provided, ' + 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape) + ) cshape = cache_entry.key.shape indices = [0] * len(cshape) @@ -299,16 +313,15 @@ def init_fn(shape, dtype=jnp.float32): indices[attn_dim] = i // attn_size i = i % attn_size - key = lax.dynamic_update_slice(cache_entry.key, key, indices) # type: ignore - value = lax.dynamic_update_slice(cache_entry.value, value, indices) # type: ignore + key = lax.dynamic_update_slice(cache_entry.key, key, indices) # type: ignore + value = lax.dynamic_update_slice(cache_entry.value, value, indices) # type: ignore one = jnp.array(1, jnp.uint32) - cache_entry = cache_entry.replace(i=cache_entry.i + one, - key=key, - value=value) + cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( - (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) + (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2] + ) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] scope.put_variable('cache', 'entry', cache_entry) @@ -334,7 +347,8 @@ def init_fn(shape, dtype=jnp.float32): padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, - attention_axis=attention_axis) + attention_axis=attention_axis, + ) mask_components.append(padding_mask) if segmentation is not None: @@ -346,7 +360,8 @@ def init_fn(shape, dtype=jnp.float32): query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, - segmentation_mask=True) + segmentation_mask=True, + ) mask_components.append(segmentation_mask) if mask_components: @@ -356,8 +371,10 @@ def init_fn(shape, dtype=jnp.float32): # attention mask in the form of attention bias attention_bias = lax.select( - attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), - jnp.full(attention_mask.shape, -1e10).astype(dtype)) + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(dtype), + jnp.full(attention_mask.shape, -1e10).astype(dtype), + ) else: attention_bias = None @@ -373,7 +390,8 @@ def init_fn(shape, dtype=jnp.float32): dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, - deterministic=deterministic) + deterministic=deterministic, + ) # back to the original inputs dimensions out = scope.child(dense_general, name='out')( @@ -384,22 +402,25 @@ def init_fn(shape, dtype=jnp.float32): bias_init=bias_init, bias=bias, dtype=dtype, - precision=precision) + precision=precision, + ) return out # TODO(flax-dev): Consider refactoring MultiHeadDotProductAttention and moving # causal_mask and cache support into this class instead. -#SelfAttention = MultiHeadDotProductAttention.partial(inputs_kv=None) +# SelfAttention = MultiHeadDotProductAttention.partial(inputs_kv=None) -def make_padding_mask(padding_mask_query, - padding_mask_key, - query_shape, - key_shape, - attention_axis=None, - segmentation_mask=False): +def make_padding_mask( + padding_mask_query, + padding_mask_key, + query_shape, + key_shape, + attention_axis=None, + segmentation_mask=False, +): """Makes padding mask for attention weights. In case of 1d inputs (i.e., `[bs, len, features]`, the attention weights will @@ -491,7 +512,8 @@ def tri(n, m, k=0): y = lax.tie_in(key, jnp.arange(m, dtype=jnp.int32)) mask = lax.ge( (lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,))) + k, - lax.broadcast(y, [n])) + lax.broadcast(y, [n]), + ) return mask k = -1 if self_mask else 0 diff --git a/flax/core/nn/linear.py b/flax/core/nn/linear.py index cc82d905db..e36ea79921 100644 --- a/flax/core/nn/linear.py +++ b/flax/core/nn/linear.py @@ -42,7 +42,8 @@ def dense_general( dtype=jnp.float32, kernel_init=default_kernel_init, bias_init=initializers.zeros_init(), - precision=None): + precision=None, +): """Applies a linear transformation to the inputs along multiple dimensions. Args: @@ -72,8 +73,10 @@ def dense_general( if batch_dims: max_dim = np.max(batch_dims) if set(batch_dims) != set(range(max_dim + 1)): - raise ValueError('batch_dims %s must be consecutive leading ' - 'dimensions starting from 0.' % str(batch_dims)) + raise ValueError( + 'batch_dims %s must be consecutive leading ' + 'dimensions starting from 0.' % str(batch_dims) + ) ndim = inputs.ndim n_batch_dims = len(batch_dims) @@ -83,10 +86,13 @@ def dense_general( def kernel_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) - flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), - np.prod(shape[-n_features:]),) - kernel = jnp.concatenate([kernel_init(rng, flat_shape, dtype) - for _ in range(size_batch_dims)], axis=0) + flat_shape = ( + np.prod(shape[n_batch_dims : n_axis + n_batch_dims]), + np.prod(shape[-n_features:]), + ) + kernel = jnp.concatenate( + [kernel_init(rng, flat_shape, dtype) for _ in range(size_batch_dims)], axis=0 + ) return jnp.reshape(kernel, shape) batch_shape = tuple(inputs.shape[ax] for ax in batch_dims) @@ -96,23 +102,26 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) - out = lax.dot_general(inputs, - kernel, - ((axis, contract_ind), (batch_dims, batch_ind)), - precision=precision) + out = lax.dot_general( + inputs, + kernel, + ((axis, contract_ind), (batch_dims, batch_ind)), + precision=precision, + ) if bias: + def bias_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[-n_features:]),) - bias = jnp.concatenate([bias_init(rng, flat_shape, dtype) - for _ in range(size_batch_dims)], axis=0) + bias = jnp.concatenate( + [bias_init(rng, flat_shape, dtype) for _ in range(size_batch_dims)], axis=0 + ) return jnp.reshape(bias, shape) bias = scope.param('bias', bias_init_wrap, batch_shape + features) # Reshape bias for broadcast. - expand_dims = sorted( - set(range(inputs.ndim)) - set(axis) - set(batch_dims)) + expand_dims = sorted(set(range(inputs.ndim)) - set(axis) - set(batch_dims)) for ax in expand_dims: bias = jnp.expand_dims(bias, ax) bias = jnp.asarray(bias, dtype) @@ -120,14 +129,16 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): return out -def dense(scope, - inputs, - features, - bias=True, - dtype=jnp.float32, - precision=None, - kernel_init=default_kernel_init, - bias_init=initializers.zeros_init()): +def dense( + scope, + inputs, + features, + bias=True, + dtype=jnp.float32, + precision=None, + kernel_init=default_kernel_init, + bias_init=initializers.zeros_init(), +): """Applies a linear transformation to the inputs along the last dimension. Args: @@ -145,9 +156,9 @@ def dense(scope, inputs = jnp.asarray(inputs, dtype) kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features)) kernel = jnp.asarray(kernel, dtype) - y = lax.dot_general(inputs, kernel, - (((inputs.ndim - 1,), (0,)), ((), ())), - precision=precision) + y = lax.dot_general( + inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=precision + ) if bias: bias = scope.param('bias', bias_init, (features,)) bias = jnp.asarray(bias, dtype) @@ -164,20 +175,22 @@ def _conv_dimension_numbers(input_shape): return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) -def conv(scope, - inputs, - features, - kernel_size, - strides=None, - padding='SAME', - input_dilation=None, - kernel_dilation=None, - feature_group_count=1, - bias=True, - dtype=jnp.float32, - precision=None, - kernel_init=default_kernel_init, - bias_init=initializers.zeros_init()): +def conv( + scope, + inputs, + features, + kernel_size, + strides=None, + padding='SAME', + input_dilation=None, + kernel_dilation=None, + feature_group_count=1, + bias=True, + dtype=jnp.float32, + precision=None, + kernel_init=default_kernel_init, + bias_init=initializers.zeros_init(), +): """Applies a convolution to the inputs. Args: @@ -230,7 +243,8 @@ def conv(scope, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, - precision=precision) + precision=precision, + ) if bias: bias = scope.param('bias', bias_init, (features,)) @@ -239,18 +253,20 @@ def conv(scope, return y -def conv_transpose(scope, - inputs, - features, - kernel_size, - strides=None, - padding='SAME', - kernel_dilation=None, - bias=True, - dtype=jnp.float32, - precision=None, - kernel_init=default_kernel_init, - bias_init=initializers.zeros_init()): +def conv_transpose( + scope, + inputs, + features, + kernel_size, + strides=None, + padding='SAME', + kernel_dilation=None, + bias=True, + dtype=jnp.float32, + precision=None, + kernel_init=default_kernel_init, + bias_init=initializers.zeros_init(), +): """Applies a transposed convolution to the inputs. Behaviour mirrors that of `jax.lax.conv_transpose`. @@ -285,8 +301,14 @@ def conv_transpose(scope, kernel = scope.param('kernel', kernel_init, kernel_shape) kernel = jnp.asarray(kernel, dtype) - y = lax.conv_transpose(inputs, kernel, strides, padding, - rhs_dilation=kernel_dilation, precision=precision) + y = lax.conv_transpose( + inputs, + kernel, + strides, + padding, + rhs_dilation=kernel_dilation, + precision=precision, + ) if bias: bias = scope.param('bias', bias_init, (features,)) @@ -295,8 +317,7 @@ def conv_transpose(scope, return y -default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal', - out_axis=0) +default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0) @struct.dataclass @@ -333,16 +354,18 @@ def attend(self, query): return jnp.dot(query, self.table.T) -def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=default_embed_init) -> Embedding: +def embedding( + scope: Scope, num_embeddings: int, features: int, init_fn=default_embed_init +) -> Embedding: """Creates embedding dataclass. - Args: - num_embeddings: number of embeddings. - features: Number of feature dimensions for each embedding. - embedding_init: embedding initializer. + Args: + num_embeddings: number of embeddings. + features: Number of feature dimensions for each embedding. + embedding_init: embedding initializer. - Returns: - Embedding dataclass with lookup and attend methods. + Returns: + Embedding dataclass with lookup and attend methods. """ table = scope.param('table', init_fn, (num_embeddings, features)) - return Embedding(table) # type: ignore + return Embedding(table) # type: ignore diff --git a/flax/core/nn/normalization.py b/flax/core/nn/normalization.py index 135f890444..d20333aa88 100644 --- a/flax/core/nn/normalization.py +++ b/flax/core/nn/normalization.py @@ -24,16 +24,22 @@ def _absolute_dims(ndim, dims): return tuple(ndim + dim if dim < 0 else dim for dim in dims) -def batch_norm(scope: Scope, - x, - use_running_average=False, - axis=-1, momentum=0.99, epsilon=1e-5, - dtype=jnp.float32, - bias=True, scale=True, - bias_init=initializers.zeros_init(), scale_init=initializers.ones_init(), - axis_name=None, axis_index_groups=None, - kind='batch_stats'): - +def batch_norm( + scope: Scope, + x, + use_running_average=False, + axis=-1, + momentum=0.99, + epsilon=1e-5, + dtype=jnp.float32, + bias=True, + scale=True, + bias_init=initializers.zeros_init(), + scale_init=initializers.ones_init(), + axis_name=None, + axis_index_groups=None, + kind='batch_stats', +): x = jnp.asarray(x, jnp.float32) axis = axis if isinstance(axis, tuple) else (axis,) axis = _absolute_dims(x.ndim, axis) @@ -61,18 +67,16 @@ def pmean(x): var = jnp.reshape(ra_var.value, var.shape) else: if not is_init: - beta = 1. - momentum + beta = 1.0 - momentum ra_mean.value += beta * (jnp.squeeze(mean) - ra_mean.value) ra_var.value += beta * (jnp.squeeze(var) - ra_var.value) y = x - mean mul = lax.rsqrt(var + epsilon) if scale: - mul = mul * scope.param( - 'scale', scale_init, squeeze_shape).reshape(mean.shape) + mul = mul * scope.param('scale', scale_init, squeeze_shape).reshape(mean.shape) y = y * mul if bias: - y = y + scope.param( - 'bias', bias_init, squeeze_shape).reshape(mean.shape) + y = y + scope.param('bias', bias_init, squeeze_shape).reshape(mean.shape) return jnp.asarray(y, dtype) @@ -84,7 +88,8 @@ def layer_norm( bias=True, scale=True, bias_init=initializers.zeros_init(), - scale_init=initializers.ones_init()): + scale_init=initializers.ones_init(), +): """Applies layer normalization on the input. It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. @@ -109,24 +114,25 @@ def layer_norm( var = mean2 - lax.square(mean) mul = lax.rsqrt(var + epsilon) if scale: - mul = mul * jnp.asarray(scope.param('scale', scale_init, (features,)), - dtype) + mul = mul * jnp.asarray(scope.param('scale', scale_init, (features,)), dtype) y = (x - mean) * mul if bias: y = y + jnp.asarray(scope.param('bias', bias_init, (features,)), dtype) return y -def group_norm(scope, - x, - num_groups=32, - group_size=None, - epsilon=1e-6, - dtype=jnp.float32, - bias=True, - scale=True, - bias_init=initializers.zeros_init(), - scale_init=initializers.ones_init()): +def group_norm( + scope, + x, + num_groups=32, + group_size=None, + epsilon=1e-6, + dtype=jnp.float32, + bias=True, + scale=True, + bias_init=initializers.zeros_init(), + scale_init=initializers.ones_init(), +): """Applies group normalization to the input (arxiv.org/abs/1803.08494). This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. @@ -153,16 +159,21 @@ def group_norm(scope, Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) - if ((num_groups is None and group_size is None) or - (num_groups is not None and group_size is not None)): - raise ValueError('Either `num_groups` or `group_size` should be ' - 'specified, but not both of them.') + if (num_groups is None and group_size is None) or ( + num_groups is not None and group_size is not None + ): + raise ValueError( + 'Either `num_groups` or `group_size` should be ' + 'specified, but not both of them.' + ) if group_size is not None: channels = x.shape[-1] if channels % group_size != 0: - raise ValueError('Number of channels ({}) is not multiple of the ' - 'group size ({}).'.format(channels, group_size)) + raise ValueError( + 'Number of channels ({}) is not multiple of the ' + 'group size ({}).'.format(channels, group_size) + ) num_groups = channels // group_size input_shape = x.shape @@ -173,8 +184,7 @@ def group_norm(scope, reduction_axis = list(range(1, x.ndim - 2)) + [x.ndim - 1] mean = jnp.mean(x, axis=reduction_axis, keepdims=True) - mean_of_squares = jnp.mean(jnp.square(x), axis=reduction_axis, - keepdims=True) + mean_of_squares = jnp.mean(jnp.square(x), axis=reduction_axis, keepdims=True) var = mean_of_squares - jnp.square(mean) x = (x - mean) * lax.rsqrt(var + epsilon) diff --git a/flax/core/nn/stochastic.py b/flax/core/nn/stochastic.py index 0711cbd8b9..837b6694ca 100644 --- a/flax/core/nn/stochastic.py +++ b/flax/core/nn/stochastic.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Stochastic modules. -""" +"""Stochastic modules.""" from jax import lax from jax import random import jax.numpy as jnp - def dropout(scope, inputs, rate, deterministic=False, rng=None): """Applies a random dropout mask to the input. Args: @@ -34,9 +32,9 @@ def dropout(scope, inputs, rate, deterministic=False, rng=None): Returns: The masked inputs. """ - if rate == 0.: + if rate == 0.0: return inputs - keep_prob = 1. - rate + keep_prob = 1.0 - rate if deterministic: return inputs diff --git a/flax/core/partial_eval.py b/flax/core/partial_eval.py index c92fc19e36..0db665c11d 100644 --- a/flax/core/partial_eval.py +++ b/flax/core/partial_eval.py @@ -48,6 +48,7 @@ def lazy_init(fn): A new function that accepts a mix of concrete values and ``jax.ShapeDtypeStruct`` instances. """ + @functools.wraps(fn) def wrapper(*args, **kwargs): # TODO(mattjj,jheek): use a public JAX API diff --git a/flax/core/scope.py b/flax/core/scope.py index 42bfb5a873..a554df2171 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -20,8 +20,20 @@ import functools import hashlib import typing -from typing import (Any, Callable, Dict, Generic, Iterable, Mapping, Optional, - Sequence, Set, Tuple, TypeVar, Union) +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) from flax.ids import uuid from flax import config as config @@ -57,6 +69,7 @@ # When conditioning on filters we require explicit boolean comparisons. # pylint: disable=g-bool-id-comparison + @dataclasses.dataclass(frozen=True, eq=True) class DenyList: """DenyList represents an opt-out based mutability filter. @@ -71,6 +84,7 @@ class DenyList: deny: The filter representing the collections that are not mutable. """ + deny: Filter @@ -89,6 +103,7 @@ class DenyList: class LazyRng(struct.PyTreeNode): """Wrapper around JAX PRNGKey that lazily maintains a tuple of static data to be folded into the rng.""" + rng: PRNGKey suffix: Tuple[PRNGFoldable, ...] = struct.field(pytree_node=False) @@ -96,8 +111,7 @@ def as_jax_rng(self) -> PRNGKey: return _fold_in_static(self.rng, self.suffix) @staticmethod - def create(rng: Union['LazyRng', PRNGKey], - *suffix: PRNGFoldable) -> 'LazyRng': + def create(rng: Union['LazyRng', PRNGKey], *suffix: PRNGFoldable) -> 'LazyRng': if not legacy_config.flax_lazy_rng: if isinstance(rng, LazyRng): assert not rng.suffix @@ -125,8 +139,7 @@ def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey: return rng -def _fold_in_static(rng: PRNGKey, - data: typing.Collection[PRNGFoldable]) -> PRNGKey: +def _fold_in_static(rng: PRNGKey, data: typing.Collection[PRNGFoldable]) -> PRNGKey: """Folds static data (strings & ints) into a jax.random.PRNGKey using its SHA-1 hash. This is faster than splitting an PRNGKey because it allows generating new PRNG @@ -154,7 +167,7 @@ def _fold_in_static(rng: PRNGKey, raise ValueError(f'Expected int or string, got: {x}') d = m.digest() hash_int = int.from_bytes(d[:4], byteorder='big') - return random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore + return random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore def is_filter_empty(filter_like: Filter) -> bool: @@ -303,8 +316,8 @@ def intersect_filters(a: Filter, b: Filter) -> Filter: def group_collections( - xs: VariableDict, - col_filters: Sequence[CollectionFilter]) -> Sequence[MutableVariableDict]: + xs: VariableDict, col_filters: Sequence[CollectionFilter] +) -> Sequence[MutableVariableDict]: """Groups variables by collection filters. Iteratively applies the filters in `col_filters` to `xs`, and adds the result @@ -343,8 +356,7 @@ class Variable(Generic[T]): content and can be assigned to for mutation. """ - def __init__( - self, scope: 'Scope', collection: str, name: str, unbox: bool): + def __init__(self, scope: 'Scope', collection: str, name: str, unbox: bool): """Initializes a variable. Args: @@ -370,12 +382,9 @@ def value(self, value: T): """Updates the value of this Variable.""" if self.unbox: cur = self.scope.get_variable(self.collection, self.name) - cur_struct = tree_util.tree_structure(cur, - is_leaf=meta.is_axis_metadata) - value_struct = tree_util.tree_structure(value, - is_leaf=meta.is_axis_metadata) - has_meta = any(map(meta.is_axis_metadata, - cur_struct.flatten_up_to(cur))) + cur_struct = tree_util.tree_structure(cur, is_leaf=meta.is_axis_metadata) + value_struct = tree_util.tree_structure(value, is_leaf=meta.is_axis_metadata) + has_meta = any(map(meta.is_axis_metadata, cur_struct.flatten_up_to(cur))) if cur_struct == value_struct and has_meta: value = meta.replace_boxed(cur, value) @@ -388,12 +397,16 @@ def is_mutable(self) -> bool: class _ChildRNGSentinel: pass + + # used to identify that an rng counter is meant for a child scope child_rng_token = _ChildRNGSentinel() class _DefaultSentinel: pass + + # used to denote no default flag value on scope no_flag = _DefaultSentinel() @@ -409,16 +422,19 @@ class Scope: `_ for a number of examples using ``Scopes``. """ + reservations: Dict[str, Set[Optional[str]]] - def __init__(self, - variables: MutableVariableDict, - rngs: Optional[Dict[str, Union[PRNGKey, LazyRng]]] = None, - name: Optional[str] = None, - mutable: CollectionFilter = False, - parent: Optional['Scope'] = None, - path: Iterable[str] = (), - flags: Optional[Mapping] = None): + def __init__( + self, + variables: MutableVariableDict, + rngs: Optional[Dict[str, Union[PRNGKey, LazyRng]]] = None, + name: Optional[str] = None, + mutable: CollectionFilter = False, + parent: Optional['Scope'] = None, + path: Iterable[str] = (), + flags: Optional[Mapping] = None, + ): """Initializes a Scope. Args: @@ -455,7 +471,11 @@ def __eq__(self, other: Any) -> bool: return False if self is other: return True - return self.root._variables is other.root._variables and self.path == other.path and self.rng_counters is other.rng_counters + return ( + self.root._variables is other.root._variables + and self.path == other.path + and self.rng_counters is other.rng_counters + ) def __hash__(self) -> int: # see __eq__ @@ -494,8 +514,7 @@ def invalidate(self): def mutable_variables(self) -> Union[VariableDict, Dict[str, Any]]: """Returns an immutable copy of the mutable variables belonging to this Scope.""" self._populate_collections() - xs = {k: v for k, v in self._variables.items() - if in_filter(self.mutable, k)} + xs = {k: v for k, v in self._variables.items() if in_filter(self.mutable, k)} if config.flax_return_frozendict: return freeze(xs) return xs @@ -521,8 +540,15 @@ def rewound(self, rewind_rngs: bool = False) -> 'Scope': emptied, and the rng counter is optionally rewound. """ self._check_valid() - scope = Scope(self._variables, self.rngs, self.name, self.mutable, - self.parent, path=self.path, flags=self.flags) + scope = Scope( + self._variables, + self.rngs, + self.name, + self.mutable, + self.parent, + path=self.path, + flags=self.flags, + ) if not rewind_rngs: scope.rng_counters = self.rng_counters return scope @@ -537,9 +563,12 @@ def name_reserved(self, name: str, col: Optional[str] = None) -> bool: if name in self.reservations: # allow the same name for two variables in # different collections, otherwise raise error. - if (None in self.reservations[name] or col is None - or col in self.reservations[name]): - return True + if ( + None in self.reservations[name] + or col is None + or col in self.reservations[name] + ): + return True return False def reserve(self, name: str, col: Optional[str] = None): @@ -552,8 +581,9 @@ def reserve(self, name: str, col: Optional[str] = None): col: if a variable, the collection used. """ if not isinstance(name, str): - raise TypeError('The type of scope "{name}" should be string but ' - f'it is {type(name)}') + raise TypeError( + 'The type of scope "{name}" should be string but ' f'it is {type(name)}' + ) if self.name_reserved(name, col): raise ValueError(f'Duplicate use of scope name: "{name}"') self.reservations[name].add(col) @@ -574,10 +604,7 @@ def default_name(self, prefix: str) -> str: return name i += 1 - def push(self, - name: Optional[str] = None, - prefix: str = '', - reuse=False) -> 'Scope': + def push(self, name: Optional[str] = None, prefix: str = '', reuse=False) -> 'Scope': """Creates a child Scope. Args: @@ -598,26 +625,30 @@ def push(self, rngs = {key: LazyRng.create(rng, name) for key, rng in self.rngs.items()} rng_key = (child_rng_token, name) if rng_key in self.rng_counters: - rng_counters = self.rng_counters.get(rng_key) # type: ignore + rng_counters = self.rng_counters.get(rng_key) # type: ignore else: rng_counters = {key: 0 for key in rngs} - self.rng_counters[rng_key] = rng_counters # type: ignore - scope = Scope({}, - name=name, - rngs=rngs, - parent=self, - mutable=self.mutable, - path=self.path + (name,), - flags=self.flags) + self.rng_counters[rng_key] = rng_counters # type: ignore + scope = Scope( + {}, + name=name, + rngs=rngs, + parent=self, + mutable=self.mutable, + path=self.path + (name,), + flags=self.flags, + ) scope.rng_counters = rng_counters return scope - def child(self, - fn: Callable[..., Any], - name: Optional[str] = None, - prefix: Optional[str] = None, - named_call: bool = True, - **partial_kwargs) -> Callable[..., Any]: + def child( + self, + fn: Callable[..., Any], + name: Optional[str] = None, + prefix: Optional[str] = None, + named_call: bool = True, + **partial_kwargs, + ) -> Callable[..., Any]: """Partially applies a child scope to fn. When calling the returned function multiple times variables will be reused. @@ -761,8 +792,7 @@ def put_variable(self, col: str, name: str, value: Any): # Make sure reference sharing of child variable dictionaries isn't broken. # See https://github.com/google/flax/issues/2022 for more details. def put(target, key, val): - if (key in target and isinstance(target[key], dict) and - isinstance(val, Mapping)): + if key in target and isinstance(target[key], dict) and isinstance(val, Mapping): for k, v in val.items(): put(target[key], k, v) else: @@ -770,9 +800,14 @@ def put(target, key, val): put(variables, name, value) - def variable(self, col: str, name: str, # pylint: disable=keyword-arg-before-vararg - init_fn: Optional[Callable[..., T]] = None, - *init_args, unbox: bool = True) -> Variable[T]: + def variable( + self, + col: str, + name: str, # pylint: disable=keyword-arg-before-vararg + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + ) -> Variable[T]: """Creates a variable if it doesn't exist yet in this scope and returns it. Args: @@ -798,8 +833,9 @@ def variable(self, col: str, name: str, # pylint: disable=keyword-arg-before-va self.put_variable(col, name, init_value) return Variable(self, col, name, unbox=unbox) - def param(self, name: str, init_fn: Callable[..., T], *init_args, - unbox: bool = True) -> T: + def param( + self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True + ) -> T: """Creates a parameter if it doesn't exist yet in this scope and returns it. If the parameter exists already, the existing value is simply returned. @@ -817,8 +853,7 @@ def param(self, name: str, init_fn: Callable[..., T], *init_args, """ self.reserve(name, 'params') if self.has_variable('params', name): - abs_rng = jax.ShapeDtypeStruct(random.default_prng_impl().key_shape, - jnp.uint32) + abs_rng = jax.ShapeDtypeStruct(random.default_prng_impl().key_shape, jnp.uint32) value = self.get_variable('params', name) # Validate that the shape of the init_fn output is the same as the shape # of the existing parameter. This is to make sure that the hparams set up @@ -833,8 +868,9 @@ def param(self, name: str, init_fn: Callable[..., T], *init_args, # usefuleness is less obvious. We might intentionally change the dtype # for inference to a half float type for example. if jnp.shape(val) != jnp.shape(abs_val): - raise errors.ScopeParamShapeError(name, self.path_text, - jnp.shape(abs_val), jnp.shape(val)) + raise errors.ScopeParamShapeError( + name, self.path_text, jnp.shape(abs_val), jnp.shape(val) + ) else: if not self.is_mutable_collection('params'): if self.is_collection_empty('params'): @@ -870,10 +906,12 @@ def _unfreeze_variables(variables, mutable): return new_variables -def bind(variables: VariableDict, - rngs: Optional[RNGSequences] = None, - mutable: CollectionFilter = False, - flags: Optional[Mapping] = None): +def bind( + variables: VariableDict, + rngs: Optional[RNGSequences] = None, + mutable: CollectionFilter = False, + flags: Optional[Mapping] = None, +): """Binds variables and rngs to a new ``Scope``. bind provides a ``Scope`` instance without transforming a function with @@ -899,14 +937,17 @@ def bind(variables: VariableDict, raise errors.ApplyScopeInvalidVariablesTypeError() if rngs is not None and not _is_valid_rngs(rngs): raise errors.InvalidRngError( - 'rngs should be a dictionary mapping strings to `jax.PRNGKey`.') + 'rngs should be a dictionary mapping strings to `jax.PRNGKey`.' + ) new_variables = _unfreeze_variables(variables, mutable) return Scope(new_variables, rngs=rngs, mutable=mutable, flags=flags) -def apply(fn: Callable[..., Any], - mutable: CollectionFilter = False, - flags: Optional[Mapping] = None) -> Callable[..., Any]: +def apply( + fn: Callable[..., Any], + mutable: CollectionFilter = False, + flags: Optional[Mapping] = None, +) -> Callable[..., Any]: """Functionalize a `Scope` function. Args: @@ -919,18 +960,18 @@ def apply(fn: Callable[..., Any], """ @functools.wraps(fn) - def wrapper(variables: VariableDict, - *args, - rngs: Optional[RNGSequences] = None, - **kwargs) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]: + def wrapper( + variables: VariableDict, *args, rngs: Optional[RNGSequences] = None, **kwargs + ) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]: # Try to detect if user accidentally passed {'params': {'params': ...}. - if 'params' in variables and isinstance( - variables['params'], - (dict, FrozenDict)) and 'params' in variables['params']: + if ( + 'params' in variables + and isinstance(variables['params'], (dict, FrozenDict)) + and 'params' in variables['params'] + ): raise errors.ApplyScopeInvalidVariablesStructureError(variables) - with bind(variables, rngs=rngs, mutable=mutable, - flags=flags).temporary() as root: + with bind(variables, rngs=rngs, mutable=mutable, flags=flags).temporary() as root: y = fn(root, *args, **kwargs) if mutable is not False: return y, root.mutable_variables() @@ -940,9 +981,11 @@ def wrapper(variables: VariableDict, return wrapper -def init(fn: Callable[..., Any], - mutable: CollectionFilter = True, - flags: Optional[Mapping] = None) -> Callable[..., Any]: +def init( + fn: Callable[..., Any], + mutable: CollectionFilter = True, + flags: Optional[Mapping] = None, +) -> Callable[..., Any]: """Functionalize a `Scope` function for initialization. Args: @@ -957,21 +1000,24 @@ def init(fn: Callable[..., Any], @functools.wraps(fn) def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]: if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): - raise ValueError('First argument passed to an init function should be a ' - '`jax.PRNGKey` or a dictionary mapping strings to ' - '`jax.PRNGKey`.') + raise ValueError( + 'First argument passed to an init function should be a ' + '`jax.PRNGKey` or a dictionary mapping strings to ' + '`jax.PRNGKey`.' + ) if not isinstance(rngs, dict): rngs = {'params': rngs} init_flags = {**(flags if flags is not None else {}), 'initializing': True} - return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs, - **kwargs) + return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs, **kwargs) return wrapper -def lazy_init(fn: Callable[..., Any], - mutable: CollectionFilter = True, - flags: Optional[Mapping] = None) -> Callable[..., Any]: +def lazy_init( + fn: Callable[..., Any], + mutable: CollectionFilter = True, + flags: Optional[Mapping] = None, +) -> Callable[..., Any]: """Functionalizes a `Scope` function for lazy initialization. Similair to ``init`` except that the init function now accepts @@ -998,7 +1044,9 @@ def f(scope, x): `fn` with the scope partially applied. Unlike ``init`` which returns a tuple of function output and variables, the lazy init function only returns the variables. """ - return partial_eval.lazy_init(lambda *args, **kwargs: init(fn, mutable, flags)(*args, **kwargs)[1]) + return partial_eval.lazy_init( + lambda *args, **kwargs: init(fn, mutable, flags)(*args, **kwargs)[1] + ) def _is_valid_collection(col: VariableDict): @@ -1031,15 +1079,14 @@ def _is_valid_variables(variables: VariableDict) -> bool: def _is_valid_rng(rng: Array): """Checks whether rng is a valid JAX PRNGKey, also handling custom prngs.""" # New-style JAX KeyArrays have a base type. - if jax_config.jax_enable_custom_prng: # type: ignore[attr-defined] + if jax_config.jax_enable_custom_prng: # type: ignore[attr-defined] if not isinstance(rng, jax.random.KeyArray): return False # Old-style JAX PRNGKeys are plain uint32 arrays. else: if not isinstance(rng, (np.ndarray, jnp.ndarray)): return False - if (rng.shape != random.default_prng_impl().key_shape or - rng.dtype != jnp.uint32): + if rng.shape != random.default_prng_impl().key_shape or rng.dtype != jnp.uint32: return False return True diff --git a/flax/errors.py b/flax/errors.py index b48bc58906..baa4e0ee15 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -87,7 +87,8 @@ def __call__(self, x): def __init__(self, partial_val): super().__init__( f'Lazy init encountered a value that could with ' - f'the given inputs (shape: {partial_val}).') + f'the given inputs (shape: {partial_val}).' + ) ################################################# @@ -164,9 +165,11 @@ class ApplyScopeInvalidVariablesTypeError(FlaxError): """ def __init__(self): - super().__init__('The first argument passed to an apply function should be ' - 'a dictionary of collections. Each collection should be a ' - 'dictionary with string keys.') + super().__init__( + 'The first argument passed to an apply function should be ' + 'a dictionary of collections. Each collection should be a ' + 'dictionary with string keys.' + ) class ApplyScopeInvalidVariablesStructureError(FlaxError): @@ -181,7 +184,8 @@ def __init__(self, variables): 'Expect the `variables` (first argument) passed to apply() ' 'to be a dict with the structure {"params": ...}, but got a dict ' 'with an extra params layer, i.e. {"params": {"params": ... } }. ' - f'You should instead pass in your dict\'s ["params"].') + f'You should instead pass in your dict\'s ["params"].' + ) class ScopeParamNotFoundError(FlaxError): @@ -208,8 +212,9 @@ def __call__(self, inputs, embed_name='embedding'): """ def __init__(self, param_name, scope_path): - super().__init__(f'Could not find parameter named "{param_name}" in scope ' - f'"{scope_path}".') + super().__init__( + f'Could not find parameter named "{param_name}" in scope ' f'"{scope_path}".' + ) class ScopeCollectionNotFound(FlaxError): @@ -228,7 +233,8 @@ class ScopeCollectionNotFound(FlaxError): def __init__(self, col_name, var_name, scope_path): super().__init__( f'Tried to access "{var_name}" from collection "{col_name}" in ' - f'"{scope_path}" but the collection is empty.') + f'"{scope_path}" but the collection is empty.' + ) class ScopeParamShapeError(FlaxError): @@ -260,9 +266,11 @@ def __call__(self, x): """ def __init__(self, param_name, scope_path, value_shape, init_shape): - super().__init__(f'Initializer expected to generate shape {init_shape} ' - f'but got shape {value_shape} instead for parameter ' - f'"{param_name}" in "{scope_path}".') + super().__init__( + f'Initializer expected to generate shape {init_shape} ' + f'but got shape {value_shape} instead for parameter ' + f'"{param_name}" in "{scope_path}".' + ) class ScopeVariableNotFoundError(FlaxError): @@ -274,8 +282,10 @@ class ScopeVariableNotFoundError(FlaxError): """ def __init__(self, name, col, scope_path): - super().__init__(f'No Variable named "{name}" for collection "{col}" ' - f'exists in "{scope_path}".') + super().__init__( + f'No Variable named "{name}" for collection "{col}" ' + f'exists in "{scope_path}".' + ) class InvalidFilterError(FlaxError): @@ -288,10 +298,10 @@ def __init__(self, filter_like): class InvalidScopeError(FlaxError): """A temporary Scope is only valid within the context in which it is created:: - with Scope(variables, rngs=rngs).temporary() as root: - y = fn(root, *args, **kwargs) - # Here root is valid. - # Here root is invalid. + with Scope(variables, rngs=rngs).temporary() as root: + y = fn(root, *args, **kwargs) + # Here root is valid. + # Here root is invalid. """ def __init__(self, scope_name): @@ -319,8 +329,10 @@ def __call__(self, x): """ def __init__(self, col, variable_name, scope_path): - super().__init__(f'Cannot update variable "{variable_name}" in ' - f'"{scope_path}" because collection "{col}" is immutable.') + super().__init__( + f'Cannot update variable "{variable_name}" in ' + f'"{scope_path}" because collection "{col}" is immutable.' + ) class JaxTransformError(FlaxError): @@ -351,7 +363,8 @@ class PartitioningUnspecifiedError(FlaxError): def __init__(self, target): super().__init__( f'Trying to transform a Partitioned variable but "partition_name"' - f' is not specified in metadata_params: {target}') + f' is not specified in metadata_params: {target}' + ) ################################################# @@ -410,8 +423,10 @@ def __call__(self, inputs): def __init__(self, key_type, value, module_name): # key_type is in {param, variable, submodule}. - super().__init__(f'Could not create {key_type} "{value}" in Module ' - f'{module_name}: Name in use.') + super().__init__( + f'Could not create {key_type} "{value}" in Module ' + f'{module_name}: Name in use.' + ) class AssignSubModuleError(FlaxError): @@ -453,8 +468,10 @@ def __call__(self, x): """ def __init__(self, cls): - super().__init__(f'Submodule {cls} must be defined in `setup()` or in a ' - 'method wrapped in `@compact`') + super().__init__( + f'Submodule {cls} must be defined in `setup()` or in a ' + 'method wrapped in `@compact`' + ) class SetAttributeInModuleSetupError(FlaxError): @@ -525,9 +542,11 @@ def __call__(self, x): """ def __init__(self, module_cls, attr_name, attr_val): - super().__init__(f'Can\'t set {attr_name}={attr_val} for Module of type ' - f'{module_cls}: Module instance is frozen outside of ' - 'setup method.') + super().__init__( + f"Can't set {attr_name}={attr_val} for Module of type " + f'{module_cls}: Module instance is frozen outside of ' + 'setup method.' + ) class MultipleMethodsCompactError(FlaxError): @@ -558,8 +577,7 @@ class ReservedModuleAttributeError(FlaxError): """ def __init__(self, annotations): - super().__init__(f'properties `parent` and `name` are reserved: ' - f'{annotations}') + super().__init__(f'properties `parent` and `name` are reserved: ' f'{annotations}') class ApplyModuleInvalidMethodError(FlaxError): @@ -574,8 +592,9 @@ class ApplyModuleInvalidMethodError(FlaxError): """ def __init__(self, method): - super().__init__(f'Cannot call apply(): {method} is not a valid function ' - 'for apply().') + super().__init__( + f'Cannot call apply(): {method} is not a valid function ' 'for apply().' + ) class CallCompactUnboundModuleError(FlaxError): @@ -601,7 +620,7 @@ class CallCompactUnboundModuleError(FlaxError): """ def __init__(self): - super().__init__('Can\'t call compact methods on unbound modules') + super().__init__("Can't call compact methods on unbound modules") class CallSetupUnboundModuleError(FlaxError): @@ -634,7 +653,8 @@ def get_submodule(module): """ def __init__(self): - super().__init__('Can\'t call compact methods on unbound modules') + super().__init__("Can't call compact methods on unbound modules") + class CallUnbindOnUnboundModuleError(FlaxError): """This error occurs when you are trying to call ``.unbind()`` on an unbound @@ -658,8 +678,10 @@ def __call__(self, x): ... # do something with bound_module module = bound_module.unbind() # <-- OK! """ + def __init__(self): - super().__init__('Can\'t call `unbind()` on unbound modules') + super().__init__("Can't call `unbind()` on unbound modules") + class InvalidInstanceModuleError(FlaxError): """This error occurs when you are trying to call `.init()`, `.init_with_output()`, `.apply() or `.bind()` @@ -708,7 +730,8 @@ def __call__(self, input): def __init__(self): super().__init__( - 'Overrode `.__post_init__()` without calling `super().__post_init__()`') + 'Overrode `.__post_init__()` without calling `super().__post_init__()`' + ) class DescriptorAttributeError(FlaxError): @@ -763,7 +786,8 @@ def __init__(self, path, step): super().__init__( f'Checkpoint failed at step: "{step}" and path: "{path}": Target ' 'contains a multiprocess array should be saved/restored with a ' - 'GlobalAsyncCheckpointManager.') + 'GlobalAsyncCheckpointManager.' + ) class MPARestoreTargetRequiredError(FlaxError): @@ -781,7 +805,8 @@ def __init__(self, path, step, key=None): f'Restore checkpoint failed at step: "{step}" and path: "{path}": ' 'Checkpoints containing a multiprocess array need to be restored with ' 'a target with pre-created arrays. If you cannot provide a full valid ' - 'target, consider ``allow_partial_mpa_restoration=True``. ') + 'target, consider ``allow_partial_mpa_restoration=True``. ' + ) if key: error_msg += f'This error fired when trying to restore array at {key}.' super().__init__(error_msg) @@ -797,7 +822,8 @@ def __init__(self, step, path): super().__init__( f'Restore checkpoint failed at step: "{step}" on multiprocess array at ' f' "{path}": No "commit_success.txt" found on this "_gda" directory. ' - 'Was its save halted before completion?') + 'Was its save halted before completion?' + ) ################################################# @@ -810,7 +836,8 @@ class TransformedMethodReturnValueError(FlaxError): def __init__(self, name): super().__init__( - f'Transformed module method {name} cannot return Modules or Variables.') + f'Transformed module method {name} cannot return Modules or Variables.' + ) class TransformTargetError(FlaxError): @@ -842,7 +869,8 @@ def __init__(self, target): super().__init__( 'Linen transformations must be applied to Modules classes or' ' functions taking a Module instance as the first argument.' - f' The provided target is not a Module class or callable: {target}') + f' The provided target is not a Module class or callable: {target}' + ) ################################################# diff --git a/flax/ids.py b/flax/ids.py index c80fbdfc54..ba69a33e84 100644 --- a/flax/ids.py +++ b/flax/ids.py @@ -28,6 +28,7 @@ class UUIDManager: instead. - We need to handle copy/deepcopy uniqueness via a wrapped type. """ + def __init__(self): self._lock = threading.Lock() self._id = 0 @@ -37,21 +38,28 @@ def __call__(self): self._id += 1 return FlaxId(self._id) + uuid = UUIDManager() class FlaxId: """Hashable wrapper for ids that handles uniqueness of copies.""" + def __init__(self, rawid): self.id = rawid + def __eq__(self, other): return isinstance(other, FlaxId) and other.id == self.id + def __hash__(self): return hash(self.id) + def __repr__(self): return f"FlaxId({self.id})" + def __deepcopy__(self, memo): del memo return uuid() + def __copy__(self): return uuid() diff --git a/flax/io.py b/flax/io.py index 8b3ab38c24..1e32f6e468 100644 --- a/flax/io.py +++ b/flax/io.py @@ -34,16 +34,20 @@ class BackendMode(Enum): DEFAULT = 0 TF = 1 + io_mode = None gfile = None -if importlib.util.find_spec('tensorflow'): +if importlib.util.find_spec("tensorflow"): from tensorflow.io import gfile # type: ignore + io_mode = BackendMode.TF else: - logging.warning("Tensorflow library not found, tensorflow.io.gfile " - "operations will use native shim calls. " - "GCS paths (i.e. 'gs://...') cannot be accessed.") + logging.warning( + "Tensorflow library not found, tensorflow.io.gfile " + "operations will use native shim calls. " + "GCS paths (i.e. 'gs://...') cannot be accessed." + ) io_mode = BackendMode.DEFAULT @@ -52,6 +56,7 @@ class BackendMode(Enum): if io_mode == BackendMode.TF: from tensorflow import errors as tf_errors # type: ignore + NotFoundError = tf_errors.NotFoundError else: NotFoundError = FileNotFoundError @@ -91,10 +96,10 @@ def set_mode(override: BackendMode): def GFile(name, mode): # pylint: disable=invalid-name if io_mode == BackendMode.DEFAULT: - if 'b' in mode: + if "b" in mode: return open(name, mode) # pylint: disable=unspecified-encoding else: - return open(name, mode, encoding='utf-8') + return open(name, mode, encoding="utf-8") elif io_mode == BackendMode.TF: return gfile.GFile(name, mode) else: @@ -162,7 +167,7 @@ def makedirs(path): def glob(pattern): if io_mode == BackendMode.DEFAULT: - return [ path.rstrip('/') for path in glob_module.glob(pattern, recursive=False) ] + return [path.rstrip("/") for path in glob_module.glob(pattern, recursive=False)] elif io_mode == BackendMode.TF: return gfile.glob(pattern) else: diff --git a/flax/jax_utils.py b/flax/jax_utils.py index 3d303a6f73..b6bd9387da 100644 --- a/flax/jax_utils.py +++ b/flax/jax_utils.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities we could consider upstreaming to Jax. -""" +"""Utilities we could consider upstreaming to Jax.""" import collections from collections.abc import Iterable # pylint: disable=g-importing-member @@ -53,8 +52,7 @@ def unreplicate(tree): def pmean(xs, axis_name): - warnings.warn('use jax.lax.pmean instead', - DeprecationWarning) + warnings.warn("use jax.lax.pmean instead", DeprecationWarning) return lax.pmean(xs, axis_name) @@ -83,11 +81,14 @@ def partial_eval_by_shape(fn, input_spec, *args, **kwargs): input_structs = [_parse_spec(spec) for spec in input_spec] inputs_flat, in_tree = jax.tree_util.tree_flatten(input_structs) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(f), in_tree) - in_pvals = [pe.PartialVal.unknown(core.ShapedArray(x.shape, x.dtype)) - for x in inputs_flat] + in_pvals = [ + pe.PartialVal.unknown(core.ShapedArray(x.shape, x.dtype)) for x in inputs_flat + ] _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) - out_flat = [const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype) - for pv, const in out_pvals] + out_flat = [ + const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype) + for pv, const in out_pvals + ] return jax.tree_util.tree_unflatten(out_tree(), out_flat) @@ -160,8 +161,10 @@ def _scan_nd(body_fn, init, xs, n=1, unroll=(1,)): if n == 1: return lax.scan(body_fn, init, xs, unroll=unroll[0]) else: + def scan_body(c, x): - return _scan_nd(body_fn, c, x, n=n-1, unroll=unroll[1:]) + return _scan_nd(body_fn, c, x, n=n - 1, unroll=unroll[1:]) + return lax.scan(scan_body, init, xs, unroll=unroll[0]) @@ -207,6 +210,7 @@ def scan_in_dim(body_fn, init, xs, axis=(0,), unroll=(1,), keepdims=False): def transpose_in(x): perm = axis + tuple(np.delete(np.arange(x.ndim), axis)) return x.transpose(perm) + def transpose_out(x): perm = axis + tuple(np.delete(np.arange(x.ndim), axis)) return x.transpose(_invert_perm(perm)) @@ -218,7 +222,7 @@ def body_wrapper(c, xs): c, ys = body_fn(c, xs) if keepdims: ys = jax.tree_util.tree_map(transpose_in, ys) - ys = jax.tree_util.tree_map(lambda x: x.reshape(x.shape[len(axis):]), ys) + ys = jax.tree_util.tree_map(lambda x: x.reshape(x.shape[len(axis) :]), ys) return c, ys xs = jax.tree_util.tree_map(transpose_in, xs) @@ -228,8 +232,9 @@ def body_wrapper(c, xs): # Copied from https://github.com/google-research/big_vision -def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=(), - static_return=False): +def pad_shard_unpad( + wrapped, static_argnums=(0,), static_argnames=(), static_return=False +): """Wraps a function with code that pads, shards, then un-shards, un-pads. Args: @@ -284,12 +289,14 @@ def pad(x): db += 1 if min_device_batch and db < min_device_batch: x = np.concatenate( - [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)]) + [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)] + ) db = min_device_batch return x.reshape(d, db, *shape) def maybe_pad(tree, actually_pad=True): - if not actually_pad: return tree # For call-site convenience below. + if not actually_pad: + return tree # For call-site convenience below. return jax.tree_util.tree_map(pad, tree) args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)] @@ -299,6 +306,7 @@ def maybe_pad(tree, actually_pad=True): def unpad(x): # Transfer back before cutting, to reduce on-device shape diversity. return jax.device_get(x).reshape([np.prod(x.shape[:2]), *x.shape[2:]])[:b] + return out if static_return else jax.tree_util.tree_map(unpad, out) return pad_shard_unpad_wrapper diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 2aaf2a40ec..ee91cd6973 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -18,48 +18,48 @@ # pylint: disable=g-multiple-import,useless-import-alias # re-export commonly used modules and functions from .activation import ( - PReLU as PReLU, - celu as celu, - elu as elu, - gelu as gelu, - glu as glu, - hard_sigmoid as hard_sigmoid, - hard_silu as hard_silu, - hard_swish as hard_swish, - hard_tanh as hard_tanh, - leaky_relu as leaky_relu, - log_sigmoid as log_sigmoid, - log_softmax as log_softmax, - logsumexp as logsumexp, - normalize as normalize, - one_hot as one_hot, - relu as relu, - relu6 as relu6, - selu as selu, - sigmoid as sigmoid, - silu as silu, - soft_sign as soft_sign, - softmax as softmax, - softplus as softplus, - standardize as standardize, - swish as swish, - tanh as tanh + PReLU as PReLU, + celu as celu, + elu as elu, + gelu as gelu, + glu as glu, + hard_sigmoid as hard_sigmoid, + hard_silu as hard_silu, + hard_swish as hard_swish, + hard_tanh as hard_tanh, + leaky_relu as leaky_relu, + log_sigmoid as log_sigmoid, + log_softmax as log_softmax, + logsumexp as logsumexp, + normalize as normalize, + one_hot as one_hot, + relu as relu, + relu6 as relu6, + selu as selu, + sigmoid as sigmoid, + silu as silu, + soft_sign as soft_sign, + softmax as softmax, + softplus as softplus, + standardize as standardize, + swish as swish, + tanh as tanh, ) from .attention import ( - MultiHeadDotProductAttention as MultiHeadDotProductAttention, - SelfAttention as SelfAttention, - combine_masks as combine_masks, - dot_product_attention as dot_product_attention, - dot_product_attention_weights as dot_product_attention_weights, - make_attention_mask as make_attention_mask, - make_causal_mask as make_causal_mask + MultiHeadDotProductAttention as MultiHeadDotProductAttention, + SelfAttention as SelfAttention, + combine_masks as combine_masks, + dot_product_attention as dot_product_attention, + dot_product_attention_weights as dot_product_attention_weights, + make_attention_mask as make_attention_mask, + make_causal_mask as make_causal_mask, ) from .combinators import Sequential as Sequential from ..core import ( - DenyList as DenyList, - FrozenDict as FrozenDict, - broadcast as broadcast, - meta as meta, + DenyList as DenyList, + FrozenDict as FrozenDict, + broadcast as broadcast, + meta as meta, ) from ..core.meta import ( Partitioned as Partitioned, @@ -81,69 +81,65 @@ with_logical_partitioning as with_logical_partitioning, ) from .initializers import ( - ones as ones, - ones_init as ones_init, - zeros as zeros, - zeros_init as zeros_init + ones as ones, + ones_init as ones_init, + zeros as zeros, + zeros_init as zeros_init, ) from .linear import ( - Conv as Conv, - ConvLocal as ConvLocal, - ConvTranspose as ConvTranspose, - Dense as Dense, - DenseGeneral as DenseGeneral, - Embed as Embed + Conv as Conv, + ConvLocal as ConvLocal, + ConvTranspose as ConvTranspose, + Dense as Dense, + DenseGeneral as DenseGeneral, + Embed as Embed, ) from .module import ( - Module as Module, - Variable as Variable, - apply as apply, - compact as compact, - disable_named_call as disable_named_call, - enable_named_call as enable_named_call, - init as init, - init_with_output as init_with_output, - merge_param as merge_param, - nowrap as nowrap, - override_named_call as override_named_call + Module as Module, + Variable as Variable, + apply as apply, + compact as compact, + disable_named_call as disable_named_call, + enable_named_call as enable_named_call, + init as init, + init_with_output as init_with_output, + merge_param as merge_param, + nowrap as nowrap, + override_named_call as override_named_call, ) from .normalization import ( - BatchNorm as BatchNorm, - GroupNorm as GroupNorm, - LayerNorm as LayerNorm, - RMSNorm as RMSNorm -) -from .pooling import ( - avg_pool as avg_pool, - max_pool as max_pool, - pool as pool + BatchNorm as BatchNorm, + GroupNorm as GroupNorm, + LayerNorm as LayerNorm, + RMSNorm as RMSNorm, ) +from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool) from .recurrent import ( - ConvLSTMCell as ConvLSTMCell, - GRUCell as GRUCell, - LSTMCell as LSTMCell, - OptimizedLSTMCell as OptimizedLSTMCell, - RNN as RNN, - RNNCellBase as RNNCellBase, - Bidirectional as Bidirectional, + ConvLSTMCell as ConvLSTMCell, + GRUCell as GRUCell, + LSTMCell as LSTMCell, + OptimizedLSTMCell as OptimizedLSTMCell, + RNN as RNN, + RNNCellBase as RNNCellBase, + Bidirectional as Bidirectional, ) from .stochastic import Dropout as Dropout from .transforms import ( - checkpoint as checkpoint, - custom_vjp as custom_vjp, - jit as jit, - jvp as jvp, - map_variables as map_variables, - named_call as named_call, - remat as remat, - remat_scan as remat_scan, - scan as scan, - vjp as vjp, - vmap as vmap, - while_loop as while_loop, - cond as cond, - switch as switch, - add_metadata_axis, + checkpoint as checkpoint, + custom_vjp as custom_vjp, + jit as jit, + jvp as jvp, + map_variables as map_variables, + named_call as named_call, + remat as remat, + remat_scan as remat_scan, + scan as scan, + vjp as vjp, + vmap as vmap, + while_loop as while_loop, + cond as cond, + switch as switch, + add_metadata_axis, ) from .summary import tabulate # pylint: enable=g-multiple-import diff --git a/flax/linen/activation.py b/flax/linen/activation.py index d684869489..b5d22cc460 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Activation functions. -""" +"""Activation functions.""" # pylint: disable=unused-import # re-export activation functions from jax.nn @@ -70,6 +69,7 @@ class PReLU(Module): negative_slope_init: the value to initialize the negative slope (default 0.01). """ + param_dtype: Dtype = jnp.float32 negative_slope_init: float = 0.01 @@ -85,6 +85,8 @@ def __call__(self, inputs: Array) -> Array: """ negative_slope = self.param( 'negative_slope', - lambda k: jnp.asarray(self.negative_slope_init, self.param_dtype)) - return jnp.where(inputs >= 0, inputs, - jnp.asarray(negative_slope, inputs.dtype) * inputs) + lambda k: jnp.asarray(self.negative_slope_init, self.param_dtype), + ) + return jnp.where( + inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs + ) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 3a9953ae0f..28b277b5bc 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -38,16 +38,18 @@ Array = Any -def dot_product_attention_weights(query: Array, - key: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - broadcast_dropout: bool = True, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0., - deterministic: bool = False, - dtype: Optional[Dtype] = None, - precision: PrecisionLike = None): +def dot_product_attention_weights( + query: Array, + key: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + broadcast_dropout: bool = True, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0.0, + deterministic: bool = False, + dtype: Optional[Dtype] = None, + precision: PrecisionLike = None, +): """Computes dot-product attention weights given query and key. Used by :func:`dot_product_attention`, which is what you'll most likely use. @@ -83,18 +85,15 @@ def dot_product_attention_weights(query: Array, dtype = query.dtype assert query.ndim == key.ndim, 'q, k must have same rank.' - assert query.shape[:-3] == key.shape[:-3], ( - 'q, k batch dims must match.') - assert query.shape[-2] == key.shape[-2], ( - 'q, k num_heads must match.') + assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.' + assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.' assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' # calculate attention matrix depth = query.shape[-1] query = query / jnp.sqrt(depth).astype(dtype) # attn weight shape is (batch..., num_heads, q_length, kv_length) - attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key, - precision=precision) + attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key, precision=precision) # apply attention bias: masking, dropout, proximity bias, etc. if bias is not None: @@ -108,32 +107,33 @@ def dot_product_attention_weights(query: Array, attn_weights = jax.nn.softmax(attn_weights).astype(dtype) # apply attention dropout - if not deterministic and dropout_rate > 0.: + if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate if broadcast_dropout: # dropout is broadcast across the batch + head dimensions dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:] - keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore + keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore else: - keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore - multiplier = (keep.astype(dtype) / - jnp.asarray(keep_prob, dtype=dtype)) + keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore + multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) attn_weights = attn_weights * multiplier return attn_weights -def dot_product_attention(query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - broadcast_dropout: bool = True, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0., - deterministic: bool = False, - dtype: Optional[Dtype] = None, - precision: PrecisionLike = None): +def dot_product_attention( + query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + broadcast_dropout: bool = True, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0.0, + deterministic: bool = False, + dtype: Optional[Dtype] = None, + precision: PrecisionLike = None, +): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -172,55 +172,66 @@ def dot_product_attention(query: Array, query, key, value = promote_dtype(query, key, value, dtype=dtype) dtype = query.dtype assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( - 'q, k, v num_heads must match.') + assert ( + query.shape[:-3] == key.shape[:-3] == value.shape[:-3] + ), 'q, k, v batch dims must match.' + assert ( + query.shape[-2] == key.shape[-2] == value.shape[-2] + ), 'q, k, v num_heads must match.' assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # compute attention weights attn_weights = dot_product_attention_weights( - query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate, - deterministic, dtype, precision) + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + deterministic, + dtype, + precision, + ) # return weighted sum over values for each query position - return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, - precision=precision) + return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, precision=precision) class MultiHeadDotProductAttention(Module): """Multi-head dot-product attention. - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - dtype: the dtype of the computation - (default: infer from inputs and params) - param_dtype: the dtype passed to parameter initializers (default: float32) - qkv_features: dimension of the key, query, and value. - out_features: dimension of the last projection - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rate: dropout rate - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the kernel of the Dense layers. - bias_init: initializer for the bias of the Dense layers. - use_bias: bool: whether pointwise QKVO dense transforms use bias. - attention_fn: dot_product_attention or compatible function. Accepts - query, key, value, and returns output of shape - `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` - decode: whether to prepare and use an autoregressive cache. + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: the dtype of the computation + (default: infer from inputs and params) + param_dtype: the dtype passed to parameter initializers (default: float32) + qkv_features: dimension of the key, query, and value. + out_features: dimension of the last projection + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rate: dropout rate + deterministic: if false, the attention weight is masked randomly + using dropout, whereas if true, the attention weights + are deterministic. + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the kernel of the Dense layers. + bias_init: initializer for the bias of the Dense layers. + use_bias: bool: whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts + query, key, value, and returns output of shape + `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. """ + num_heads: int dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 qkv_features: Optional[int] = None out_features: Optional[int] = None broadcast_dropout: bool = True - dropout_rate: float = 0. + dropout_rate: float = 0.0 deterministic: Optional[bool] = None precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init @@ -232,11 +243,13 @@ class MultiHeadDotProductAttention(Module): out_dot_general: DotGeneralT = lax.dot_general @compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None): + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -280,30 +293,33 @@ def __call__(self, ) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] - query, key, value = (dense(name='query')(inputs_q), - dense(name='key')(inputs_kv), - dense(name='value')(inputs_kv)) + query, key, value = ( + dense(name='query')(inputs_q), + dense(name='key')(inputs_kv), + dense(name='value')(inputs_kv), + ) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') - cached_key = self.variable('cache', 'cached_key', - jnp.zeros, key.shape, key.dtype) - cached_value = self.variable('cache', 'cached_value', - jnp.zeros, value.shape, value.dtype) - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.int32)) + cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) + cached_value = self.variable( + 'cache', 'cached_value', jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) + ) if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = ( - cached_key.value.shape) + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: - raise ValueError('Autoregressive cache shape error, ' - 'expected query shape %s instead got %s.' % - (expected_shape, query.shape)) + raise ValueError( + 'Autoregressive cache shape error, ' + 'expected query shape %s instead got %s.' % (expected_shape, query.shape) + ) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0,) * len(batch_dims) + (cur_index, 0, 0) @@ -318,13 +334,15 @@ def __call__(self, # not the remaining zero elements. mask = combine_masks( mask, - jnp.broadcast_to(jnp.arange(max_length) <= cur_index, - tuple(batch_dims) + (1, 1, max_length))) + jnp.broadcast_to( + jnp.arange(max_length) <= cur_index, + tuple(batch_dims) + (1, 1, max_length), + ), + ) dropout_rng = None - if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. - m_deterministic = merge_param('deterministic', self.deterministic, - deterministic) + if self.dropout_rate > 0.0: # Require `deterministic` only if using dropout. + m_deterministic = merge_param('deterministic', self.deterministic, deterministic) if not m_deterministic: dropout_rng = self.make_rng('dropout') else: @@ -341,7 +359,8 @@ def __call__(self, broadcast_dropout=self.broadcast_dropout, deterministic=m_deterministic, dtype=self.dtype, - precision=self.precision) # pytype: disable=wrong-keyword-args + precision=self.precision, + ) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral( features=features, @@ -353,7 +372,7 @@ def __call__(self, param_dtype=self.param_dtype, precision=self.precision, dot_general=self.out_dot_general, - name='out', # type: ignore[call-arg] + name='out', # type: ignore[call-arg] )(x) return out @@ -362,8 +381,12 @@ class SelfAttention(MultiHeadDotProductAttention): """Self-attention special case of multi-head dot-product attention.""" @compact - def __call__(self, inputs_q: Array, mask: Optional[Array] = None, # type: ignore - deterministic: Optional[bool] = None): + def __call__( + self, + inputs_q: Array, + mask: Optional[Array] = None, # type: ignore + deterministic: Optional[bool] = None, + ): """Applies multi-head dot product self-attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -383,18 +406,19 @@ def __call__(self, inputs_q: Array, mask: Optional[Array] = None, # type: ignore Returns: output of shape `[batch_sizes..., length, features]`. """ - return super().__call__(inputs_q, inputs_q, mask, - deterministic=deterministic) + return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic) # mask-making utility functions -def make_attention_mask(query_input: Array, - key_input: Array, - pairwise_fn: Callable[..., Any] = jnp.multiply, - extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32): +def make_attention_mask( + query_input: Array, + key_input: Array, + pairwise_fn: Callable[..., Any] = jnp.multiply, + extra_batch_dims: int = 0, + dtype: Dtype = jnp.float32, +): """Mask-making helper for attention weights. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the @@ -412,16 +436,17 @@ def make_attention_mask(query_input: Array, Returns: A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention. """ - mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), - jnp.expand_dims(key_input, axis=-2)) + mask = pairwise_fn( + jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) + ) mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) return mask.astype(dtype) -def make_causal_mask(x: Array, - extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32) -> Array: +def make_causal_mask( + x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32 +) -> Array: """Make a causal mask for self-attention. In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights @@ -438,12 +463,12 @@ def make_causal_mask(x: Array, A `[batch..., 1, len, len]` shaped causal mask for 1d attention. """ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) - return make_attention_mask(idxs, idxs, jnp.greater_equal, - extra_batch_dims=extra_batch_dims, dtype=dtype) + return make_attention_mask( + idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype + ) -def combine_masks(*masks: Optional[Array], - dtype: Dtype = jnp.float32) -> Array: +def combine_masks(*masks: Optional[Array], dtype: Dtype = jnp.float32) -> Array: """Combine attention masks. Args: @@ -456,8 +481,9 @@ def combine_masks(*masks: Optional[Array], masks_list = [m for m in masks if m is not None] if not masks_list: return None - assert all(map(lambda x: x.ndim == masks_list[0].ndim, masks_list)), ( - f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}') + assert all( + map(lambda x: x.ndim == masks_list[0].ndim, masks_list) + ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}' mask, *other_masks = masks_list for other_mask in other_masks: mask = jnp.logical_and(mask, other_mask) diff --git a/flax/linen/combinators.py b/flax/linen/combinators.py index 82cc5a0a1a..299a0b8860 100644 --- a/flax/linen/combinators.py +++ b/flax/linen/combinators.py @@ -69,17 +69,19 @@ def __call__(self, x): return nn.Sequential([CrossAttentionBlock() for _ in range(self.num_layers)])(query, key_value) """ + layers: Sequence[Callable[..., Any]] def __post_init__(self): if not isinstance(self.layers, Sequence): - raise ValueError('\'layers\' must be a sequence, ' - f'got \'{type(self.layers).__name__}\'.') + raise ValueError( + "'layers' must be a sequence, " f"got '{type(self.layers).__name__}'." + ) super().__post_init__() def __call__(self, *args, **kwargs): if not self.layers: - raise ValueError(f'Empty Sequential module {self.name}.') + raise ValueError(f"Empty Sequential module {self.name}.") outputs = self.layers[0](*args, **kwargs) for layer in self.layers[1:]: diff --git a/flax/linen/dtypes.py b/flax/linen/dtypes.py index c447f249db..d0d9dd59d4 100644 --- a/flax/linen/dtypes.py +++ b/flax/linen/dtypes.py @@ -37,9 +37,9 @@ Array = Any -def canonicalize_dtype(*args, - dtype: Optional[Dtype] = None, - inexact: bool = True) -> Dtype: +def canonicalize_dtype( + *args, dtype: Optional[Dtype] = None, inexact: bool = True +) -> Dtype: """Canonicalize an optional dtype to the definitive dtype. If the ``dtype`` is None this function will infer the dtype. If it is not @@ -70,7 +70,7 @@ def canonicalize_dtype(*args, def promote_dtype(*args, dtype=None, inexact=True) -> List[Array]: - """"Promotes input arguments to a specified or inferred dtype. + """ "Promotes input arguments to a specified or inferred dtype. All args are cast to the same dtype. See ``canonicalize_dtype`` for how this dtype is determined. @@ -94,5 +94,4 @@ def promote_dtype(*args, dtype=None, inexact=True) -> List[Array]: The arguments cast to arrays of the same dtype. """ dtype = canonicalize_dtype(*args, dtype=dtype, inexact=inexact) - return [jnp.asarray(x, dtype) if x is not None else None - for x in args] + return [jnp.asarray(x, dtype) if x is not None else None for x in args] diff --git a/flax/linen/experimental/layers_with_named_axes.py b/flax/linen/experimental/layers_with_named_axes.py index 37a4e79779..a2ca4d619c 100644 --- a/flax/linen/experimental/layers_with_named_axes.py +++ b/flax/linen/experimental/layers_with_named_axes.py @@ -59,6 +59,7 @@ class Dense(nn.Module): kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. """ + features: int use_bias: bool = True dtype: DType = jnp.float32 @@ -80,11 +81,13 @@ def __call__(self, inputs: Array) -> Array: The transformed input. """ inputs = jnp.asarray(inputs, self.dtype) - kernel = param_with_axes('kernel', - self.kernel_init, - (inputs.shape[-1], self.features), - self.param_dtype, - axes=self.kernel_axes) + kernel = param_with_axes( + 'kernel', + self.kernel_init, + (inputs.shape[-1], self.features), + self.param_dtype, + axes=self.kernel_axes, + ) kernel = jnp.asarray(kernel, self.dtype) y = self.dot_general( inputs, @@ -93,11 +96,13 @@ def __call__(self, inputs: Array) -> Array: precision=self.precision, ) if self.use_bias: - bias = param_with_axes('bias', - self.bias_init, - (self.features,), - self.param_dtype, - axes=(self.kernel_axes[-1],)) + bias = param_with_axes( + 'bias', + self.bias_init, + (self.features,), + self.param_dtype, + axes=(self.kernel_axes[-1],), + ) bias = jnp.asarray(bias, self.dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -119,6 +124,7 @@ class Embed(nn.Module): one_hot: performs the gather with a one-hot contraction rather than a true gather. This is currently needed for SPMD partitioning. """ + num_embeddings: int features: int cast_input_dtype: Optional[DType] = None @@ -132,9 +138,11 @@ class Embed(nn.Module): def setup(self): self.embedding = param_with_axes( 'embedding', - self.embedding_init, (self.num_embeddings, self.features), + self.embedding_init, + (self.num_embeddings, self.features), self.param_dtype, - axes=('vocab', 'embed')) + axes=('vocab', 'embed'), + ) def __call__(self, inputs: Array) -> Array: """Embeds the inputs along the last dimension. @@ -210,18 +218,26 @@ def _compute_stats(x: Array, axes: Axes): mean2 = jnp.mean(_abs_sq(x), axes) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. - var = jnp.maximum(0., mean2 - _abs_sq(mean)) + var = jnp.maximum(0.0, mean2 - _abs_sq(mean)) return mean, var -def _normalize(mdl: nn.Module, x: Array, mean: Array, var: Array, - reduction_axes: Axes, feature_axes: Axes, - dtype: DType, param_dtype: DType, - epsilon: float, - use_bias: bool, use_scale: bool, - bias_init: Callable[[PRNGKey, Shape, DType], Array], - scale_init: Callable[[PRNGKey, Shape, DType], Array]): - """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias. +def _normalize( + mdl: nn.Module, + x: Array, + mean: Array, + var: Array, + reduction_axes: Axes, + feature_axes: Axes, + dtype: DType, + param_dtype: DType, + epsilon: float, + use_bias: bool, + use_scale: bool, + bias_init: Callable[[PRNGKey, Shape, DType], Array], + scale_init: Callable[[PRNGKey, Shape, DType], Array], +): + """ "Normalizes the input of a normalization layer and optionally applies a learned scale and bias. A seperate bias and scale is learned for each feature as specified by feature_axes. @@ -242,17 +258,14 @@ def _normalize(mdl: nn.Module, x: Array, mean: Array, var: Array, mul = lax.rsqrt(var + epsilon) if use_scale: scale = mdl.param_with_axes( - 'scale', - scale_init, - reduced_feature_shape, - param_dtype, - axes=('embed',)).reshape(feature_shape) + 'scale', scale_init, reduced_feature_shape, param_dtype, axes=('embed',) + ).reshape(feature_shape) mul *= scale y *= mul if use_bias: bias = mdl.param_with_axes( - 'bias', bias_init, reduced_feature_shape, param_dtype, - axes=('embed',)).reshape(feature_shape) + 'bias', bias_init, reduced_feature_shape, param_dtype, axes=('embed',) + ).reshape(feature_shape) y += bias return jnp.asarray(y, dtype) @@ -282,6 +295,7 @@ class LayerNorm(nn.Module): bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. """ + epsilon: float = 1e-6 dtype: Any = jnp.float32 param_dtype: DType = jnp.float32 @@ -306,7 +320,17 @@ def __call__(self, x): mean, var = _compute_stats(x, reduction_axes) return _normalize( - self, x, mean, var, reduction_axes, feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, + x, + mean, + var, + reduction_axes, + feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ) diff --git a/flax/linen/initializers.py b/flax/linen/initializers.py index 93a6934a32..8579a9c62e 100644 --- a/flax/linen/initializers.py +++ b/flax/linen/initializers.py @@ -37,6 +37,7 @@ from jax.nn.initializers import Initializer as Initializer # pylint: enable=unused-import + def zeros_init() -> Initializer: """Builds an initializer that returns a constant array full of zeros. @@ -49,6 +50,7 @@ def zeros_init() -> Initializer: """ return zeros + def ones_init() -> Initializer: """Builds an initializer that returns a constant array full of ones. @@ -60,4 +62,4 @@ def ones_init() -> Initializer: [1., 1.], [1., 1.]], dtype=float32) """ - return ones \ No newline at end of file + return ones diff --git a/flax/linen/kw_only_dataclasses.py b/flax/linen/kw_only_dataclasses.py index da010c6378..85a16917d9 100644 --- a/flax/linen/kw_only_dataclasses.py +++ b/flax/linen/kw_only_dataclasses.py @@ -76,9 +76,10 @@ def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs): A `dataclasses.Field` object. """ if kw_only is not dataclasses.MISSING and kw_only: - if (kwargs.get('default', dataclasses.MISSING) is dataclasses.MISSING and - kwargs.get('default_factory', - dataclasses.MISSING) is dataclasses.MISSING): + if ( + kwargs.get('default', dataclasses.MISSING) is dataclasses.MISSING + and kwargs.get('default_factory', dataclasses.MISSING) is dataclasses.MISSING + ): raise ValueError('Keyword-only fields with no default are not supported.') if metadata is None: metadata = {} @@ -131,8 +132,7 @@ def _process_class(cls, extra_fields=None, **kwargs): kw_only_name = name elif kw_only_name is not None: if not hasattr(cls, name): - raise ValueError( - 'Keyword-only fields with no default are not supported.') + raise ValueError('Keyword-only fields with no default are not supported.') default = getattr(cls, name) if isinstance(default, dataclasses.Field): default.metadata = {**default.metadata, **{KW_ONLY: True}} @@ -146,8 +146,9 @@ def _process_class(cls, extra_fields=None, **kwargs): if extra_fields: for name, annotation, default in extra_fields: if not (isinstance(name, str) and isinstance(default, dataclasses.Field)): - raise ValueError('Expected extra_fields to a be a list of ' - '(name, type, Field) tuples.') + raise ValueError( + 'Expected extra_fields to a be a list of ' '(name, type, Field) tuples.' + ) setattr(cls, name, default) cls.__annotations__[name] = annotation @@ -156,21 +157,20 @@ def _process_class(cls, extra_fields=None, **kwargs): if not dataclasses.is_dataclass(base): continue base_annotations = base.__dict__.get('__annotations__', {}) - base_dataclass_fields[base] = dict( - getattr(base, '__dataclass_fields__', {})) + base_dataclass_fields[base] = dict(getattr(base, '__dataclass_fields__', {})) for base_field in list(dataclasses.fields(base)): field_name = base_field.name if base_field.metadata.get(KW_ONLY) or field_name in kw_only_fields: - kw_only_fields[field_name] = (base_annotations.get(field_name), - base_field) + kw_only_fields[field_name] = (base_annotations.get(field_name), base_field) del base.__dataclass_fields__[field_name] # Remove any keyword-only fields from this class. cls_annotations = cls.__dict__['__annotations__'] for name, annotation in list(cls_annotations.items()): value = getattr(cls, name, None) - if ((isinstance(value, dataclasses.Field) and value.metadata.get(KW_ONLY)) - or name in kw_only_fields): + if ( + isinstance(value, dataclasses.Field) and value.metadata.get(KW_ONLY) + ) or name in kw_only_fields: del cls_annotations[name] kw_only_fields[name] = (annotation, value) @@ -185,7 +185,7 @@ def _process_class(cls, extra_fields=None, **kwargs): transformed_cls = dataclasses.dataclass(cls, **kwargs) # Restore the base classes' __dataclass_fields__. - for (cls, dataclass_fields) in base_dataclass_fields.items(): + for cls, dataclass_fields in base_dataclass_fields.items(): cls.__dataclass_fields__ = dataclass_fields # Return the transformed dataclass diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 95b7531de1..b711eab904 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -15,8 +15,7 @@ """Linear modules.""" import dataclasses -from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple, - Union) +from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union) from flax.core import meta from flax.linen import initializers @@ -36,8 +35,9 @@ Shape = Tuple[int, ...] Dtype = Any # this could be a real type? Array = Any -PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], - Tuple[lax.Precision, lax.Precision]] +PrecisionLike = Union[ + None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] +] DotGeneralT = Callable[..., Array] ConvGeneralDilatedT = Callable[..., Array] @@ -72,6 +72,7 @@ class DenseGeneral(Module): precision: numerical precision of the computation see `jax.lax.Precision` for details. """ + features: Union[int, Sequence[int]] axis: Union[int, Sequence[int]] = -1 batch_dims: Sequence[int] = () @@ -99,8 +100,10 @@ def __call__(self, inputs: Array) -> Array: if batch_dims: max_dim = np.max(batch_dims) if set(batch_dims) != set(range(max_dim + 1)): - raise ValueError('batch_dims %s must be consecutive leading ' - 'dimensions starting from 0.' % str(batch_dims)) + raise ValueError( + 'batch_dims %s must be consecutive leading ' + 'dimensions starting from 0.' % str(batch_dims) + ) ndim = inputs.ndim n_batch_dims = len(batch_dims) @@ -109,9 +112,11 @@ def __call__(self, inputs: Array) -> Array: n_axis, n_features = len(axis), len(features) def kernel_init_wrap(rng, shape, dtype=jnp.float32): - flat_shape = (np.prod(shape[:n_batch_dims]) * - np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), - np.prod(shape[-n_features:]),) + flat_shape = ( + np.prod(shape[:n_batch_dims]) + * np.prod(shape[n_batch_dims : n_axis + n_batch_dims]), + np.prod(shape[-n_features:]), + ) flat_shape = jax.tree_map(int, flat_shape) kernel = self.kernel_init(rng, flat_shape, dtype) if isinstance(kernel, meta.AxisMetadata): @@ -122,26 +127,30 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): # batch and non-contracting dims of input with 1s for batch dims. expanded_batch_shape = tuple( inputs.shape[ax] if ax in batch_dims else 1 - for ax in range(inputs.ndim) if ax not in axis) + for ax in range(inputs.ndim) + if ax not in axis + ) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features - kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape, - self.param_dtype) + kernel = self.param( + 'kernel', kernel_init_wrap, batch_shape + kernel_shape, self.param_dtype + ) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) if self.use_bias: + def bias_init_wrap(rng, shape, dtype=jnp.float32): - flat_shape = (np.prod(shape[:n_batch_dims]) * - np.prod(shape[-n_features:]),) + flat_shape = (np.prod(shape[:n_batch_dims]) * np.prod(shape[-n_features:]),) flat_shape = jax.tree_map(int, flat_shape) bias = self.bias_init(rng, flat_shape, dtype) if isinstance(bias, meta.AxisMetadata): return meta.replace_boxed(bias, jnp.reshape(bias.unbox(), shape)) return jnp.reshape(bias, shape) - bias = self.param('bias', bias_init_wrap, batch_shape + features, - self.param_dtype) + bias = self.param( + 'bias', bias_init_wrap, batch_shape + features, self.param_dtype + ) else: bias = None @@ -174,6 +183,7 @@ class Dense(Module): kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. """ + features: int use_bias: bool = True dtype: Optional[Dtype] = None @@ -193,13 +203,14 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - kernel = self.param('kernel', - self.kernel_init, - (jnp.shape(inputs)[-1], self.features), - self.param_dtype) + kernel = self.param( + 'kernel', + self.kernel_init, + (jnp.shape(inputs)[-1], self.features), + self.param_dtype, + ) if self.use_bias: - bias = self.param('bias', self.bias_init, (self.features,), - self.param_dtype) + bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype) else: bias = None inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) @@ -228,7 +239,7 @@ def _conv_dimension_numbers(input_shape): def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: - """"Canonicalizes conv padding to a jax.lax supported format.""" + """ "Canonicalizes conv padding to a jax.lax supported format.""" if isinstance(padding, str): return padding if isinstance(padding, int): @@ -245,9 +256,10 @@ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: if len(new_pad) == rank: return new_pad raise ValueError( - f'Invalid padding format: {padding}, should be str, int,' - f' or a sequence of len {rank} where each element is an' - f' int or pair of ints.') + f'Invalid padding format: {padding}, should be str, int,' + f' or a sequence of len {rank} where each element is an' + f' int or pair of ints.' + ) class _Conv(Module): @@ -287,6 +299,7 @@ class _Conv(Module): kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. """ + features: int kernel_size: Sequence[int] strides: Union[None, int, Sequence[int]] = 1 @@ -337,14 +350,15 @@ def __call__(self, inputs: Array) -> Array: """ if isinstance(self.kernel_size, int): - raise TypeError('Expected Conv kernel_size to be a' - ' tuple/list of integers (eg.: [3, 3]) but got' - f' {self.kernel_size}.') + raise TypeError( + 'Expected Conv kernel_size to be a' + ' tuple/list of integers (eg.: [3, 3]) but got' + f' {self.kernel_size}.' + ) else: kernel_size = tuple(self.kernel_size) - def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( - Tuple[int, ...]): + def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 @@ -358,8 +372,7 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( if num_batch_dimensions != 1: input_batch_shape = inputs.shape[:num_batch_dimensions] total_batch_size = int(np.prod(input_batch_shape)) - flat_input_shape = ( - (total_batch_size,) + inputs.shape[num_batch_dimensions:]) + flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:] inputs = jnp.reshape(inputs, flat_input_shape) # self.strides or (1,) * (inputs.ndim - 2) @@ -373,14 +386,12 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) ] zero_pad: List[Tuple[int, int]] = [(0, 0)] - pads = (zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + - [(0, 0)]) + pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] inputs = jnp.pad(inputs, pads, mode='wrap') padding_lax = 'VALID' elif padding_lax == 'CAUSAL': if len(kernel_size) != 1: - raise ValueError( - 'Causal padding is only implemented for 1D convolutions.') + raise ValueError('Causal padding is only implemented for 1D convolutions.') left_pad = kernel_dilation[0] * (kernel_size[0] - 1) pads = [(0, 0), (left_pad, 0), (0, 0)] inputs = jnp.pad(inputs, pads) @@ -393,7 +404,9 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( # One shared convolutional kernel for all pixels in the output. assert in_features % self.feature_group_count == 0 kernel_shape = kernel_size + ( - in_features // self.feature_group_count, self.features) + in_features // self.feature_group_count, + self.features, + ) else: if self.feature_group_count != 1: @@ -419,15 +432,18 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( ).shape # One (unshared) convolutional kernel per each pixel in the output. - kernel_shape = conv_output_shape[1:-1] + (np.prod(kernel_size) * - in_features, self.features) + kernel_shape = conv_output_shape[1:-1] + ( + np.prod(kernel_size) * in_features, + self.features, + ) if self.mask is not None and self.mask.shape != kernel_shape: - raise ValueError('Mask needs to have the same shape as weights. ' - f'Shapes are: {self.mask.shape}, {kernel_shape}') + raise ValueError( + 'Mask needs to have the same shape as weights. ' + f'Shapes are: {self.mask.shape}, {kernel_shape}' + ) - kernel = self.param('kernel', self.kernel_init, kernel_shape, - self.param_dtype) + kernel = self.param('kernel', self.kernel_init, kernel_shape, self.param_dtype) if self.mask is not None: kernel *= self.mask @@ -467,7 +483,7 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, - precision=self.precision + precision=self.precision, ) if self.use_bias: @@ -597,6 +613,7 @@ class ConvTranspose(Module): transpose_kernel: if True flips spatial axes and swaps the input/output channel axes of the kernel. """ + features: int kernel_size: Union[int, Sequence[int]] strides: Optional[Sequence[int]] = None @@ -644,8 +661,7 @@ def __call__(self, inputs: Array) -> Array: if num_batch_dimensions != 1: input_batch_shape = inputs.shape[:num_batch_dimensions] total_batch_size = int(np.prod(input_batch_shape)) - flat_input_shape = ( - (total_batch_size,) + inputs.shape[num_batch_dimensions:]) + flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:] inputs = jnp.reshape(inputs, flat_input_shape) strides: Tuple[int, ...] @@ -661,11 +677,12 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = kernel_size + (in_features, self.features) if self.mask is not None and self.mask.shape != kernel_shape: - raise ValueError('Mask needs to have the same shape as weights. ' - f'Shapes are: {self.mask.shape}, {kernel_shape}') + raise ValueError( + 'Mask needs to have the same shape as weights. ' + f'Shapes are: {self.mask.shape}, {kernel_shape}' + ) - kernel = self.param('kernel', self.kernel_init, kernel_shape, - self.param_dtype) + kernel = self.param('kernel', self.kernel_init, kernel_shape, self.param_dtype) if self.mask is not None: kernel *= self.mask @@ -675,13 +692,11 @@ def __call__(self, inputs: Array) -> Array: padding_lax = 'VALID' if self.use_bias: - bias = self.param('bias', self.bias_init, (self.features,), - self.param_dtype) + bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype) else: bias = None - inputs, kernel, bias = promote_dtype(inputs, kernel, bias, - dtype=self.dtype) + inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) y = lax.conv_transpose( inputs, @@ -690,7 +705,8 @@ def __call__(self, inputs: Array) -> Array: padding_lax, rhs_dilation=self.kernel_dilation, transpose_kernel=self.transpose_kernel, - precision=self.precision) + precision=self.precision, + ) if self.padding == 'CIRCULAR': # For circular padding, we need to identify the size of the final output @@ -716,22 +732,17 @@ def __call__(self, inputs: Array) -> Array: # If the kernel is transposed, the "+1" is put on the right to # mirror the regular convolution. If the same kernel parameters are used # as for Conv, this layer then computes the proper transpose convolution. - total_pad = [ - (size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs - ] + total_pad = [(size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs] else: # Divide the padding equally between left and right. The choice to put # "+1" on the left (and not on the right) represents a convention for # aligning even-sized kernels. - total_pad = [ - ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs - ] + total_pad = [((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs] y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)]) # Wrap the result periodically around each spatial dimension, # one by one. for i in range(1, y.ndim - 1): - y = y.reshape(y.shape[:i] + (-1, scaled_x_dims[i - 1]) + - y.shape[i + 1:]) + y = y.reshape(y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :]) y = y.sum(axis=i) if self.use_bias: @@ -759,6 +770,7 @@ class Embed(Module): param_dtype: the dtype passed to parameter initializers (default: float32). embedding_init: embedding initializer. """ + num_embeddings: int features: int dtype: Optional[Dtype] = None @@ -768,10 +780,12 @@ class Embed(Module): embedding: Array = dataclasses.field(init=False) def setup(self): - self.embedding = self.param('embedding', - self.embedding_init, - (self.num_embeddings, self.features), - self.param_dtype) + self.embedding = self.param( + 'embedding', + self.embedding_init, + (self.num_embeddings, self.features), + self.param_dtype, + ) def __call__(self, inputs: Array) -> Array: """Embeds the inputs along the last dimension. @@ -787,7 +801,7 @@ def __call__(self, inputs: Array) -> Array: raise ValueError('Input type must be an integer or unsigned integer.') # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. - embedding, = promote_dtype(self.embedding, dtype=self.dtype, inexact=False) + (embedding,) = promote_dtype(self.embedding, dtype=self.dtype, inexact=False) return jnp.take(embedding, inputs, axis=0) def attend(self, query: Array) -> Array: diff --git a/flax/linen/module.py b/flax/linen/module.py index 169bd22f83..e3a22fc078 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -24,25 +24,43 @@ from types import MappingProxyType import typing import weakref -from typing import (Any, Callable, Dict, Iterable, List, Sequence, NamedTuple, Mapping, - Optional, Set, Tuple, Type, TypeVar, Union, overload) +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Sequence, + NamedTuple, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + overload, +) import jax import numpy as np import jax.numpy as jnp -from typing_extensions import Protocol, \ - dataclass_transform # pytype: disable=not-supported-yet +from typing_extensions import Protocol, dataclass_transform # pytype: disable=not-supported-yet import flax import flax.linen as nn -from flax import (config, core, errors, serialization, traceback_util, - traverse_util) +from flax import (config, core, errors, serialization, traceback_util, traverse_util) from flax.core import Scope from flax.core import partial_eval from flax.core.frozen_dict import FrozenDict from flax.core.scope import ( # pylint: disable=g-multiple-import - CollectionFilter, DenyList, FrozenVariableDict, Variable, VariableDict, - union_filters) + CollectionFilter, + DenyList, + FrozenVariableDict, + Variable, + VariableDict, + union_filters, +) from flax.ids import FlaxId from flax.ids import uuid from flax.linen import kw_only_dataclasses @@ -52,7 +70,7 @@ KeyArray = Union[jax.Array, jax.random.KeyArray] # pylint: disable=invalid-name RNGSequences = Dict[str, KeyArray] -Array = Any # pylint: disable=invalid-name +Array = Any # pylint: disable=invalid-name T = TypeVar('T') @@ -62,13 +80,14 @@ # Used for abstractly testing module behavior. -TestScope = type('TestScope', - (Scope,), - {'make_rng': lambda self, name: jax.random.PRNGKey(0)}) +TestScope = type( + 'TestScope', (Scope,), {'make_rng': lambda self, name: jax.random.PRNGKey(0)} +) # pylint: disable=protected-access,attribute-defined-outside-init + def _indent(x: str, num_spaces: int): indent_str = ' ' * num_spaces lines = x.split('\n') @@ -99,8 +118,11 @@ def _module_repr(module: 'Module', num_spaces: int = 4): for f in dataclasses.fields(cls) if f.name not in ('parent', 'name') and f.repr } - child_modules = {k: v for k, v in module._state.children.items() # pytype: disable=attribute-error - if isinstance(v, Module)} + child_modules = { + k: v + for k, v in module._state.children.items() # pytype: disable=attribute-error + if isinstance(v, Module) + } if attributes: rep += '# attributes\n' for attr in attributes.keys(): @@ -118,11 +140,13 @@ def _module_repr(module: 'Module', num_spaces: int = 4): else: return f'{cls_name}()' + # Tabulation utilities. # ----------------------------------------------------------------------------- _find_non_lifted_module = re.compile(r'.*\((.*)\)') + def _fix_path_part(part: str): """Fixes a path part by removing transformation name and parenthesis sometimes inserted by lifted transformations""" @@ -131,6 +155,7 @@ def _fix_path_part(part: str): return match.group(1) return part + @dataclasses.dataclass class _CallInfo: index: int @@ -141,6 +166,7 @@ class _CallInfo: kwargs: Dict[str, Any] outputs: Any + @dataclasses.dataclass class _CallInfoContext(threading.local): index: int @@ -151,6 +177,7 @@ def get_call_index(self, module: 'Module') -> int: self.index += 1 return index + @contextlib.contextmanager def _tabulate_context(): _context.call_info_stack.append(_CallInfoContext(0, [])) @@ -159,18 +186,23 @@ def _tabulate_context(): finally: _context.call_info_stack.pop() + # Track parent relationship across Modules. # ----------------------------------------------------------------------------- class _DynamicContext(threading.local): """Dynamic context.""" + # TODO(marcvanzee): switch to using contextvars once minimum python version is # 3.7 def __init__(self): - self.module_stack = [None,] + self.module_stack = [ + None, + ] self.capture_stack = [] self.call_info_stack = [] + # The global context _context = _DynamicContext() @@ -198,6 +230,7 @@ def _get_fn_name(fn): if isinstance(fn, functools.partial): return _get_fn_name(fn.func) return fn.__name__ + fn_name = _get_fn_name(fn) method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' module_name = module.name or module.__class__.__name__ @@ -256,8 +289,7 @@ def _sorted_items(x): return sorted(x.items(), key=lambda x: x[0]) -def _get_suffix_value_pairs( - tree_or_leaf: Any) -> List[Tuple[str, Type['Module']]]: +def _get_suffix_value_pairs(tree_or_leaf: Any) -> List[Tuple[str, Type['Module']]]: """Helper for naming pytrees of submodules.""" dict_or_leaf = serialization.to_state_dict(tree_or_leaf) if not isinstance(dict_or_leaf, dict) or not dict_or_leaf: @@ -274,10 +306,12 @@ def _map_over_modules_in_tree(fn, tree_or_leaf): return fn('', tree_or_leaf) else: flat_dict = traverse_util.flatten_dict(dict_or_leaf, keep_empty_nodes=True) - mapped_flat_dict = {k: fn('_' + '_'.join(k), v) - for k, v in _sorted_items(flat_dict)} + mapped_flat_dict = { + k: fn('_' + '_'.join(k), v) for k, v in _sorted_items(flat_dict) + } return serialization.from_state_dict( - tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict)) + tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict) + ) def _freeze_attr(val: Any) -> Any: @@ -355,8 +389,7 @@ def __call__(self, x): return fun -def _get_local_method_names(cls: Any, - exclude: Iterable[str] = ()) -> Tuple[str, ...]: +def _get_local_method_names(cls: Any, exclude: Iterable[str] = ()) -> Tuple[str, ...]: """Gets method names of a class, excluding class and static methods. Args: @@ -373,8 +406,10 @@ def _get_local_method_names(cls: Any, true_methods.add(m) return tuple(true_methods.difference(set(exclude))) -def _get_local_descriptor_names(cls: Any, - exclude: Iterable[str] = ()) -> Tuple[str, ...]: + +def _get_local_descriptor_names( + cls: Any, exclude: Iterable[str] = () +) -> Tuple[str, ...]: """Gets descriptor names of a class. Args: @@ -386,8 +421,9 @@ def _get_local_descriptor_names(cls: Any, true_properties = set() for m, attr in cls.__dict__.items(): if not callable(attr) and ( - hasattr(attr, '__get__') or hasattr(attr, '__set__') or - hasattr(attr, '__delete__') + hasattr(attr, '__get__') + or hasattr(attr, '__set__') + or hasattr(attr, '__delete__') ): mtype = type(attr) if mtype != staticmethod and mtype != classmethod: @@ -419,10 +455,12 @@ def wrapped_module_method(*args, **kwargs): return self._call_wrapped_method(fun, args, kwargs) else: return fun(*args, **kwargs) + wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] return wrapped_module_method -def wrap_descriptor_once(descriptor) -> "DescriptorWrapper": + +def wrap_descriptor_once(descriptor) -> 'DescriptorWrapper': """Wraps a descriptor to give better error messages. Args: @@ -439,17 +477,21 @@ def wrap_descriptor_once(descriptor) -> "DescriptorWrapper": def _wrap_hash(hash_fn: Callable[..., Any]) -> Callable[..., Any]: """Wraps a hash function with some check for Flax Modules.""" + @functools.wraps(hash_fn) def wrapped(self): if self.scope is not None: - raise TypeError('Can\'t call __hash__ on modules that hold variables.') + raise TypeError("Can't call __hash__ on modules that hold variables.") try: hash_value = hash_fn(self) except TypeError as exc: - raise TypeError('Failed to hash Flax Module. ' - 'The module probably contains unhashable attributes. ' - f'Module={self}') from exc + raise TypeError( + 'Failed to hash Flax Module. ' + 'The module probably contains unhashable attributes. ' + f'Module={self}' + ) from exc return hash_value + return wrapped @@ -465,23 +507,23 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]: Returns: An unbound version of input function. """ - if (inspect.ismethod(method_or_fn) and - isinstance(method_or_fn.__self__, Module)): # pytype: disable=attribute-error + if inspect.ismethod(method_or_fn) and isinstance(method_or_fn.__self__, Module): # pytype: disable=attribute-error method_or_fn = method_or_fn.__func__ # pytype: disable=attribute-error # The method should be callable, and it should have at least one argument # representing the class that is passed in. - if (not callable(method_or_fn) or - len(inspect.signature(method_or_fn).parameters) < 1): + if not callable(method_or_fn) or len(inspect.signature(method_or_fn).parameters) < 1: raise errors.ApplyModuleInvalidMethodError(method_or_fn) return method_or_fn + def _map_submodules(fn: Callable[['Module'], Any], tree): """Map a function over all submodules in a tree.""" g = lambda _, x: fn(x) if isinstance(x, Module) else x return _freeze_attr(_map_over_modules_in_tree(g, tree)) + class SetupState(enum.IntEnum): # setup() has not been called. NEW = 0 @@ -499,13 +541,13 @@ class _ModuleInternalState: Modules for autonaming and error messages here, alongside the rules used to pass this ephemeral state across transform boundaries. """ + in_compact_method: bool = False in_setup: bool = False setup_called: SetupState = SetupState.NEW is_initialized: bool = False autoname_cursor: Dict[str, int] = dataclasses.field(default_factory=dict) - children: Dict[str, Union[str, 'Module']] = dataclasses.field( - default_factory=dict) + children: Dict[str, Union[str, 'Module']] = dataclasses.field(default_factory=dict) def reset(self) -> None: """Resets transient state. @@ -525,7 +567,8 @@ def export(self) -> '_ModuleInternalState': in_setup=self.in_setup, setup_called=setup_state, is_initialized=self.is_initialized, - autoname_cursor=dict(self.autoname_cursor)) + autoname_cursor=dict(self.autoname_cursor), + ) return cloned def reimport(self, other: '_ModuleInternalState') -> None: @@ -535,16 +578,24 @@ def reimport(self, other: '_ModuleInternalState') -> None: self.is_initialized = other.is_initialized self.autoname_cursor = dict(other.autoname_cursor) + _uninitialized_module_internal_state = _ModuleInternalState() _UNDEFINED_COPY_PICKLE_METHODS = ( - '__getstate__', '__setstate__', '__getnewargs_ex__', - '__reduce__', '__reduce_ex__', '__copy__', '__deepcopy__') + '__getstate__', + '__setstate__', + '__getnewargs_ex__', + '__reduce__', + '__reduce_ex__', + '__copy__', + '__deepcopy__', +) _caches: 'weakref.WeakKeyDictionary[Scope, weakref.WeakValueDictionary[FlaxId, Module]]' = ( - weakref.WeakKeyDictionary()) + weakref.WeakKeyDictionary() +) tuple_reduce = lambda xs, x: xs + (x,) @@ -566,28 +617,39 @@ class ParentDescriptor: more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms. """ + def __get__(self, obj, objtype=None): # check if obj is None, happens during %autoreload if obj is None: return None - parent = object.__getattribute__(obj, "_parent_ref") + parent = object.__getattribute__(obj, '_parent_ref') return parent() if isinstance(parent, weakref.ReferenceType) else parent def __set__(self, obj, value): maybe_weak = weakref.ref(value) if isinstance(value, Module) else value - object.__setattr__(obj, "_parent_ref", maybe_weak) + object.__setattr__(obj, '_parent_ref', maybe_weak) class Descriptor(Protocol): __isabstractmethod__: bool - def __get__(self, obj, objtype=None) -> Any: ... - def __set__(self, obj, value) -> None: ... - def __delete__(self, obj) -> None: ... - def __set_name__(self, owner, name) -> None: ... + + def __get__(self, obj, objtype=None) -> Any: + ... + + def __set__(self, obj, value) -> None: + ... + + def __delete__(self, obj) -> None: + ... + + def __set_name__(self, owner, name) -> None: + ... + class DescriptorWrapper: pass + def create_descriptor_wrapper(descriptor: Descriptor): """Creates a descriptor wrapper that calls a get_fn on the descriptor.""" @@ -602,6 +664,7 @@ def __init__(self, wrapped: Descriptor): # conditionally define descriptor methods if hasattr(descriptor, '__get__'): + def __get__(self, *args, **kwargs): # here we will catch internal AttributeError and re-raise it as a # more informative and correct error message. @@ -611,14 +674,17 @@ def __get__(self, *args, **kwargs): raise errors.DescriptorAttributeError() from e if hasattr(descriptor, '__set__'): + def __set__(self, *args, **kwargs): return self.wrapped.__set__(*args, **kwargs) if hasattr(descriptor, '__delete__'): + def __delete__(self, *args, **kwargs): return self.wrapped.__delete__(*args, **kwargs) if hasattr(descriptor, '__set_name__'): + def __set_name__(self, *args, **kwargs): self.wrapped.__set_name__(*args, **kwargs) @@ -627,9 +693,11 @@ def __getattr__(self, name): return _DescriptorWrapper(descriptor) + # Base Module definition. # ----------------------------------------------------------------------------- + # The ModuleBase class is created only to make static analyzers happy # mainly pytype and pyright. Some notes: # * pyright (correctly) complains that Module itself is not a dataclass, even @@ -653,6 +721,7 @@ class ModuleBase: parent: Union['Module', _Sentinel, None] __dataclass_fields__: Dict[str, dataclasses.Field] + class Module(ModuleBase): """Base class for all neural network modules. Layers and models should subclass this class. @@ -687,6 +756,7 @@ def __call__(self, x): """ if typing.TYPE_CHECKING: + def __init__(self, *args, **kwargs): # this stub makes sure pytype accepts constructor arguments. pass @@ -710,11 +780,11 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None: cls._verify_single_or_no_compact() cls._wrap_module_attributes() # Set empty class defaults. - cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined] - cls.scope: Optional[Scope] = None # type: ignore + cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined] + cls.scope: Optional[Scope] = None # type: ignore # Handles weak referencing of parent Modules to prevent reference cycles. - cls._parent_ref = None # type: ignore[attr-defined] - cls.parent = ParentDescriptor() # type: ignore[assignment] + cls._parent_ref = None # type: ignore[attr-defined] + cls.parent = ParentDescriptor() # type: ignore[assignment] @classmethod def _customized_dataclass_transform(cls, kw_only: bool): @@ -742,12 +812,16 @@ def _customized_dataclass_transform(cls, kw_only: bool): field_meta.hash = False field_meta.repr = False - extra_fields = [('parent', _ParentType, - kw_only_dataclasses.field( - repr=False, default=_unspecified_parent, - kw_only=True)), - ('name', Optional[str], - kw_only_dataclasses.field(default=None, kw_only=True))] + extra_fields = [ + ( + 'parent', + _ParentType, + kw_only_dataclasses.field( + repr=False, default=_unspecified_parent, kw_only=True + ), + ), + ('name', Optional[str], kw_only_dataclasses.field(default=None, kw_only=True)), + ] if kw_only: if tuple(sys.version_info)[:3] >= (3, 10, 0): @@ -758,7 +832,9 @@ def _customized_dataclass_transform(cls, kw_only: bool): unsafe_hash='__hash__' not in cls.__dict__, repr=False, kw_only=True, - )(cls) # type: ignore[call-overload] + )( + cls + ) # type: ignore[call-overload] else: raise TypeError('`kw_only` is not available before Py 3.10.') else: @@ -768,7 +844,8 @@ def _customized_dataclass_transform(cls, kw_only: bool): cls, unsafe_hash='__hash__' not in cls.__dict__, repr=False, - extra_fields=extra_fields) # pytype: disable=wrong-keyword-args + extra_fields=extra_fields, + ) # pytype: disable=wrong-keyword-args cls.__hash__ = _wrap_hash(cls.__hash__) # type: ignore[method-assign] @@ -776,8 +853,13 @@ def _customized_dataclass_transform(cls, kw_only: bool): def _verify_single_or_no_compact(cls): """Statically verifies that at most a single method is labelled compact.""" methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)] - n_compact_fns = len([method_name for method_name in methods - if hasattr(getattr(cls, method_name), 'compact')]) + n_compact_fns = len( + [ + method_name + for method_name in methods + if hasattr(getattr(cls, method_name), 'compact') + ] + ) if n_compact_fns > 1: raise errors.MultipleMethodsCompactError() @@ -787,9 +869,13 @@ def _wrap_module_attributes(cls): management functions. """ # wrap methods - method_exclusions = ([f.name for f in dataclasses.fields(cls)] + - ['__eq__', '__repr__', '__init__', '__hash__', - '__post_init__']) + method_exclusions = [f.name for f in dataclasses.fields(cls)] + [ + '__eq__', + '__repr__', + '__init__', + '__hash__', + '__post_init__', + ] for key in _get_local_method_names(cls, exclude=method_exclusions): method = getattr(cls, key) if hasattr(method, 'nowrap'): @@ -797,8 +883,10 @@ def _wrap_module_attributes(cls): setattr(cls, key, wrap_method_once(method)) # wrap descriptors - descriptor_exclusions = ([f.name for f in dataclasses.fields(cls)] + - ['parent', '__dict__']) + descriptor_exclusions = [f.name for f in dataclasses.fields(cls)] + [ + 'parent', + '__dict__', + ] for key in _get_local_descriptor_names(cls, descriptor_exclusions): # don't use getattr here, since it will call the descriptor descriptor = cls.__dict__[key] @@ -808,7 +896,7 @@ def _wrap_module_attributes(cls): return cls def _call_wrapped_method(self, fun, args, kwargs): - """"Calls a wrapped method. + """ "Calls a wrapped method. This function is responsible for setting up the thread local state correctly before calling the method and cleaning up afterwards. @@ -863,7 +951,10 @@ def _call_wrapped_method(self, fun, args, kwargs): if add_call_info: _args, _kwargs, _y = flax.linen.summary._represent_tree((args, kwargs, y)) _context.call_info_stack[-1].calls.append( - _CallInfo(call_index, scope_path, type(self), fun.__name__, _args, _kwargs, _y)) + _CallInfo( + call_index, scope_path, type(self), fun.__name__, _args, _kwargs, _y + ) + ) return y finally: _context.module_stack.pop() @@ -902,8 +993,7 @@ def __setattr__(self, name: str, val: Any): else: # We're past all initialization and setup logic: # Raises a TypeError just like frozen python dataclasses. - raise errors.SetAttributeFrozenModuleError( - self.__class__.__name__, name, val) + raise errors.SetAttributeFrozenModuleError(self.__class__.__name__, name, val) # We're inside the setup() method: if is_dataclass_attr: @@ -927,8 +1017,10 @@ def __getattr__(self, name: str) -> Any: else: msg = f'"{self.__class__.__name__}" object has no attribute "{name}".' if self.scope is None: - msg += (f' If "{name}" is defined in \'.setup()\', remember these fields ' - 'are only accessible from inside \'init\' or \'apply\'.') + msg += ( + f' If "{name}" is defined in \'.setup()\', remember these fields ' + "are only accessible from inside 'init' or 'apply'." + ) raise AttributeError(msg) def __dir__(self) -> List[str]: @@ -973,8 +1065,10 @@ def __post_init__(self) -> None: self.name = f'{prefix}_{cursor}' self.parent._state.autoname_cursor[prefix] = cursor + 1 # Allow scope aliasing under transforms for submodules defined in setup. - reuse_scopes = (self.parent._state.in_setup and - self.parent._state.setup_called == SetupState.TRANSFORMED) + reuse_scopes = ( + self.parent._state.in_setup + and self.parent._state.setup_called == SetupState.TRANSFORMED + ) # Perform name-collision check. if self.parent._name_taken(self.name, self, reuse_scopes=reuse_scopes): parent_class = self.parent.__class__.__name__ @@ -983,7 +1077,8 @@ def __post_init__(self) -> None: self.parent._state.children[self.name] = self assert self.parent.scope is not None object.__setattr__( - self, 'scope', self.parent.scope.push(self.name, reuse=reuse_scopes)) + self, 'scope', self.parent.scope.push(self.name, reuse=reuse_scopes) + ) # Top-level invocation with a functional Scope. elif isinstance(self.parent, Scope): @@ -993,9 +1088,9 @@ def __post_init__(self) -> None: # eagerly bind submodules if scope is available if self.scope is not None: - for field in dataclasses.fields(self): - if field.name not in ('parent', 'name') and field.init: - self._register_submodules(field.name, getattr(self, field.name)) + for field in dataclasses.fields(self): + if field.name not in ('parent', 'name') and field.init: + self._register_submodules(field.name, getattr(self, field.name)) self._state.is_initialized = True @@ -1047,6 +1142,7 @@ def _register_submodules(self, name, val): preserve_adopted_names = config.flax_preserve_adopted_names if hasattr(type(self), 'preserve_adopted_names'): preserve_adopted_names = type(self).preserve_adopted_names + def adopt_attr_modules(cache, queue, suffix, subvalue): if isinstance(subvalue, Module): adopted_name = None @@ -1072,17 +1168,23 @@ def adopt_attr_modules(cache, queue, suffix, subvalue): object.__setattr__(subvalue, 'name', adopted_name) queue.append(subvalue) return subvalue - val = _freeze_attr(_map_over_modules_in_tree( - functools.partial(adopt_attr_modules, cache, queue), val)) + + val = _freeze_attr( + _map_over_modules_in_tree( + functools.partial(adopt_attr_modules, cache, queue), val + ) + ) object.__setattr__(self, name, val) for x in queue: x.__post_init__() def _try_setup(self, shallow: bool = False) -> None: """Tries to setup module if scope is available and setup has not been called yet.""" - if (self.scope + if ( + self.scope and not self._state.in_setup - and self._state.setup_called != SetupState.DONE): + and self._state.setup_called != SetupState.DONE + ): try: self._state.in_setup = True # A shallow setup will only register attribute submodules but it does @@ -1104,17 +1206,21 @@ def _try_setup(self, shallow: bool = False) -> None: def _validate_setup(self) -> None: """Abstractly evaluates setup only to run static checks.""" + def run_setup_only(x): wrapped_id = wrap_method_once(lambda m, x: x) with TestScope({}, rngs={}, mutable=True).temporary() as root: return wrapped_id(self.clone(parent=root), x) + _ = jax.eval_shape(run_setup_only, 0) - def _name_taken(self, - name: str, - module: Optional['Module'] = None, - reuse_scopes: bool = False, - collection: Optional[str] = None) -> bool: + def _name_taken( + self, + name: str, + module: Optional['Module'] = None, + reuse_scopes: bool = False, + collection: Optional[str] = None, + ) -> bool: assert self.scope is not None if reuse_scopes: return False @@ -1122,14 +1228,19 @@ def _name_taken(self, @property def _initialization_allowed(self): - return (not self._state.is_initialized # allow eager attachment in post-init - or self._state.in_setup - or self._state.in_compact_method) - - def clone(self: M, *, - parent: Optional[Union[Scope, 'Module']] = None, - _deep_clone: Union[bool, weakref.WeakValueDictionary] = False, - **updates) -> M: + return ( + not self._state.is_initialized # allow eager attachment in post-init + or self._state.in_setup + or self._state.in_compact_method + ) + + def clone( + self: M, + *, + parent: Optional[Union[Scope, 'Module']] = None, + _deep_clone: Union[bool, weakref.WeakValueDictionary] = False, + **updates, + ) -> M: """Creates a clone of this Module, with optionally updated arguments. Args: @@ -1154,7 +1265,12 @@ def clone(self: M, *, if _deep_clone != False: # We use a weak value dictionary to cache cloned submodules. When a shared # submodule is cloned, its only cloned once else its fetched from the cache. - cache = weakref.WeakValueDictionary() if isinstance(_deep_clone, bool) else _deep_clone + cache = ( + weakref.WeakValueDictionary() + if isinstance(_deep_clone, bool) + else _deep_clone + ) + def clone_fn(m: Module) -> Module: if hasattr(m, '_id'): key = m._id @@ -1178,10 +1294,14 @@ def clone_fn(m: Module) -> Module: return module - def variable(self, col: str, name: str, - init_fn: Optional[Callable[..., Any]] = None, - *init_args, - unbox: bool = True) -> Variable: + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., Any]] = None, + *init_args, + unbox: bool = True, + ) -> Variable: """Declares and returns a variable in this Module. See :mod:`flax.core.variables` for more information. See also :meth:`param` @@ -1216,7 +1336,8 @@ def variable(self, col: str, name: str, if not self._initialization_allowed: raise ValueError( 'Variables must be initialized in `setup()` or in a method ' - 'wrapped in `@compact`') + 'wrapped in `@compact`' + ) if self._name_taken(name, collection=col): raise errors.NameInUseError('variable', name, self.__class__.__name__) assert self.scope is not None @@ -1224,8 +1345,9 @@ def variable(self, col: str, name: str, self._state.children[name] = col return v - def param(self, name: str, init_fn: Callable[..., T], *init_args, - unbox: bool = True) -> T: + def param( + self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True + ) -> T: """Declares and returns a parameter in this Module. Parameters are read-only variables in the collection named "params". See @@ -1257,7 +1379,8 @@ def param(self, name: str, init_fn: Callable[..., T], *init_args, if not self._initialization_allowed: raise ValueError( 'Parameters must be initialized in `setup()` or in a method ' - 'wrapped in `@compact`') + 'wrapped in `@compact`' + ) if self._name_taken(name, collection='params'): raise errors.NameInUseError('param', name, self.__class__.__name__) assert self.scope is not None @@ -1332,15 +1455,17 @@ def _module_checks(self): raise errors.InvalidInstanceModuleError() overridden_post_init = self.__post_init__ != Module.__post_init__ - if overridden_post_init and not hasattr(self, "_id"): + if overridden_post_init and not hasattr(self, '_id'): raise errors.IncorrectPostInitOverrideError() @traceback_util.api_boundary - def bind(self: M, - variables: VariableDict, - *args, - rngs: Optional[RNGSequences] = None, - mutable: CollectionFilter = False) -> M: + def bind( + self: M, + variables: VariableDict, + *args, + rngs: Optional[RNGSequences] = None, + mutable: CollectionFilter = False, + ) -> M: """Creates an interactive Module instance by binding variables and RNGs. ``bind`` provides an "interactive" instance of a Module directly without @@ -1433,14 +1558,16 @@ def __call__(self, x): return module, variables @traceback_util.api_boundary - def apply(self, - variables: VariableDict, - *args, - rngs: Optional[RNGSequences] = None, - method: Union[Callable[..., Any], str, None] = None, - mutable: CollectionFilter = False, - capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, - **kwargs) -> Union[Any, Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: + def apply( + self, + variables: VariableDict, + *args, + rngs: Optional[RNGSequences] = None, + method: Union[Callable[..., Any], str, None] = None, + mutable: CollectionFilter = False, + capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + **kwargs, + ) -> Union[Any, Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: """Applies a module method to variables and returns output and modified variables. Note that `method` should be set if one would like to call `apply` on a @@ -1505,24 +1632,29 @@ def other_fn(instance, ...): method = getattr(self, attribute_name) if not callable(method): class_name = type(self).__name__ - raise TypeError(f'\'{class_name}.{attribute_name}\' must be a callable, got {type(method)}.') + raise TypeError( + f"'{class_name}.{attribute_name}' must be a callable, got {type(method)}." + ) elif method is None: method = self.__call__ method = _get_unbound_fn(method) return apply( - method, self, + method, + self, mutable=mutable, capture_intermediates=capture_intermediates, )(variables, *args, **kwargs, rngs=rngs) @traceback_util.api_boundary - def init_with_output(self, - rngs: Union[KeyArray, RNGSequences], - *args, - method: Union[Callable[..., Any], str, None] = None, - mutable: CollectionFilter = DenyList('intermediates'), - capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, - **kwargs) -> Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]: + def init_with_output( + self, + rngs: Union[KeyArray, RNGSequences], + *args, + method: Union[Callable[..., Any], str, None] = None, + mutable: CollectionFilter = DenyList('intermediates'), + capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + **kwargs, + ) -> Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]: """Initializes a module method with variables and returns output and modified variables. Args: @@ -1553,7 +1685,8 @@ def init_with_output(self, if not core.scope._is_valid_rng(rngs): raise errors.InvalidRngError( 'RNGs should be of shape (2,) or KeyArray in Module ' - f'{self.__class__.__name__}, but rngs are: {rngs}') + f'{self.__class__.__name__}, but rngs are: {rngs}' + ) rngs = {'params': rngs} if isinstance(method, str): @@ -1561,25 +1694,26 @@ def init_with_output(self, method = getattr(self, attribute_name) if not callable(method): class_name = type(self).__name__ - raise TypeError(f'\'{class_name}.{attribute_name}\' must be a callable, got {type(method)}.') + raise TypeError( + f"'{class_name}.{attribute_name}' must be a callable, got {type(method)}." + ) elif method is None: method = self.__call__ method = _get_unbound_fn(method) return init_with_output( - method, - self, - mutable=mutable, - capture_intermediates=capture_intermediates + method, self, mutable=mutable, capture_intermediates=capture_intermediates )(rngs, *args, **kwargs) @traceback_util.api_boundary - def init(self, - rngs: Union[KeyArray, RNGSequences], - *args, - method: Union[Callable[..., Any], str, None] = None, - mutable: CollectionFilter = DenyList('intermediates'), - capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, - **kwargs) -> Union[FrozenVariableDict, Dict[str, Any]]: + def init( + self, + rngs: Union[KeyArray, RNGSequences], + *args, + method: Union[Callable[..., Any], str, None] = None, + mutable: CollectionFilter = DenyList('intermediates'), + capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + **kwargs, + ) -> Union[FrozenVariableDict, Dict[str, Any]]: """Initializes a module method with variables and returns modified variables. ``init`` takes as first argument either a single ``PRNGKey``, or a dictionary mapping variable collections names to their ``PRNGKeys``, and will call ``method`` (which is the module's ``__call__`` function by default) passing ``*args`` and ``**kwargs``, and returns @@ -1666,16 +1800,19 @@ def init(self, method=method, mutable=mutable, capture_intermediates=capture_intermediates, - **kwargs) + **kwargs, + ) return v_out @traceback_util.api_boundary - def lazy_init(self, - rngs: Union[KeyArray, RNGSequences], - *args, - method: Optional[Callable[..., Any]] = None, - mutable: CollectionFilter = DenyList('intermediates'), - **kwargs) -> FrozenVariableDict: + def lazy_init( + self, + rngs: Union[KeyArray, RNGSequences], + *args, + method: Optional[Callable[..., Any]] = None, + mutable: CollectionFilter = DenyList('intermediates'), + **kwargs, + ) -> FrozenVariableDict: """Initializes a module without computing on an actual input. lazy_init will initialize the variables without doing unnecessary compute. @@ -1710,8 +1847,10 @@ def lazy_init(self, The initialized variable dict. """ Module._module_checks(self) + def lazy_wrapper(rngs, *args, **kwargs): return self.init(rngs, *args, method=method, mutable=mutable, **kwargs) + return partial_eval.lazy_init(lazy_wrapper)(rngs, *args, **kwargs) @property @@ -1755,14 +1894,24 @@ def sow(self, col: str, name: str, value: Any) -> bool: ... @overload - def sow(self, col: str, name: str, value: T, - reduce_fn: Callable[[K, T], K] = tuple_reduce, - init_fn: Callable[[], K] = tuple_init) -> bool: # type: ignore + def sow( + self, + col: str, + name: str, + value: T, + reduce_fn: Callable[[K, T], K] = tuple_reduce, + init_fn: Callable[[], K] = tuple_init, + ) -> bool: # type: ignore ... - def sow(self, col: str, name: str, value: T, - reduce_fn: Callable[[K, T], K] = tuple_reduce, - init_fn: Callable[[], K] = tuple_init) -> bool: # type: ignore + def sow( + self, + col: str, + name: str, + value: T, + reduce_fn: Callable[[K, T], K] = tuple_reduce, + init_fn: Callable[[], K] = tuple_init, + ) -> bool: # type: ignore """Stores a value in a collection. Collections can be used to collect intermediate values without @@ -1885,28 +2034,31 @@ def loss(params, perturbations, inputs, targets): model.apply({'params': params}, x) # behaves like a no-op """ + def _root_has_collection(): """Returns True if the root scope has the collection.""" assert self.scope is not None return collection in self.scope.root._variables + # we will only add the perturbation variable if the collection is mutable # (e.g. during `init`) or if the collection was passed to `apply` (contained in # the root scope). if self.is_mutable_collection(collection) or _root_has_collection(): - value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value # type: ignore + value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value # type: ignore return value def tabulate( - self, - rngs: Union[KeyArray, RNGSequences], - *args, - depth: Optional[int] = None, - show_repeated: bool = False, - mutable: CollectionFilter = True, - console_kwargs: Optional[Mapping[str, Any]] = None, - table_kwargs: Mapping[str, Any] = MappingProxyType({}), - column_kwargs: Mapping[str, Any] = MappingProxyType({}), - **kwargs) -> str: + self, + rngs: Union[KeyArray, RNGSequences], + *args, + depth: Optional[int] = None, + show_repeated: bool = False, + mutable: CollectionFilter = True, + console_kwargs: Optional[Mapping[str, Any]] = None, + table_kwargs: Mapping[str, Any] = MappingProxyType({}), + column_kwargs: Mapping[str, Any] = MappingProxyType({}), + **kwargs, + ) -> str: """Creates a summary of the Module represented as a table. This method has the same signature and internally calls `Module.init`, @@ -1993,13 +2145,21 @@ def __call__(self, x): from flax.linen import summary tabulate_fn = summary.tabulate( - self, rngs, depth=depth, show_repeated=show_repeated, mutable=mutable, - console_kwargs=console_kwargs, table_kwargs=table_kwargs, column_kwargs=column_kwargs) + self, + rngs, + depth=depth, + show_repeated=show_repeated, + mutable=mutable, + console_kwargs=console_kwargs, + table_kwargs=table_kwargs, + column_kwargs=column_kwargs, + ) return tabulate_fn(*args, **kwargs) _ParentType = Union[Type[Module], Type[Scope], Type[_Sentinel], None] + def merge_param(name: str, a: Optional[T], b: Optional[T]) -> T: """Merges construction- and call-time argument. @@ -2027,10 +2187,14 @@ def __call__(self, train: Optional[bool] = None): """ if a is None and b is None: - raise ValueError(f'Parameter "{name}" must be passed to the constructor or at call time.') + raise ValueError( + f'Parameter "{name}" must be passed to the constructor or at call time.' + ) if a is not None and b is not None: - raise ValueError(f'Parameter "{name}" was passed to the constructor and at call time.' - ' Should be passed just once.') + raise ValueError( + f'Parameter "{name}" was passed to the constructor and at call time.' + ' Should be passed just once.' + ) if a is None: assert b is not None return b @@ -2038,10 +2202,12 @@ def __call__(self, train: Optional[bool] = None): @traceback_util.api_boundary -def apply(fn: Callable[..., Any], module: Module, - mutable: CollectionFilter = False, - capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, - ) -> Callable[..., Any]: +def apply( + fn: Callable[..., Any], + module: Module, + mutable: CollectionFilter = False, + capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, +) -> Callable[..., Any]: """Creates an apply function to call ``fn`` with a bound module. Unlike ``Module.apply`` this function returns a new function with the signature @@ -2082,6 +2248,7 @@ def f(foo, x): Returns: The apply function wrapping ``fn``. """ + @functools.wraps(fn) def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) @@ -2098,10 +2265,12 @@ def scope_fn(scope, *args, **kwargs): @traceback_util.api_boundary -def init_with_output(fn: Callable[..., Any], module: Module, - mutable: CollectionFilter = DenyList('intermediates'), - capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, - ) -> Callable[..., Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: +def init_with_output( + fn: Callable[..., Any], + module: Module, + mutable: CollectionFilter = DenyList('intermediates'), + capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, +) -> Callable[..., Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: """Creates an init function to call ``fn`` with a bound module that also returns the function outputs. Unlike ``Module.init_with_output`` this function returns a new function with the signature @@ -2143,6 +2312,7 @@ def f(foo, x): Returns: The init function wrapping ``fn``. """ + @functools.wraps(fn) def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) @@ -2159,10 +2329,12 @@ def scope_fn(scope, *args, **kwargs): @traceback_util.api_boundary -def init(fn: Callable[..., Any], module: Module, - mutable: CollectionFilter = DenyList('intermediates'), - capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, - ) -> Callable[..., Union[FrozenVariableDict, Dict[str, Any]]]: +def init( + fn: Callable[..., Any], + module: Module, + mutable: CollectionFilter = DenyList('intermediates'), + capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, +) -> Callable[..., Union[FrozenVariableDict, Dict[str, Any]]]: """Creates an init function to call ``fn`` with a bound module. Unlike ``Module.init`` this function returns a new function with the signature @@ -2205,7 +2377,9 @@ def f(foo, x): The init function wrapping ``fn``. """ init_fn = init_with_output(fn, module, mutable, capture_intermediates) + @functools.wraps(init_fn) def init_wrapper(*args, **kwargs): return init_fn(*args, **kwargs)[1] + return init_wrapper diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 2c318a99d9..ef7af34084 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -110,14 +110,22 @@ def mean(x, axes=axes): return mu, var -def _normalize(mdl: Module, x: Array, mean: Array, var: Array, - reduction_axes: Axes, feature_axes: Axes, - dtype: Dtype, param_dtype: Dtype, - epsilon: float, - use_bias: bool, use_scale: bool, - bias_init: Callable[[PRNGKey, Shape, Dtype], Array], - scale_init: Callable[[PRNGKey, Shape, Dtype], Array]): - """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias. +def _normalize( + mdl: Module, + x: Array, + mean: Array, + var: Array, + reduction_axes: Axes, + feature_axes: Axes, + dtype: Dtype, + param_dtype: Dtype, + epsilon: float, + use_bias: bool, + use_scale: bool, + bias_init: Callable[[PRNGKey, Shape, Dtype], Array], + scale_init: Callable[[PRNGKey, Shape, Dtype], Array], +): + """ "Normalizes the input of a normalization layer and optionally applies a learned scale and bias. Arguments: mdl: Module to apply the normalization in (normalization params will reside @@ -153,14 +161,16 @@ def _normalize(mdl: Module, x: Array, mean: Array, var: Array, mul = lax.rsqrt(var + epsilon) args = [x] if use_scale: - scale = mdl.param('scale', scale_init, reduced_feature_shape, - param_dtype).reshape(feature_shape) + scale = mdl.param('scale', scale_init, reduced_feature_shape, param_dtype).reshape( + feature_shape + ) mul *= scale args.append(scale) y *= mul if use_bias: - bias = mdl.param('bias', bias_init, reduced_feature_shape, - param_dtype).reshape(feature_shape) + bias = mdl.param('bias', bias_init, reduced_feature_shape, param_dtype).reshape( + feature_shape + ) y += bias args.append(bias) dtype = canonicalize_dtype(*args, dtype=dtype) @@ -220,6 +230,7 @@ class BatchNorm(Module): the examples on the first two and last two devices. See `jax.lax.psum` for more details. """ + use_running_average: Optional[bool] = None axis: int = -1 momentum: float = 0.99 @@ -254,37 +265,49 @@ def __call__(self, x, use_running_average: Optional[bool] = None): """ use_running_average = merge_param( - 'use_running_average', self.use_running_average, use_running_average) + 'use_running_average', self.use_running_average, use_running_average + ) feature_axes = _canonicalize_axes(x.ndim, self.axis) reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) feature_shape = [x.shape[ax] for ax in feature_axes] - ra_mean = self.variable('batch_stats', 'mean', - lambda s: jnp.zeros(s, jnp.float32), - feature_shape) - ra_var = self.variable('batch_stats', 'var', - lambda s: jnp.ones(s, jnp.float32), - feature_shape) + ra_mean = self.variable( + 'batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), feature_shape + ) + ra_var = self.variable( + 'batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), feature_shape + ) if use_running_average: mean, var = ra_mean.value, ra_var.value else: mean, var = _compute_stats( - x, reduction_axes, + x, + reduction_axes, dtype=self.dtype, axis_name=self.axis_name if not self.is_initializing() else None, - axis_index_groups=self.axis_index_groups) + axis_index_groups=self.axis_index_groups, + ) if not self.is_initializing(): - ra_mean.value = self.momentum * ra_mean.value + (1 - - self.momentum) * mean + ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var return _normalize( - self, x, mean, var, reduction_axes, feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, + x, + mean, + var, + reduction_axes, + feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ) class LayerNorm(Module): @@ -319,6 +342,7 @@ class LayerNorm(Module): use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. """ + epsilon: float = 1e-6 dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 @@ -352,10 +376,20 @@ def __call__(self, x): ) return _normalize( - self, x, mean, var, self.reduction_axes, self.feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, + x, + mean, + var, + self.reduction_axes, + self.feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ) class RMSNorm(Module): @@ -397,6 +431,7 @@ class RMSNorm(Module): the examples on the first two and last two devices. See `jax.lax.psum` for more details. """ + epsilon: float = 1e-6 dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 @@ -417,50 +452,66 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ - mean, var = _compute_stats(x, self.reduction_axes, self.dtype, - self.axis_name, self.axis_index_groups, - use_mean=False) + mean, var = _compute_stats( + x, + self.reduction_axes, + self.dtype, + self.axis_name, + self.axis_index_groups, + use_mean=False, + ) return _normalize( - self, x, mean, var, self.reduction_axes, self.feature_axes, - self.dtype, self.param_dtype, self.epsilon, - False, self.use_scale, - initializers.zeros, self.scale_init) + self, + x, + mean, + var, + self.reduction_axes, + self.feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + False, + self.use_scale, + initializers.zeros, + self.scale_init, + ) class GroupNorm(Module): """Group normalization (arxiv.org/abs/1803.08494). - This op is similar to batch normalization, but statistics are shared across - equally-sized groups of channels and not shared across batch dimension. - Thus, group normalization does not depend on the batch composition and does - not require maintaining internal state for storing statistics. - The user should either specify the total number of channel groups or the - number of channels per group. - - Attributes: - num_groups: the total number of channel groups. The default value of 32 is - proposed by the original group normalization paper. - group_size: the number of channels in a group. - epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the result (default: infer from input and params). - param_dtype: the dtype passed to parameter initializers (default: float32). - use_bias: If True, bias (beta) is added. - use_scale: If True, multiply by scale (gamma). When the next layer is - linear (also e.g. nn.relu), this can be disabled since the scaling will - be done by the next layer. - bias_init: Initializer for bias, by default, zero. - scale_init: Initializer for scale, by default, one. - axis_name: the axis name used to combine batch statistics from multiple - devices. See `jax.pmap` for a description of axis names (default: None). - This is only needed if the model is subdivided across devices, i.e. the - array being normalized is sharded across devices within a pmap. - axis_index_groups: groups of axis indices within that named axis - representing subsets of devices to reduce over (default: None). For - example, `[[0, 1], [2, 3]]` would independently batch-normalize over - the examples on the first two and last two devices. See `jax.lax.psum` - for more details. + This op is similar to batch normalization, but statistics are shared across + equally-sized groups of channels and not shared across batch dimension. + Thus, group normalization does not depend on the batch composition and does + not require maintaining internal state for storing statistics. + The user should either specify the total number of channel groups or the + number of channels per group. + + Attributes: + num_groups: the total number of channel groups. The default value of 32 is + proposed by the original group normalization paper. + group_size: the number of channels in a group. + epsilon: A small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: If True, bias (beta) is added. + use_scale: If True, multiply by scale (gamma). When the next layer is + linear (also e.g. nn.relu), this can be disabled since the scaling will + be done by the next layer. + bias_init: Initializer for bias, by default, zero. + scale_init: Initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + This is only needed if the model is subdivided across devices, i.e. the + array being normalized is sharded across devices within a pmap. + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. """ + num_groups: Optional[int] = 32 group_size: Optional[int] = None epsilon: float = 1e-6 @@ -488,38 +539,59 @@ def __call__(self, x): reduction_axes = list(range(1, x.ndim - 1)) + [-1] feature_axes = (-1,) - if ((self.num_groups is None and self.group_size is None) or - (self.num_groups is not None and self.group_size is not None)): - raise ValueError('Either `num_groups` or `group_size` should be ' - 'specified. If `group_size` is to be specified, ' - 'pass `num_groups=None` as argument to override ' - 'the default `num_groups` value of 32.') + if (self.num_groups is None and self.group_size is None) or ( + self.num_groups is not None and self.group_size is not None + ): + raise ValueError( + 'Either `num_groups` or `group_size` should be ' + 'specified. If `group_size` is to be specified, ' + 'pass `num_groups=None` as argument to override ' + 'the default `num_groups` value of 32.' + ) channels = x.shape[-1] if self.group_size is not None: if channels % self.group_size != 0: - raise ValueError('Number of channels ({}) is not multiple of the ' - 'group size ({}).'.format(channels, self.group_size)) + raise ValueError( + 'Number of channels ({}) is not multiple of the ' + 'group size ({}).'.format(channels, self.group_size) + ) num_groups = channels // self.group_size else: num_groups = self.num_groups assert isinstance(num_groups, int) if num_groups <= 0 or channels % num_groups != 0: - raise ValueError('Number of groups ({}) does not divide the number' - ' of channels ({}).'.format(num_groups, channels)) + raise ValueError( + 'Number of groups ({}) does not divide the number' + ' of channels ({}).'.format(num_groups, channels) + ) group_size = x.shape[-1] // num_groups group_shape = x.shape[:-1] + (num_groups, group_size) mean, var = _compute_stats( - x.reshape(group_shape), reduction_axes, self.dtype, self.axis_name, - self.axis_index_groups) + x.reshape(group_shape), + reduction_axes, + self.dtype, + self.axis_name, + self.axis_index_groups, + ) mean = jnp.repeat(mean, group_size, axis=-1) var = jnp.repeat(var, group_size, axis=-1) return _normalize( - self, x, mean, var, reduction_axes[:-1], feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, + x, + mean, + var, + reduction_axes[:-1], + feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ) diff --git a/flax/linen/partitioning.py b/flax/linen/partitioning.py index 29b1918786..822789d9e7 100644 --- a/flax/linen/partitioning.py +++ b/flax/linen/partitioning.py @@ -77,6 +77,7 @@ @struct.dataclass class AxisMetadata: """Contains a tuple of axis names, which is passed through FLAX.""" + names: LogicalPartitionSpecPytree = struct.field(pytree_node=False) @@ -102,8 +103,9 @@ def _param_with_axes_sow_reduce_fn(x, y): if isinstance(x, AxisMetadata): if x != y: - raise ValueError('If axis names are sown twice, expected them to match. ' - f'Got {x} and {y}.') + raise ValueError( + 'If axis names are sown twice, expected them to match. ' f'Got {x} and {y}.' + ) elif x: # Shouldn't happen, so raise a fairly internal error. raise AssertionError(f'Non-initial-or-AxisMetadata value encountered: {x}') @@ -115,7 +117,8 @@ def param_with_axes( init_fn, *init_args, axes: Optional[Tuple[str, ...]] = None, - module: Optional['nn.Module'] = None): + module: Optional['nn.Module'] = None, +): """Declares and returns a parameter with logical axes in the current Module. See :mod:`flax.linen.module.param` for original docstring. @@ -145,12 +148,16 @@ def param_with_axes( module_param = module.param(name, init_fn, *init_args) if axes is not None: # apply logical axis constraint immediately - module_param = with_sharding_constraint(module_param, - jax.sharding.PartitionSpec(*axes)) + module_param = with_sharding_constraint( + module_param, jax.sharding.PartitionSpec(*axes) + ) # record logical axis constraint for global axis metadata module.sow( - 'params_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore - reduce_fn=_param_with_axes_sow_reduce_fn) + 'params_axes', + f'{name}_axes', + AxisMetadata(axes), # type: ignore + reduce_fn=_param_with_axes_sow_reduce_fn, + ) return module_param @@ -164,12 +171,14 @@ class PartitionedVariable(flax.core.scope.Variable): and assignment. """ - def __init__(self, - scope, - collection: str, - name: str, - axes: Optional[Tuple[str, ...]] = None, - fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED): + def __init__( + self, + scope, + collection: str, + name: str, + axes: Optional[Tuple[str, ...]] = None, + fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, + ): """Initializes a partitioned variable. Args: @@ -208,7 +217,8 @@ def _core_variable_with_axes( init_fn: Callable[..., Any], *init_args, axes: Optional[Tuple[str, ...]] = None, - fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED): + fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, +): """Variant of flax core variable scope call with sharding constraints.""" scope.reserve(name) if not scope.has_variable(col, name): @@ -228,7 +238,8 @@ def variable_with_axes( *init_args, axes: Optional[Tuple[str, ...]] = None, module: Optional['nn.Module'] = None, - fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED): + fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, +): """Declares and returns a variable with logical axes in the current Module. See :mod:`flax.linen.module.variable` for original docstring. @@ -258,18 +269,16 @@ def variable_with_axes( module = nn.module._context.module_stack[-1] # pylint: disable=protected-access assert module is not None module_var = _core_variable_with_axes( - module.scope, - collection, - name, - init_fn, - *init_args, - axes=axes, - fallback=fallback) + module.scope, collection, name, init_fn, *init_args, axes=axes, fallback=fallback + ) if axes is not None: # record logical axis constraint for global axis metadata module.sow( - f'{collection}_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore - reduce_fn=_param_with_axes_sow_reduce_fn) + f'{collection}_axes', + f'{name}_axes', + AxisMetadata(axes), # type: ignore + reduce_fn=_param_with_axes_sow_reduce_fn, + ) return module_var @@ -285,17 +294,19 @@ def get_axis_names(axes_metadata): suffix on variable names removed to match original variable collection for annotations. """ + def leaf_rewrite(x): return None if x is None else jax.sharding.PartitionSpec(*x) + def rewrite(tree): return jax.tree_util.tree_map(leaf_rewrite, tree, is_leaf=_is_logical_spec) + axes_metadata = unfreeze(axes_metadata) # pytype: disable=wrong-arg-types flat_dict = { re.sub(r'_axes$', '', '/'.join(k)): rewrite(v.names) for k, v in flatten_dict(axes_metadata).items() } - return freeze(unflatten_dict( - {tuple(k.split('/')): v for k, v in flat_dict.items()})) + return freeze(unflatten_dict({tuple(k.split('/')): v for k, v in flat_dict.items()})) # Metadata Aware Scan @@ -306,7 +317,8 @@ def _tree_map_axes(fn, tree): """Only map over AxisMetadata leaves in pytree - identity for other leaves.""" safe_fn = lambda x: fn(x) if isinstance(x, AxisMetadata) else x return jax.tree_util.tree_map( - safe_fn, tree, is_leaf=lambda x: isinstance(x, AxisMetadata)) + safe_fn, tree, is_leaf=lambda x: isinstance(x, AxisMetadata) + ) def _is_mutable(axis_col: str) -> bool: @@ -346,8 +358,9 @@ def insert_fn_leaf(names): return tuple(names) def insert_fn(x): - new_names = jax.tree_util.tree_map(insert_fn_leaf, x.names, - is_leaf=_is_logical_spec) + new_names = jax.tree_util.tree_map( + insert_fn_leaf, x.names, is_leaf=_is_logical_spec + ) return x.replace(names=new_names) def remove_fn_leaf(names): @@ -355,14 +368,17 @@ def remove_fn_leaf(names): return names names = list(names) if names[axis_pos] != axis_name: - raise ValueError(f'Expected axis {axis_name} at position {axis_pos} in ' - f'axis metadata {names}.') + raise ValueError( + f'Expected axis {axis_name} at position {axis_pos} in ' + f'axis metadata {names}.' + ) names.pop(axis_pos) return tuple(names) def remove_fn(x): - new_names = jax.tree_util.tree_map(remove_fn_leaf, x.names, - is_leaf=_is_logical_spec) + new_names = jax.tree_util.tree_map( + remove_fn_leaf, x.names, is_leaf=_is_logical_spec + ) return x.replace(names=new_names) return nn.transforms.map_variables( @@ -370,15 +386,16 @@ def remove_fn(x): axis_col, mutable=_is_mutable(axis_col), trans_in_fn=lambda tree: _tree_map_axes(remove_fn, tree), - trans_out_fn=lambda tree: _tree_map_axes(insert_fn, tree) - ) + trans_out_fn=lambda tree: _tree_map_axes(insert_fn, tree), + ) # pylint: disable=dangerous-default-value def scan_with_axes( target: 'flax.linen.transforms.Target', - variable_axes: Mapping[flax.core.lift.CollectionFilter, - flax.core.lift.InOutScanAxis] = {}, + variable_axes: Mapping[ + flax.core.lift.CollectionFilter, flax.core.lift.InOutScanAxis + ] = {}, variable_broadcast: flax.core.lift.CollectionFilter = False, variable_carry: flax.core.lift.CollectionFilter = False, split_rngs: Mapping[flax.core.lift.PRNGSequenceFilter, bool] = {}, @@ -390,13 +407,13 @@ def scan_with_axes( axis_name: str = 'layers', axes_collections: Tuple[str, ...] = ('params',), data_transform: Optional[Callable[..., Any]] = None, - methods=None) -> 'flax.linen.transforms.Target': + methods=None, +) -> 'flax.linen.transforms.Target': """Wrapped version of nn.scan that handles logical axis metadata.""" # we broadcast the static metadata collections. axes_filters = tuple(f'{col}_axes' for col in axes_collections) - variable_broadcast = flax.core.scope.union_filters( - variable_broadcast, axes_filters) + variable_broadcast = flax.core.scope.union_filters(variable_broadcast, axes_filters) # perform usual lifted scan scanned = flax.linen.transforms.lift_transform( @@ -412,31 +429,34 @@ def scan_with_axes( reverse=reverse, unroll=unroll, data_transform=data_transform, - methods=methods) + methods=methods, + ) # add scan axis to logical axes metadata for col in axes_collections: if col in variable_axes: - scanned = _add_axis_to_metadata(scanned, - axis_pos=variable_axes[col], - axis_name=axis_name, - axis_col=f'{col}_axes') + scanned = _add_axis_to_metadata( + scanned, + axis_pos=variable_axes[col], + axis_name=axis_name, + axis_col=f'{col}_axes', + ) return scanned # pylint: disable=dangerous-default-value -def vmap_with_axes(target: 'flax.linen.transforms.Target', - variable_axes: Mapping[flax.core.lift.CollectionFilter, - flax.core.lift.InOutAxis], - split_rngs: Mapping[flax.core.lift.PRNGSequenceFilter, - bool] = {}, - in_axes=0, - out_axes=0, - axis_size: Optional[int] = None, - axis_name: Optional[str] = None, - partitioning_axis_names: Mapping[Any, str] = {}, - spmd_axis_name: Optional[str] = None, - methods=None) -> 'flax.linen.transforms.Target': +def vmap_with_axes( + target: 'flax.linen.transforms.Target', + variable_axes: Mapping[flax.core.lift.CollectionFilter, flax.core.lift.InOutAxis], + split_rngs: Mapping[flax.core.lift.PRNGSequenceFilter, bool] = {}, + in_axes=0, + out_axes=0, + axis_size: Optional[int] = None, + axis_name: Optional[str] = None, + partitioning_axis_names: Mapping[Any, str] = {}, + spmd_axis_name: Optional[str] = None, + methods=None, +) -> 'flax.linen.transforms.Target': """Wrapped version of nn.vmap that handles logical axis metadata.""" # tell normal vmap to broadcast axis metadata. @@ -455,7 +475,8 @@ def vmap_with_axes(target: 'flax.linen.transforms.Target', axis_size=axis_size, axis_name=axis_name, spmd_axis_name=spmd_axis_name, - methods=methods) + methods=methods, + ) for collection_name, axis in variable_axes.items(): if collection_name in partitioning_axis_names: @@ -463,7 +484,8 @@ def vmap_with_axes(target: 'flax.linen.transforms.Target', vmapped, axis_pos=axis, axis_name=partitioning_axis_names[collection_name], - axis_col=f'{collection_name}_axes') + axis_col=f'{collection_name}_axes', + ) return vmapped @@ -475,13 +497,15 @@ def vmap_with_axes(target: 'flax.linen.transforms.Target', # static_argnums behavior for flax remat via closure before applying jax remat. -def core_remat_static(fn, - variables=True, - rngs=True, - concrete=False, - prevent_cse=True, - static_argnums=(), - policy=None): +def core_remat_static( + fn, + variables=True, + rngs=True, + concrete=False, + prevent_cse=True, + static_argnums=(), + policy=None, +): """Flax functional core remat version with static_argnums.""" static_argnums = tuple(sorted(static_argnums)) @@ -504,7 +528,8 @@ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): dyn_args = tuple(x for i, x in enumerate(args) if i not in static_argnums) @functools.partial( - jax.remat, concrete=concrete, prevent_cse=prevent_cse, policy=policy) + jax.remat, concrete=concrete, prevent_cse=prevent_cse, policy=policy + ) @functools.wraps(fn) def rematted(variable_groups, rng_groups, *dyn_args): args = _repack_remat_args(dyn_args, static_args) @@ -514,18 +539,19 @@ def rematted(variable_groups, rng_groups, *dyn_args): return rematted(variable_groups, rng_groups, *dyn_args) - return flax.core.lift.pack( - inner, (variables,), (variables,), (rngs,), name='remat') + return flax.core.lift.pack(inner, (variables,), (variables,), (rngs,), name='remat') -def remat(target, - variables=True, - rngs=True, - concrete=False, - prevent_cse=True, - static_argnums=(), - policy=None, - methods=None): +def remat( + target, + variables=True, + rngs=True, + concrete=False, + prevent_cse=True, + static_argnums=(), + policy=None, + methods=None, +): """Flax lifted remat that supports static_argnums.""" return flax.linen.transforms.lift_transform( core_remat_static, @@ -536,4 +562,5 @@ def remat(target, prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, - methods=methods) + methods=methods, + ) diff --git a/flax/linen/pooling.py b/flax/linen/pooling.py index 87f039c107..65a96d242e 100644 --- a/flax/linen/pooling.py +++ b/flax/linen/pooling.py @@ -43,8 +43,9 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): """ num_batch_dims = inputs.ndim - (len(window_shape) + 1) strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len(strides), ( - f"len({window_shape}) must equal len({strides})") + assert len(window_shape) == len( + strides + ), f"len({window_shape}) must equal len({strides})" strides = (1,) * num_batch_dims + strides + (1,) dims = (1,) * num_batch_dims + window_shape + (1,) @@ -62,9 +63,11 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): padding = tuple(map(tuple, padding)) assert len(padding) == len(window_shape), ( f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}") - assert all([len(x) == 2 for x in padding]), ( - f"each entry in padding {padding} must be length 2") + f"window_shape {window_shape}" + ) + assert all( + [len(x) == 2 for x in padding] + ), f"each entry in padding {padding} must be length 2" padding = ((0, 0),) + padding + ((0, 0),) y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) if is_single_input: @@ -72,7 +75,9 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): return y -def avg_pool(inputs, window_shape, strides=None, padding="VALID", count_include_pad=True): +def avg_pool( + inputs, window_shape, strides=None, padding="VALID", count_include_pad=True +): """Pools the input by taking the average over a window. Args: @@ -88,14 +93,14 @@ def avg_pool(inputs, window_shape, strides=None, padding="VALID", count_include_ Returns: The average for each window slice. """ - y = pool(inputs, 0., lax.add, window_shape, strides, padding) + y = pool(inputs, 0.0, lax.add, window_shape, strides, padding) if count_include_pad: y = y / np.prod(window_shape) else: div_shape = inputs.shape[:-1] + (1,) if len(div_shape) - 2 == len(window_shape): - div_shape = (1,) + div_shape[1:] - y = y / pool(jnp.ones(div_shape), 0., lax.add, window_shape, strides, padding) + div_shape = (1,) + div_shape[1:] + y = y / pool(jnp.ones(div_shape), 0.0, lax.add, window_shape, strides, padding) return y diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index a00dfdb540..8506e22321 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -19,7 +19,7 @@ """ from abc import ABCMeta -from functools import partial # pylint: disable=g-importing-member +from functools import partial # pylint: disable=g-importing-member from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union, TypeVar, cast from typing_extensions import Protocol from absl import logging @@ -51,19 +51,22 @@ CarryHistory = Any Output = Any + class _Never: pass + NEVER = _Never() LEGACY_UPDATE_MESSAGE = ( - "The RNNCellBase API has changed, " - "the error you are experiencing might be caused by this change. Please " - "update your code to the new API, for more information on how to do this " - "please check out the RNNCellBase migration guide: " - "https://flax.readthedocs.io/en/latest/guides/rnncell_upgrade_guide.html" + 'The RNNCellBase API has changed, ' + 'the error you are experiencing might be caused by this change. Please ' + 'update your code to the new API, for more information on how to do this ' + 'please check out the RNNCellBase migration guide: ' + 'https://flax.readthedocs.io/en/latest/guides/rnncell_upgrade_guide.html' ) + class RNNCellCompatibilityMeta(ABCMeta): """Metaclass for RNNCell compatibility.""" @@ -74,13 +77,16 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: msg = e.args[0] raise TypeError(f'{msg} \n\n {LEGACY_UPDATE_MESSAGE}') from e + def deprecation_method_decorator(f): def wrapper(*args, **kwargs): if len(args) < 1 or not isinstance(args[0], RNNCellBase): raise TypeError(LEGACY_UPDATE_MESSAGE) return f(*args, **kwargs) + return wrapper + class RNNCellBase(Module): """RNN cell base class.""" @@ -97,7 +103,6 @@ def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]) -> Carry: """ raise NotImplementedError - @property def num_feature_axes(self) -> int: """Returns the number of feature axes of the RNN cell.""" @@ -161,19 +166,23 @@ def __call__(self, carry, inputs): c, h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. - dense_h = partial(Dense, - features=hidden_features, - use_bias=True, - kernel_init=self.recurrent_kernel_init, - bias_init=self.bias_init, - dtype=self.dtype, - param_dtype=self.param_dtype) - dense_i = partial(Dense, - features=hidden_features, - use_bias=False, - kernel_init=self.kernel_init, - dtype=self.dtype, - param_dtype=self.param_dtype) + dense_h = partial( + Dense, + features=hidden_features, + use_bias=True, + kernel_init=self.recurrent_kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + dense_i = partial( + Dense, + features=hidden_features, + use_bias=False, + kernel_init=self.kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h)) f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h)) g = self.activation_fn(dense_i(name='ig')(inputs) + dense_h(name='hg')(h)) @@ -185,7 +194,8 @@ def __call__(self, carry, inputs): @nowrap @deprecation_method_decorator def initialize_carry( - self, rng: PRNGKey, input_shape: Tuple[int, ...]) -> Tuple[Array, Array]: + self, rng: PRNGKey, input_shape: Tuple[int, ...] + ) -> Tuple[Array, Array]: """Initialize the RNN cell carry. Args: @@ -219,8 +229,8 @@ class DenseParams(Module): @compact def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: k = self.param( - 'kernel', self.kernel_init, (inputs.shape[-1], self.features), - self.param_dtype) + 'kernel', self.kernel_init, (inputs.shape[-1], self.features), self.param_dtype + ) if self.use_bias: b = self.param('bias', self.bias_init, (self.features,), self.param_dtype) else: @@ -274,8 +284,9 @@ class OptimizedLSTMCell(RNNCellBase, metaclass=RNNCellCompatibilityMeta): carry_init: initializers.Initializer = initializers.zeros_init() @compact - def __call__(self, carry: Tuple[Array, Array], - inputs: Array) -> Tuple[Tuple[Array, Array], Array]: + def __call__( + self, carry: Tuple[Array, Array], inputs: Array + ) -> Tuple[Tuple[Array, Array], Array]: r"""An optimized long short-term memory (LSTM) cell. Args: @@ -290,9 +301,11 @@ def __call__(self, carry: Tuple[Array, Array], c, h = carry hidden_features = h.shape[-1] - def _concat_dense(inputs: Array, - params: Mapping[str, Tuple[Array, Optional[Array]]], - use_bias: bool = True) -> Dict[str, Array]: + def _concat_dense( + inputs: Array, + params: Mapping[str, Tuple[Array, Optional[Array]]], + use_bias: bool = True, + ) -> Dict[str, Array]: # Concatenates the individual kernels and biases, given in params, into a # single kernel and single bias for efficiency before applying them using # dot_general. @@ -324,15 +337,25 @@ def _concat_dense(inputs: Array, dense_params_i = {} for component in ['i', 'f', 'g', 'o']: dense_params_i[component] = DenseParams( - features=hidden_features, use_bias=False, + features=hidden_features, + use_bias=False, param_dtype=self.param_dtype, - kernel_init=self.kernel_init, bias_init=self.bias_init, - name=f'i{component}')(inputs) # type: ignore[call-arg] + kernel_init=self.kernel_init, + bias_init=self.bias_init, + name=f'i{component}', + )( + inputs + ) # type: ignore[call-arg] dense_params_h[component] = DenseParams( - features=hidden_features, use_bias=True, + features=hidden_features, + use_bias=True, param_dtype=self.param_dtype, - kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init, - name=f'h{component}')(h) # type: ignore[call-arg] + kernel_init=self.recurrent_kernel_init, + bias_init=self.bias_init, + name=f'h{component}', + )( + h + ) # type: ignore[call-arg] dense_h = _concat_dense(h, dense_params_h, use_bias=True) dense_i = _concat_dense(inputs, dense_params_i, use_bias=False) @@ -348,7 +371,8 @@ def _concat_dense(inputs: Array, @nowrap @deprecation_method_decorator def initialize_carry( - self, rng: PRNGKey, input_shape: Tuple[int, ...]) -> Tuple[Array, Array]: + self, rng: PRNGKey, input_shape: Tuple[int, ...] + ) -> Tuple[Array, Array]: """Initialize the RNN cell carry. Args: @@ -369,6 +393,7 @@ def initialize_carry( def num_feature_axes(self) -> int: return 1 + class GRUCell(RNNCellBase, metaclass=RNNCellCompatibilityMeta): r"""GRU cell. @@ -423,26 +448,31 @@ def __call__(self, carry, inputs): h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. - dense_h = partial(Dense, - features=hidden_features, - use_bias=False, - dtype=self.dtype, - param_dtype=self.param_dtype, - kernel_init=self.recurrent_kernel_init, - bias_init=self.bias_init) - dense_i = partial(Dense, - features=hidden_features, - use_bias=True, - dtype=self.dtype, - param_dtype=self.param_dtype, - kernel_init=self.kernel_init, - bias_init=self.bias_init) + dense_h = partial( + Dense, + features=hidden_features, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.recurrent_kernel_init, + bias_init=self.bias_init, + ) + dense_i = partial( + Dense, + features=hidden_features, + use_bias=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + ) r = self.gate_fn(dense_i(name='ir')(inputs) + dense_h(name='hr')(h)) z = self.gate_fn(dense_i(name='iz')(inputs) + dense_h(name='hz')(h)) # add bias because the linear transformations aren't directly summed. - n = self.activation_fn(dense_i(name='in')(inputs) + - r * dense_h(name='hn', use_bias=True)(h)) - new_h = (1. - z) * n + z * h + n = self.activation_fn( + dense_i(name='in')(inputs) + r * dense_h(name='hn', use_bias=True)(h) + ) + new_h = (1.0 - z) * n + z * h return new_h, new_h @nowrap @@ -528,25 +558,29 @@ def __call__(self, carry, inputs): A tuple with the new carry and the output. """ c, h = carry - input_to_hidden = partial(Conv, - features=4*self.features, - kernel_size=self.kernel_size, - strides=self.strides, - padding=self.padding, - use_bias=self.use_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, - name='ih') - - hidden_to_hidden = partial(Conv, - features=4*self.features, - kernel_size=self.kernel_size, - strides=self.strides, - padding=self.padding, - use_bias=self.use_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, - name='hh') + input_to_hidden = partial( + Conv, + features=4 * self.features, + kernel_size=self.kernel_size, + strides=self.strides, + padding=self.padding, + use_bias=self.use_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + name='ih', + ) + + hidden_to_hidden = partial( + Conv, + features=4 * self.features, + kernel_size=self.kernel_size, + strides=self.strides, + padding=self.padding, + use_bias=self.use_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + name='hh', + ) gates = input_to_hidden()(inputs) + hidden_to_hidden()(h) i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1) @@ -569,8 +603,8 @@ def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]): An initialized carry for the given RNN cell. """ # (*batch_dims, *signal_dims, features) - signal_dims = input_shape[-self.num_feature_axes:-1] - batch_dims = input_shape[:-self.num_feature_axes] + signal_dims = input_shape[-self.num_feature_axes : -1] + batch_dims = input_shape[: -self.num_feature_axes] key1, key2 = random.split(rng) mem_shape = batch_dims + signal_dims + (self.features,) c = self.carry_init(key1, mem_shape, self.param_dtype) @@ -581,6 +615,7 @@ def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]): def num_feature_axes(self) -> int: return len(self.kernel_size) + 1 + class RNN(Module): """The ``RNN`` module takes any :class:`RNNCellBase` instance and applies it over a sequence using :func:`flax.linen.scan`. @@ -681,6 +716,7 @@ class RNN(Module): PRNG key should be split such that its values are different at each step, or replicated such that its values remain the same at each step. This argument is forwarded to `nn.scan`. """ + cell: RNNCellBase cell_size: Any = NEVER time_major: bool = False @@ -688,7 +724,7 @@ class RNN(Module): reverse: bool = False keep_order: bool = False unroll: int = 1 - variable_axes: Mapping[lift.CollectionFilter,lift.InOutScanAxis] = FrozenDict() + variable_axes: Mapping[lift.CollectionFilter, lift.InOutScanAxis] = FrozenDict() variable_broadcast: lift.CollectionFilter = 'params' variable_carry: lift.CollectionFilter = False split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict({'params': False}) @@ -696,21 +732,21 @@ class RNN(Module): def __post_init__(self) -> None: if self.cell_size is not NEVER: raise TypeError( - f'The `cell_size` argument is no longer available`. ' + LEGACY_UPDATE_MESSAGE + f'The `cell_size` argument is no longer available`. ' + LEGACY_UPDATE_MESSAGE ) return super().__post_init__() def __call__( - self, - inputs: jax.Array, - *, - initial_carry: Optional[Carry] = None, - init_key: Optional[random.KeyArray] = None, - seq_lengths: Optional[Array] = None, - return_carry: Optional[bool] = None, - time_major: Optional[bool] = None, - reverse: Optional[bool] = None, - keep_order: Optional[bool] = None, + self, + inputs: jax.Array, + *, + initial_carry: Optional[Carry] = None, + init_key: Optional[random.KeyArray] = None, + seq_lengths: Optional[Array] = None, + return_carry: Optional[bool] = None, + time_major: Optional[bool] = None, + reverse: Optional[bool] = None, + keep_order: Optional[bool] = None, ) -> Union[Output, Tuple[Carry, Output]]: """ Applies the RNN to the inputs. @@ -765,23 +801,25 @@ def __call__( if time_major: # we add +1 because we moved the time axis to the front - batch_dims = inputs.shape[1:-self.cell.num_feature_axes] + batch_dims = inputs.shape[1 : -self.cell.num_feature_axes] else: batch_dims = inputs.shape[:time_axis] # maybe reverse the sequence if reverse: inputs = jax.tree_map( - lambda x: flip_sequences( - x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major), # type: ignore - inputs) + lambda x: flip_sequences( + x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major + ), # type: ignore + inputs, + ) carry: Carry if initial_carry is None: if init_key is None: init_key = random.PRNGKey(0) - input_shape = inputs.shape[:time_axis] + inputs.shape[time_axis + 1:] + input_shape = inputs.shape[:time_axis] + inputs.shape[time_axis + 1 :] carry = self.cell.initialize_carry(init_key, input_shape) else: carry = initial_carry @@ -789,7 +827,7 @@ def __call__( slice_carry = seq_lengths is not None and return_carry def scan_fn( - cell: RNNCellBase, carry: Carry, x: Array + cell: RNNCellBase, carry: Carry, x: Array ) -> Union[Tuple[Carry, Array], Tuple[Carry, Tuple[Carry, Array]]]: carry, y = cell(carry, x) # When we have a segmentation mask we return the carry as an output @@ -802,14 +840,14 @@ def scan_fn( return carry, y scan = transforms.scan( - scan_fn, - in_axes=time_axis, - out_axes=(0, time_axis) if slice_carry else time_axis, - unroll=self.unroll, - variable_axes=self.variable_axes, - variable_broadcast=self.variable_broadcast, - variable_carry=self.variable_carry, - split_rngs=self.split_rngs, + scan_fn, + in_axes=time_axis, + out_axes=(0, time_axis) if slice_carry else time_axis, + unroll=self.unroll, + variable_axes=self.variable_axes, + variable_broadcast=self.variable_broadcast, + variable_carry=self.variable_carry, + split_rngs=self.split_rngs, ) scan_output = scan(self.cell, carry, inputs) @@ -828,15 +866,18 @@ def scan_fn( if reverse and keep_order: outputs = jax.tree_map( - lambda x: flip_sequences( - x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major), # type: ignore - outputs) + lambda x: flip_sequences( + x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major + ), # type: ignore + outputs, + ) if return_carry: return carry, outputs else: return outputs + def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A: last_idx = seq_lengths - 1 @@ -845,12 +886,14 @@ def _slice_array(x: jnp.ndarray): return jax.tree_map(_slice_array, sequence) + def _expand_dims_like(x, target): """Expands the shape of `x` to match `target`'s shape by adding singleton dimensions.""" return x.reshape(list(x.shape) + [1] * (target.ndim - x.ndim)) + def flip_sequences( - inputs: Array, seq_lengths: Optional[Array], num_batch_dims: int, time_major: bool + inputs: Array, seq_lengths: Optional[Array], num_batch_dims: int, time_major: bool ) -> Array: """Flips a sequence of inputs along the time axis. @@ -890,13 +933,13 @@ def flip_sequences( seq_lengths = jnp.expand_dims(seq_lengths, axis=time_axis) # create indexes - idxs = jnp.arange(max_steps - 1, -1, -1) # [max_steps] + idxs = jnp.arange(max_steps - 1, -1, -1) # [max_steps] if time_major: idxs = jnp.reshape(idxs, [max_steps] + [1] * num_batch_dims) else: - idxs = jnp.reshape(idxs, [1] * num_batch_dims + [max_steps]) # [1, ..., max_steps] - idxs = (idxs + seq_lengths) % max_steps # [*batch, max_steps] - idxs = _expand_dims_like(idxs, target=inputs) # [*batch, max_steps, *features] + idxs = jnp.reshape(idxs, [1] * num_batch_dims + [max_steps]) # [1, ..., max_steps] + idxs = (idxs + seq_lengths) % max_steps # [*batch, max_steps] + idxs = _expand_dims_like(idxs, target=inputs) # [*batch, max_steps, *features] # Select the inputs in flipped order. outputs = jnp.take_along_axis(inputs, idxs, axis=time_axis) @@ -907,23 +950,27 @@ def _concatenate(a: Array, b: Array) -> Array: """Concatenates two arrays along the last dimension.""" return jnp.concatenate([a, b], axis=-1) + class RNNBase(Protocol): + def __call__( - self, - inputs: jax.Array, - *, - initial_carry: Optional[Carry] = None, - init_key: Optional[random.KeyArray] = None, - seq_lengths: Optional[Array] = None, - return_carry: Optional[bool] = None, - time_major: Optional[bool] = None, - reverse: Optional[bool] = None, - keep_order: Optional[bool] = None, + self, + inputs: jax.Array, + *, + initial_carry: Optional[Carry] = None, + init_key: Optional[random.KeyArray] = None, + seq_lengths: Optional[Array] = None, + return_carry: Optional[bool] = None, + time_major: Optional[bool] = None, + reverse: Optional[bool] = None, + keep_order: Optional[bool] = None, ) -> Union[Output, Tuple[Carry, Output]]: ... + class Bidirectional(Module): """Processes the input in both directions and merges the results.""" + forward_rnn: RNNBase backward_rnn: RNNBase merge_fn: Callable[[Array, Array], Array] = _concatenate @@ -931,16 +978,16 @@ class Bidirectional(Module): return_carry: bool = False def __call__( - self, - inputs: jax.Array, - *, - initial_carry: Optional[Carry] = None, - init_key: Optional[random.KeyArray] = None, - seq_lengths: Optional[Array] = None, - return_carry: Optional[bool] = None, - time_major: Optional[bool] = None, - reverse: Optional[bool] = None, - keep_order: Optional[bool] = None, + self, + inputs: jax.Array, + *, + initial_carry: Optional[Carry] = None, + init_key: Optional[random.KeyArray] = None, + seq_lengths: Optional[Array] = None, + return_carry: Optional[bool] = None, + time_major: Optional[bool] = None, + reverse: Optional[bool] = None, + keep_order: Optional[bool] = None, ) -> Union[Output, Tuple[Carry, Output]]: if time_major is None: time_major = self.time_major @@ -957,19 +1004,34 @@ def __call__( # Throw a warning in case the user accidentally re-uses the forward RNN # for the backward pass and does not intend for them to share parameters. if self.forward_rnn is self.backward_rnn: - logging.warning(("forward_rnn and backward_rnn is the same object, so " - "they will share parameters.")) + logging.warning( + ( + 'forward_rnn and backward_rnn is the same object, so ' + 'they will share parameters.' + ) + ) # Encode in the forward direction. carry_forward, outputs_forward = self.forward_rnn( - inputs, initial_carry=initial_carry_forward, init_key=key_forward, - seq_lengths=seq_lengths, return_carry=True, - time_major=time_major, reverse=False) + inputs, + initial_carry=initial_carry_forward, + init_key=key_forward, + seq_lengths=seq_lengths, + return_carry=True, + time_major=time_major, + reverse=False, + ) carry_backward, outputs_backward = self.backward_rnn( - inputs, initial_carry=initial_carry_backward, init_key=key_backward, - seq_lengths=seq_lengths, return_carry=True, - time_major=time_major, reverse=True, keep_order=True) + inputs, + initial_carry=initial_carry_backward, + init_key=key_backward, + seq_lengths=seq_lengths, + return_carry=True, + time_major=time_major, + reverse=True, + keep_order=True, + ) carry = (carry_forward, carry_backward) outputs = jax.tree_map(self.merge_fn, outputs_forward, outputs_backward) @@ -977,4 +1039,4 @@ def __call__( if return_carry: return carry, outputs else: - return outputs \ No newline at end of file + return outputs diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index c56e3edf3f..c4af231c01 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -52,11 +52,14 @@ # Dynamic Axis Mapping Context # ------------------------------------------------------------------------------ + @dataclasses.dataclass class _AxisRules(threading.local): """Dynamic logical axis to mesh axis binding context.""" + rules: LogicalRules = () + # Global axis binding context. _axis_rules = _AxisRules() @@ -114,11 +117,11 @@ def _logical_to_mesh_axes( if rules is None: rules = _axis_rules.rules axis_name_counts = collections.Counter(array_dim_names) - dups = tuple( - k for k, v in axis_name_counts.items() if v > 1 and k is not None) + dups = tuple(k for k, v in axis_name_counts.items() if v > 1 and k is not None) if dups: raise ValueError( - f'Unsupported: Dimensions {dups} occur more than once in array names.') + f'Unsupported: Dimensions {dups} occur more than once in array names.' + ) if not isinstance(rules, (tuple, list)): raise ValueError('Unknown axis rule specification type.') # We assign mesh axes using a priority based ruleset over logical axis names. @@ -127,8 +130,10 @@ def _logical_to_mesh_axes( for rule_model_name, rule_mesh_names in rules: if rule_model_name in array_dim_names: pos = array_dim_names.index(rule_model_name) - if (_mesh_assignment_free(rule_mesh_names, result) and - result[pos] == _unassigned_axis): + if ( + _mesh_assignment_free(rule_mesh_names, result) + and result[pos] == _unassigned_axis + ): result[pos] = rule_mesh_names return result @@ -177,10 +182,7 @@ def logical_to_mesh_axes( return jax.sharding.PartitionSpec(*result) -def logical_to_mesh( - tree: Any, - rules: Optional[LogicalRules] = None -) -> Any: +def logical_to_mesh(tree: Any, rules: Optional[LogicalRules] = None) -> Any: """Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.""" return jax.tree_map( lambda x: logical_to_mesh_axes(x, rules), @@ -210,6 +212,7 @@ def _global_mesh_defined() -> bool: class RulesFallback(enum.Enum): """How a sharding constraint should behave when no matching rule is found.""" + AXIS_IS_UNSHARDED = 'axis_is_unsharded' RAISE_ERROR = 'raise_error' NO_CONSTRAINT = 'no_constraint' @@ -218,9 +221,12 @@ class RulesFallback(enum.Enum): def _with_sharding_constraint( x: Array, axis_resources: Optional[jax.sharding.PartitionSpec], - mesh: Optional[jax.sharding.Mesh] = None): + mesh: Optional[jax.sharding.Mesh] = None, +): """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit.""" - if jax.devices()[0].platform == 'cpu' or (not _global_mesh_defined() and mesh is None): + if jax.devices()[0].platform == 'cpu' or ( + not _global_mesh_defined() and mesh is None + ): return x else: if mesh is not None and axis_resources is not None: @@ -234,7 +240,8 @@ def _with_sharding_constraint_one_fallback( x: Array, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, rules: Optional[LogicalRules] = None, - mesh: Optional[jax.sharding.Mesh] = None): + mesh: Optional[jax.sharding.Mesh] = None, +): """Either imposes a sharding constraint or applies fallback.""" mesh_axes = _logical_to_mesh_axes(axis_resources, rules) if mesh_axes is None: @@ -253,7 +260,8 @@ def _with_sharding_constraint_one_fallback( def _is_logical_spec(x): return x is None or ( - isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x)) + isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x) + ) def with_logical_constraint( @@ -261,7 +269,8 @@ def with_logical_constraint( logical_axis_resources: LogicalPartitionSpecPytree, rules: Optional[LogicalRules] = None, mesh: Optional[jax.sharding.Mesh] = None, - fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED): + fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, +): """Version of pjit's with_sharding_constraint that uses logical axis names.""" # If no axis binding is set, this is a no-op. if rules is None: @@ -271,11 +280,15 @@ def with_logical_constraint( # Translate logical names to mesh assignments. return jax.tree_util.tree_map( functools.partial( - _with_sharding_constraint_one_fallback, fallback=fallback, - rules=rules, mesh=mesh), + _with_sharding_constraint_one_fallback, + fallback=fallback, + rules=rules, + mesh=mesh, + ), logical_axis_resources, x, - is_leaf=_is_logical_spec) + is_leaf=_is_logical_spec, + ) # Logical Partitioning Axis Metadata @@ -284,12 +297,13 @@ def with_logical_constraint( class LogicallyPartitioned(meta.Partitioned): rules: Optional[LogicalRules] = struct.field(default=None, pytree_node=False) + def unbox(self, apply_constraint=True) -> Any: """Returns the wrapped value with the partitioning constraint applied.""" if apply_constraint and (_global_mesh_defined() or self.mesh is not None): return with_logical_constraint( - self.value, self.get_partition_spec(), - rules=self.rules, mesh=self.mesh) + self.value, self.get_partition_spec(), rules=self.rules, mesh=self.mesh + ) else: return self.value @@ -299,7 +313,7 @@ def with_logical_partitioning( names: meta.LogicalNames, mesh: Optional[jax.sharding.Mesh] = None, rules: Optional[LogicalRules] = None, - ) -> Callable[..., LogicallyPartitioned]: +) -> Callable[..., LogicallyPartitioned]: """Wraps a function's return value with LogicallyPartitioned. Example:: @@ -319,8 +333,9 @@ def with_logical_partitioning( A function wrapping ``fn`` that will return an instance of ``LogicallyPartitioned``. """ + @functools.wraps(fn) def wrapper(*args, **kwargs): - return LogicallyPartitioned(fn(*args, **kwargs), names, - rules=rules, mesh=mesh) + return LogicallyPartitioned(fn(*args, **kwargs), names, rules=rules, mesh=mesh) + return wrapper diff --git a/flax/linen/stochastic.py b/flax/linen/stochastic.py index b562646255..f2b6ecf4cc 100644 --- a/flax/linen/stochastic.py +++ b/flax/linen/stochastic.py @@ -31,19 +31,20 @@ class Dropout(Module): """Create a dropout layer. - Note: When using :meth:`Module.apply() `, make sure - to include an RNG seed named `'dropout'`. For example:: - - model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})` - - Attributes: - rate: the dropout probability. (_not_ the keep rate!) - broadcast_dims: dimensions that will share the same dropout mask - deterministic: if false the inputs are scaled by `1 / (1 - rate)` and - masked, whereas if true, no mask is applied and the inputs are returned - as is. - rng_collection: the rng collection name to use when requesting an rng key. + Note: When using :meth:`Module.apply() `, make sure + to include an RNG seed named `'dropout'`. For example:: + + model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})` + + Attributes: + rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by `1 / (1 - rate)` and + masked, whereas if true, no mask is applied and the inputs are returned + as is. + rng_collection: the rng collection name to use when requesting an rng key. """ + rate: float broadcast_dims: Sequence[int] = () deterministic: Optional[bool] = None @@ -51,7 +52,7 @@ class Dropout(Module): @compact def __call__( - self, inputs, deterministic: Optional[bool] = None, rng: Optional[KeyArray] = None + self, inputs, deterministic: Optional[bool] = None, rng: Optional[KeyArray] = None ): """Applies a random dropout mask to the input. @@ -66,17 +67,16 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ - deterministic = merge_param( - 'deterministic', self.deterministic, deterministic) + deterministic = merge_param('deterministic', self.deterministic, deterministic) - if (self.rate == 0.) or deterministic: + if (self.rate == 0.0) or deterministic: return inputs # Prevent gradient NaNs in 1.0 edge-case. if self.rate == 1.0: return jnp.zeros_like(inputs) - keep_prob = 1. - self.rate + keep_prob = 1.0 - self.rate if rng is None: rng = self.make_rng(self.rng_collection) broadcast_shape = list(inputs.shape) diff --git a/flax/linen/summary.py b/flax/linen/summary.py index 59d3369c63..88c92ac448 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -33,7 +33,8 @@ PRNGKey = Any # pylint: disable=invalid-name RNGSequences = Dict[str, PRNGKey] -Array = Any # pylint: disable=invalid-name +Array = Any # pylint: disable=invalid-name + class _ValueRepresentation(ABC): """A class that represents a value in the summary table.""" @@ -42,6 +43,7 @@ class _ValueRepresentation(ABC): def render(self) -> str: ... + @dataclasses.dataclass class _ArrayRepresentation(_ValueRepresentation): shape: Tuple[int, ...] @@ -59,18 +61,22 @@ def render(self): shape_repr = ','.join(str(x) for x in self.shape) return f'[dim]{self.dtype}[/dim][{shape_repr}]' + @dataclasses.dataclass class _PartitionedArrayRepresentation(_ValueRepresentation): array_representation: _ArrayRepresentation names: meta.LogicalNames @classmethod - def from_partitioned(cls, partitioned: meta.Partitioned) -> '_PartitionedArrayRepresentation': + def from_partitioned( + cls, partitioned: meta.Partitioned + ) -> '_PartitionedArrayRepresentation': return cls(_ArrayRepresentation.from_array(partitioned.value), partitioned.names) def render(self): return self.array_representation.render() + f' [dim]P[/dim]{self.names}' + @dataclasses.dataclass class _ObjectRepresentation(_ValueRepresentation): obj: Any @@ -78,6 +84,7 @@ class _ObjectRepresentation(_ValueRepresentation): def render(self): return repr(self.obj) + @dataclasses.dataclass class Row: """Contains the information about a single row in the summary table. @@ -93,6 +100,7 @@ class Row: summarization is done then this dictionary potentially contains parameters from submodules depending on the depth of the Module in question. """ + path: Tuple[str, ...] module_type: Type[module_lib.Module] method: str @@ -110,7 +118,9 @@ def __post_init__(self): def size_and_bytes(self, collections: Iterable[str]) -> Dict[str, Tuple[int, int]]: return { col: _size_and_bytes(self.counted_variables[col]) - if col in self.counted_variables else (0, 0) for col in collections + if col in self.counted_variables + else (0, 0) + for col in collections } @@ -124,23 +134,24 @@ class Table(List[Row]): * `collections`: a list containing the parameter collections (e.g. 'params', 'batch_stats', etc) """ - def __init__(self, module: module_lib.Module, collections: Sequence[str], - rows: Iterable[Row]): + def __init__( + self, module: module_lib.Module, collections: Sequence[str], rows: Iterable[Row] + ): super().__init__(rows) self.module = module self.collections = collections def tabulate( - module: module_lib.Module, - rngs: Union[PRNGKey, RNGSequences], - depth: Optional[int] = None, - show_repeated: bool = False, - mutable: CollectionFilter = True, - console_kwargs: Optional[Mapping[str, Any]] = None, - table_kwargs: Mapping[str, Any] = MappingProxyType({}), - column_kwargs: Mapping[str, Any] = MappingProxyType({}), - **kwargs, + module: module_lib.Module, + rngs: Union[PRNGKey, RNGSequences], + depth: Optional[int] = None, + show_repeated: bool = False, + mutable: CollectionFilter = True, + console_kwargs: Optional[Mapping[str, Any]] = None, + table_kwargs: Mapping[str, Any] = MappingProxyType({}), + column_kwargs: Mapping[str, Any] = MappingProxyType({}), + **kwargs, ) -> Callable[..., str]: """Returns a function that creates a summary of the Module represented as a table. @@ -239,6 +250,7 @@ def _tabulate_fn(*fn_args, **fn_kwargs): return _tabulate_fn + def _get_module_table( module: module_lib.Module, depth: Optional[int], @@ -248,7 +260,6 @@ def _get_module_table( but returns the Table representation of the Module.""" def _get_table_fn(*args, **kwargs): - with module_lib._tabulate_context(): def _get_variables(): @@ -286,14 +297,26 @@ def _get_variables(): visited_paths.add(c.path) rows.append( - Row(c.path, c.module_type, c.method, inputs, c.outputs, module_vars, counted_vars)) + Row( + c.path, + c.module_type, + c.method, + inputs, + c.outputs, + module_vars, + counted_vars, + ) + ) return Table(module, tuple(collections), rows) return _get_table_fn + def _get_module_variables( - path: Tuple[str, ...], variables: FrozenVariableDict, all_paths: Set[Tuple[str, ...]] + path: Tuple[str, ...], + variables: FrozenVariableDict, + all_paths: Set[Tuple[str, ...]], ) -> Tuple[MutableVariableDict, Any]: """A function that takes a path and variables structure and returns a (module_variables, submodule_variables) tuple for that path. _get_module_variables @@ -305,14 +328,16 @@ def _get_module_variables( for key in all_keys: submodule_path = path + (key,) if submodule_path in all_paths: - for collection in module_variables: if key in module_variables[collection]: submodule_variables[collection][key] = module_variables[collection].pop(key) return module_variables, submodule_variables -def _get_path_variables(path: Tuple[str, ...], variables: FrozenVariableDict) -> MutableVariableDict: + +def _get_path_variables( + path: Tuple[str, ...], variables: FrozenVariableDict +) -> MutableVariableDict: """A function that takes a path and a variables structure and returns the variable structure at that path.""" path_variables = {} @@ -330,6 +355,7 @@ def _get_path_variables(path: Tuple[str, ...], variables: FrozenVariableDict) -> return path_variables + def _process_inputs(args, kwargs) -> Any: """A function that normalizes the representation of the ``args`` and ``kwargs`` for the ``inputs`` column.""" @@ -344,11 +370,12 @@ def _process_inputs(args, kwargs) -> Any: return input_values + def _render_table( - table: Table, - console_extras: Optional[Mapping[str, Any]], - table_kwargs: Mapping[str, Any], - column_kwargs: Mapping[str, Any], + table: Table, + console_extras: Optional[Mapping[str, Any]], + table_kwargs: Mapping[str, Any], + column_kwargs: Mapping[str, Any], ) -> str: """A function that renders a Table to a string representation using rich.""" console_kwargs = {'force_terminal': True, 'force_jupyter': False} @@ -357,11 +384,11 @@ def _render_table( non_params_cols = 4 rich_table = rich.table.Table( - show_header=True, - show_lines=True, - show_footer=True, - title=f'{table.module.__class__.__name__} Summary', - **table_kwargs, + show_header=True, + show_lines=True, + show_footer=True, + title=f'{table.module.__class__.__name__} Summary', + **table_kwargs, ) rich_table.add_column('path', **column_kwargs) @@ -381,8 +408,7 @@ def _render_table( if collection in row.module_variables: module_variables = _represent_tree(row.module_variables[collection]) module_variables = _normalize_structure(module_variables) - col_repr += _as_yaml_str( - _summary_tree_map(_maybe_render, module_variables)) + col_repr += _as_yaml_str(_summary_tree_map(_maybe_render, module_variables)) if col_repr: col_repr += '\n\n' @@ -391,17 +417,25 @@ def _render_table( no_show_methods = {'__call__', ''} path_repr = '/'.join(row.path) - method_repr = f' [dim]({row.method})[/dim]' if row.method not in no_show_methods else '' + method_repr = ( + f' [dim]({row.method})[/dim]' if row.method not in no_show_methods else '' + ) rich_table.add_row( path_repr, row.module_type.__name__ + method_repr, - _as_yaml_str(_summary_tree_map(_maybe_render, _normalize_structure(row.inputs))), - _as_yaml_str(_summary_tree_map(_maybe_render, _normalize_structure(row.outputs))), - *collections_size_repr) + _as_yaml_str( + _summary_tree_map(_maybe_render, _normalize_structure(row.inputs)) + ), + _as_yaml_str( + _summary_tree_map(_maybe_render, _normalize_structure(row.outputs)) + ), + *collections_size_repr, + ) # add footer with totals rich_table.columns[non_params_cols - 1].footer = rich.text.Text.from_markup( - 'Total', justify='right') + 'Total', justify='right' + ) # get collection totals collection_total = {col: (0, 0) for col in table.collections} @@ -414,12 +448,13 @@ def _render_table( # add totals to footer for i, col in enumerate(table.collections): - rich_table.columns[non_params_cols + i].footer = \ - _size_and_bytes_repr(*collection_total[col]) + rich_table.columns[non_params_cols + i].footer = _size_and_bytes_repr( + *collection_total[col] + ) # add final totals to caption caption_totals = (0, 0) - for (size, num_bytes) in collection_total.values(): + for size, num_bytes in collection_total.values(): caption_totals = ( caption_totals[0] + size, caption_totals[1] + num_bytes, @@ -430,9 +465,11 @@ def _render_table( return '\n' + _get_rich_repr(rich_table, console_kwargs) + '\n' + def _summary_tree_map(f, tree, *rest): return jax.tree_util.tree_map(f, tree, *rest, is_leaf=lambda x: x is None) + def _size_and_bytes_repr(size: int, num_bytes: int) -> str: if not size: return '' @@ -467,7 +504,7 @@ def _as_yaml_str(value) -> str: sort_keys=False, explicit_end=False, ) - return file.getvalue().replace('\n...', '').replace('\'', '').strip() + return file.getvalue().replace('\n...', '').replace("'", '').strip() def _normalize_structure(obj): @@ -478,22 +515,32 @@ def _normalize_structure(obj): elif isinstance(obj, Mapping): return {k: _normalize_structure(v) for k, v in obj.items()} elif dataclasses.is_dataclass(obj): - return {f.name: _normalize_structure(getattr(obj, f.name)) for f in dataclasses.fields(obj)} + return { + f.name: _normalize_structure(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + } else: return obj + def _bytes_repr(num_bytes): - count, units = ((f'{num_bytes / 1e9 :,.1f}', 'GB') if num_bytes > 1e9 else - (f'{num_bytes / 1e6 :,.1f}', 'MB') if num_bytes > 1e6 else - (f'{num_bytes / 1e3 :,.1f}', 'KB') if num_bytes > 1e3 else - (f'{num_bytes:,}', 'B')) + count, units = ( + (f'{num_bytes / 1e9 :,.1f}', 'GB') + if num_bytes > 1e9 + else (f'{num_bytes / 1e6 :,.1f}', 'MB') + if num_bytes > 1e6 + else (f'{num_bytes / 1e3 :,.1f}', 'KB') + if num_bytes > 1e3 + else (f'{num_bytes:,}', 'B') + ) return f'{count} {units}' def _get_value_representation(x: Any) -> _ValueRepresentation: if isinstance(x, (int, float, bool, type(None))) or ( - isinstance(x, np.ndarray) and np.isscalar(x)): + isinstance(x, np.ndarray) and np.isscalar(x) + ): return _ObjectRepresentation(x) elif isinstance(x, meta.Partitioned): return _PartitionedArrayRepresentation.from_partitioned(x) @@ -502,12 +549,16 @@ def _get_value_representation(x: Any) -> _ValueRepresentation: except: return _ObjectRepresentation(x) + def _represent_tree(x): """Returns a tree with the same structure as `x` but with each leaf replaced by a `_ValueRepresentation` object.""" return jax.tree_util.tree_map( - _get_value_representation, x, - is_leaf=lambda x: x is None or isinstance(x, meta.Partitioned)) + _get_value_representation, + x, + is_leaf=lambda x: x is None or isinstance(x, meta.Partitioned), + ) + def _maybe_render(x): - return x.render() if hasattr(x, 'render') else repr(x) \ No newline at end of file + return x.render() if hasattr(x, 'render') else repr(x) diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 59cabfe9ec..dd1a0e7701 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -25,8 +25,19 @@ import dataclasses import functools import inspect -from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, - Tuple, Type, TypeVar, Union) +from typing import ( + Any, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) from flax import errors from flax import struct @@ -54,8 +65,8 @@ def clean_clone(x): """Remove scopes and tracers from children.""" if isinstance(x, Module): object.__setattr__( - x, 'children', - {k: clean_clone(v) for k, v in x.children.items()}) + x, 'children', {k: clean_clone(v) for k, v in x.children.items()} + ) object.__setattr__(x, 'scope', None) return x @@ -63,6 +74,7 @@ def clean_clone(x): @struct.dataclass class VariablePlaceholder: """Used to mark Variables in a JAX-compatible way when lifting arguments.""" + collection: str = struct.field(pytree_node=False) name: str = struct.field(pytree_node=False) unbox: bool = struct.field(pytree_node=False) @@ -72,6 +84,7 @@ class VariablePlaceholder: @struct.dataclass class InstancePlaceholder: """Marks module instances in a JAX-compatible way when lifting arguments.""" + cls: Type[Any] = struct.field(pytree_node=False) attrs: Dict[Any, Any] = struct.field(pytree_node=False) id: int = struct.field(pytree_node=False) @@ -79,6 +92,7 @@ class InstancePlaceholder: def _memoize_by_id(fn, refs): """Memoization by module/variable id to handle aliasing in traversal.""" + @functools.wraps(fn) def wrapped_fn(x): nonlocal refs @@ -91,6 +105,7 @@ def wrapped_fn(x): if x_id not in refs: refs[x_id] = fn(x) return refs[x_id] + return wrapped_fn @@ -122,6 +137,7 @@ def get_module_scopes(module, args=None, kwargs=None): """ scopes = [] refs = {} + # Gather scopes associated with Variables and Module instances passed as # positional and keyword arguments. @functools.partial(_memoize_by_id, refs=refs) @@ -141,6 +157,7 @@ def get_arg_scope(x): attrs = jax.tree_util.tree_map(get_arg_scope, attrs) return InstancePlaceholder(x.__class__, attrs, x._id) return x + new_args, new_kwargs = jax.tree_util.tree_map(get_arg_scope, (args, kwargs)) # Gather scopes in Variables and Submodules passed as Module attributes. @@ -148,6 +165,7 @@ def get_arg_scope(x): def get_scopes(module): nonlocal scopes module._try_setup(shallow=True) + def get_scopes_inner(x): nonlocal scopes if isinstance(x, Module) and isinstance(x.scope, Scope): @@ -162,6 +180,7 @@ def get_scopes_inner(x): } jax.tree_util.tree_map(get_scopes_inner, attrs) scopes.append(module.scope) + get_scopes(module) return scopes, new_args, new_kwargs @@ -194,16 +213,16 @@ def set_module_scopes(module, args, kwargs, scopes): """ idx = 0 refs = {} + # Set scopes associated with Variables and Module instances passed as # positional and keyword arguments. @functools.partial(_memoize_by_id, refs=refs) def set_arg_scope(x): nonlocal idx if isinstance(x, VariablePlaceholder): - new_x = Variable(scope=scopes[idx], - collection=x.collection, - name=x.name, - unbox=x.unbox) + new_x = Variable( + scope=scopes[idx], collection=x.collection, name=x.name, unbox=x.unbox + ) idx += 1 return new_x elif isinstance(x, InstancePlaceholder): @@ -217,21 +236,22 @@ def is_placeholder(x): return isinstance(x, (VariablePlaceholder, InstancePlaceholder)) new_args, new_kwargs = jax.tree_util.tree_map( - set_arg_scope, (args, kwargs), is_leaf=is_placeholder) + set_arg_scope, (args, kwargs), is_leaf=is_placeholder + ) # set scopes in Variables and Submodules passed as Module attributes @functools.partial(_memoize_by_id, refs=refs) def set_scopes(module): nonlocal idx + def set_scopes_inner(x): nonlocal idx if isinstance(x, Module) and isinstance(x.scope, Scope): return set_scopes(x) elif isinstance(x, Variable) and isinstance(x.scope, Scope): - new_x = Variable(scope=scopes[idx], - collection=x.collection, - name=x.name, - unbox=x.unbox) + new_x = Variable( + scope=scopes[idx], collection=x.collection, name=x.name, unbox=x.unbox + ) idx += 1 return new_x else: @@ -246,6 +266,7 @@ def set_scopes_inner(x): new_module = module.clone(parent=scopes[idx], **new_attrs) idx += 1 return new_module + new_module = set_scopes(module) assert len(scopes) == idx, f'scope list mismatch {len(scopes)} != {idx}' return new_module, new_args, new_kwargs @@ -253,8 +274,9 @@ def set_scopes_inner(x): def _test_transformed_return_values(tree, method_name): """Tests whether the return value contains any Modules or Variables.""" - impure = any(map(lambda x: isinstance(x, (Module, Variable)), - jax.tree_util.tree_leaves(tree))) + impure = any( + map(lambda x: isinstance(x, (Module, Variable)), jax.tree_util.tree_leaves(tree)) + ) if impure: raise errors.TransformedMethodReturnValueError(method_name) @@ -262,11 +284,8 @@ def _test_transformed_return_values(tree, method_name): # Class lifting # ----------------------------------------------------------------------------- def module_class_lift_transform( - transform, - module_class, - *trafo_args, - methods=None, - **trafo_kwargs): + transform, module_class, *trafo_args, methods=None, **trafo_kwargs +): """Module class lift transform.""" # TODO(marcvanzee): Improve docstrings (#1977). # TODO(levskaya): find nicer argument convention for multi-method case? @@ -282,12 +301,12 @@ def module_class_lift_transform( # Pass different trafo args per each method. class_trafo_args = {k: ((), v) for k, v in methods.items()} else: - raise ValueError( - 'transform methods argument must be None, tuple, list, or dict.') + raise ValueError('transform methods argument must be None, tuple, list, or dict.') # Handle partially initialized module class constructors. - if (isinstance(module_class, functools.partial) and - issubclass(module_class.func, Module)): + if isinstance(module_class, functools.partial) and issubclass( + module_class.func, Module + ): partial_object = module_class module_class = module_class.func else: @@ -297,10 +316,12 @@ def create_trans_fn(fn_name, fn_trafo_args): # get existing unbound method from class fn = getattr(module_class, fn_name) trafo_args, trafo_kwargs = fn_trafo_args + # we need to create a scope-function from our class for the given method @functools.wraps(fn) def wrapped_fn(self, *args, **kwargs): state = self._state.export() + # make a scope-function to transform def core_fn(scopes, *args, **kwargs): # make a clone of self using its arguments @@ -317,31 +338,38 @@ def core_fn(scopes, *args, **kwargs): self._state.reimport(cloned._state) _test_transformed_return_values(res, fn_name) return res + # here we apply the given lifting transform to the scope-ingesting fn trafo_fn = transform(core_fn, *trafo_args, **trafo_kwargs) module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) ret = trafo_fn(module_scopes, *args, **kwargs) return ret + return wrapped_fn - transformed_fns = {fn_name: create_trans_fn(fn_name, fn_trafo_args) - for fn_name, fn_trafo_args in class_trafo_args.items()} + + transformed_fns = { + fn_name: create_trans_fn(fn_name, fn_trafo_args) + for fn_name, fn_trafo_args in class_trafo_args.items() + } # construct new dynamic class w. transformed methods transformed_cls = type( transform.__name__.capitalize() + module_class.__name__, (module_class,), - transformed_fns) + transformed_fns, + ) # Handle partially initialized module class constructors. if partial_object is not None: - transformed_cls = functools.partial(transformed_cls, - *partial_object.args, - **partial_object.keywords) + transformed_cls = functools.partial( + transformed_cls, *partial_object.args, **partial_object.keywords + ) return transformed_cls # Function lifting as decorator on methods __inside__ class definition. # ----------------------------------------------------------------------------- -def decorator_lift_transform(transform, class_fn, *trafo_args, - multi_scope=True, **trafo_kwargs): +def decorator_lift_transform( + transform, class_fn, *trafo_args, multi_scope=True, **trafo_kwargs +): """Decorator for lifted transform.""" # TODO(marcvanzee): Improve docstrings (#1977). # Due to the ordering of method decorators, we must wrap the class_fn @@ -352,9 +380,11 @@ def decorator_lift_transform(transform, class_fn, *trafo_args, else: class_fns = (class_fn,) prewrapped_fns = [wrap_method_once(class_fn) for class_fn in class_fns] + @functools.wraps(prewrapped_fns[0]) def wrapped_fn(self, *args, **kwargs): state = self._state.export() + # make a scope-function to transform def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs): if not multi_scope: @@ -365,8 +395,11 @@ def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs): self._state.reimport(cloned._state) _test_transformed_return_values(res, getattr(class_fn, '__name__', None)) return res - core_fns = [functools.partial(core_fn, prewrapped_fn, class_fn) - for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns)] + + core_fns = [ + functools.partial(core_fn, prewrapped_fn, class_fn) + for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) + ] # here we apply the given lifting transform to the scope-ingesting fn trafo_fn = transform(*core_fns, *trafo_args, **trafo_kwargs) module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) @@ -379,9 +412,11 @@ def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs): # arguments per lifted Module. raise NotImplementedError( 'This transform does not yet support' - ' Modules that include other Modules passed as arguments.') + ' Modules that include other Modules passed as arguments.' + ) module_scopes = module_scopes[0] return trafo_fn(module_scopes, *args, **kwargs) + return wrapped_fn @@ -393,60 +428,66 @@ def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs): def _is_module_class(target: TransformTarget) -> bool: - return (inspect.isclass(target) and issubclass(target, Module) or - (isinstance(target, functools.partial)) and - _is_module_class(target.func)) + return ( + inspect.isclass(target) + and issubclass(target, Module) + or (isinstance(target, functools.partial)) + and _is_module_class(target.func) + ) -def lift_transform(transform, - target, - *trafo_args, - methods=None, - **trafo_kwargs): +def lift_transform(transform, target, *trafo_args, methods=None, **trafo_kwargs): """Applies to class or as a decorator on class fns.""" # TODO(marcvanzee): Improve docstrings (#1977). if _is_module_class(target): return module_class_lift_transform( - transform, target, *trafo_args, methods=methods, **trafo_kwargs) + transform, target, *trafo_args, methods=methods, **trafo_kwargs + ) # we presume this is being used as a function decorator in class definition elif callable(target) and not isinstance(target, Module): - return decorator_lift_transform( - transform, target, *trafo_args, **trafo_kwargs) + return decorator_lift_transform(transform, target, *trafo_args, **trafo_kwargs) else: raise errors.TransformTargetError(target) -def lift_direct_transform(transform: Callable[..., Any], - targets: Tuple[Callable[..., Any], ...], - mdl: Module, - *args, multi_scope=True, **kwargs): +def lift_direct_transform( + transform: Callable[..., Any], + targets: Tuple[Callable[..., Any], ...], + mdl: Module, + *args, + multi_scope=True, + **kwargs, +): """Lift direct transform.""" # TODO(marcvanzee): Improve docstrings (#1977). for target in targets: if _is_module_class(target): raise ValueError( f'The {transform.__name__} transform can only be applied on a Module method.' - ' That is function that takes a Module instance as its first arg.') + ' That is function that takes a Module instance as its first arg.' + ) elif not callable(target): raise ValueError('transform target must be callable') # normalize self.foo bound methods to class.foo unbound methods. targets = tuple(_get_unbound_fn(target) for target in targets) aug_transform = lambda *fns: functools.partial(transform, *fns) - return decorator_lift_transform( - aug_transform, targets, multi_scope=multi_scope)(mdl, *args, **kwargs) - - -def vmap(target: Target, - variable_axes: Mapping[lift.CollectionFilter, - lift.InOutAxis] = FrozenDict(), - split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict(), - in_axes=0, - out_axes=0, - axis_size: Optional[int] = None, - axis_name: Optional[str] = None, - spmd_axis_name: Optional[str] = None, - metadata_params: Mapping[Any, Any] = {}, - methods=None) -> Target: + return decorator_lift_transform(aug_transform, targets, multi_scope=multi_scope)( + mdl, *args, **kwargs + ) + + +def vmap( + target: Target, + variable_axes: Mapping[lift.CollectionFilter, lift.InOutAxis] = FrozenDict(), + split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict(), + in_axes=0, + out_axes=0, + axis_size: Optional[int] = None, + axis_name: Optional[str] = None, + spmd_axis_name: Optional[str] = None, + metadata_params: Mapping[Any, Any] = {}, + methods=None, +) -> Target: """A lifted version of ``jax.vmap``. See ``jax.vmap`` for the unlifted batch transform in Jax. @@ -518,17 +559,20 @@ def vmap(target: Target, axis_size=axis_size, axis_name=axis_name, metadata_params=metadata_params, - spmd_axis_name=spmd_axis_name) + spmd_axis_name=spmd_axis_name, + ) -def jit(target: Target, - variables: lift.CollectionFilter = True, - rngs: lift.PRNGSequenceFilter = True, - static_argnums: Union[int, Iterable[int]] = (), - donate_argnums: Union[int, Iterable[int]] = (), - device=None, - backend: Union[str, None] = None, - methods=None) -> Target: +def jit( + target: Target, + variables: lift.CollectionFilter = True, + rngs: lift.PRNGSequenceFilter = True, + static_argnums: Union[int, Iterable[int]] = (), + donate_argnums: Union[int, Iterable[int]] = (), + device=None, + backend: Union[str, None] = None, + methods=None, +) -> Target: """Lifted version of ``jax.jit``. Args: @@ -569,23 +613,28 @@ def jit(target: Target, A wrapped version of target, set up for just-in-time compilation. """ return lift_transform( - lift.jit, target, - variables=variables, rngs=rngs, + lift.jit, + target, + variables=variables, + rngs=rngs, static_argnums=static_argnums, donate_argnums=donate_argnums, device=device, backend=backend, - methods=methods) + methods=methods, + ) -def checkpoint(target: Target, - variables: lift.CollectionFilter = True, - rngs: lift.PRNGSequenceFilter = True, - concrete: bool = False, - prevent_cse: bool = True, - static_argnums: Union[int, Tuple[int, ...]] = (), - policy: Optional[Callable[..., bool]] = None, - methods=None) -> Target: +def checkpoint( + target: Target, + variables: lift.CollectionFilter = True, + rngs: lift.PRNGSequenceFilter = True, + concrete: bool = False, + prevent_cse: bool = True, + static_argnums: Union[int, Tuple[int, ...]] = (), + policy: Optional[Callable[..., bool]] = None, + methods=None, +) -> Target: """Lifted version of ``jax.checkpoint``. Checkpointing is a technique for reducing memory usage by recomputing @@ -651,11 +700,16 @@ def checkpoint(target: Target, # lifted function static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums) return lift_transform( - lift.checkpoint, target, - variables=variables, rngs=rngs, concrete=concrete, + lift.checkpoint, + target, + variables=variables, + rngs=rngs, + concrete=concrete, static_argnums=static_argnums, - prevent_cse=prevent_cse, policy=policy, - methods=methods) + prevent_cse=prevent_cse, + policy=policy, + methods=methods, + ) remat = checkpoint @@ -667,9 +721,10 @@ def remat_scan( policy: Optional[Callable[..., bool]] = None, variable_broadcast: lift.CollectionFilter = False, variable_carry: lift.CollectionFilter = False, - variable_axes: Mapping[lift.CollectionFilter, - lift.InOutScanAxis] = FrozenDict({True: 0}), - split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict({True: True}) + variable_axes: Mapping[lift.CollectionFilter, lift.InOutScanAxis] = FrozenDict( + {True: 0} + ), + split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict({True: True}), ) -> Target: """Combines remat and scan for memory efficiency and constant time compilation. @@ -713,7 +768,8 @@ def __call__(self, x): A wrapped version of ``target`` that repeats itself prod(lengths) times. """ return lift_transform( - lift.remat_scan, target, + lift.remat_scan, + target, lengths=lengths, variable_broadcast=variable_broadcast, variable_carry=variable_carry, @@ -723,19 +779,21 @@ def __call__(self, x): ) -def scan(target: Target, - variable_axes: Mapping[lift.CollectionFilter, - lift.InOutScanAxis] = FrozenDict(), - variable_broadcast: lift.CollectionFilter = False, - variable_carry: lift.CollectionFilter = False, - split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict(), - in_axes=0, out_axes=0, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1, - data_transform: Optional[Callable[..., Any]] = None, - metadata_params: Mapping[Any, Any] = {}, - methods=None) -> Target: +def scan( + target: Target, + variable_axes: Mapping[lift.CollectionFilter, lift.InOutScanAxis] = FrozenDict(), + variable_broadcast: lift.CollectionFilter = False, + variable_carry: lift.CollectionFilter = False, + split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict(), + in_axes=0, + out_axes=0, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + data_transform: Optional[Callable[..., Any]] = None, + metadata_params: Mapping[Any, Any] = {}, + methods=None, +) -> Target: """A lifted version of ``jax.lax.scan``. See ``jax.lax.scan`` for the unlifted scan in Jax. @@ -879,18 +937,21 @@ def scan(target: Target, the loop. """ return lift_transform( - lift.scan, target, + lift.scan, + target, variable_axes=variable_axes, variable_broadcast=variable_broadcast, variable_carry=variable_carry, split_rngs=split_rngs, - in_axes=in_axes, out_axes=out_axes, + in_axes=in_axes, + out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, data_transform=data_transform, metadata_params=metadata_params, - methods=methods) + methods=methods, + ) def map_variables( @@ -898,10 +959,12 @@ def map_variables( mapped_collections: lift.CollectionFilter = True, trans_in_fn: Callable[..., Any] = lift.id_fn, trans_out_fn: Callable[..., Any] = lift.id_fn, - init: bool = False, mutable: bool = False, + init: bool = False, + mutable: bool = False, rngs: lift.PRNGSequenceFilter = True, variables: lift.CollectionFilter = True, - methods=None) -> Target: + methods=None, +) -> Target: """Map Variables inside a module. ``map_variables`` can be used to transform the variables inside a module @@ -954,11 +1017,15 @@ def map_variables( """ return lift_transform( - lift.map_variables, target, + lift.map_variables, + target, mapped_collections, - trans_in_fn, trans_out_fn, - init, mutable, - rngs, variables, + trans_in_fn, + trans_out_fn, + init, + mutable, + rngs, + variables, methods=methods, ) @@ -972,7 +1039,7 @@ def vjp( vjp_variables: lift.CollectionFilter = 'params', variables: lift.CollectionFilter = True, rngs: lift.PRNGSequenceFilter = True, - ) -> Tuple[Any, Any]: +) -> Tuple[Any, Any]: """A lifted version of ``jax.vjp``. See ``jax.vjp`` for the unlifted vector-Jacobiam product (backward gradient). @@ -1034,12 +1101,17 @@ def __call__(self, x, y): returned by ``fn``. """ return lift_direct_transform( - lift.vjp, (fn,), mdl, *primals, + lift.vjp, + (fn,), + mdl, + *primals, multi_scope=False, - has_aux=has_aux, reduce_axes=reduce_axes, + has_aux=has_aux, + reduce_axes=reduce_axes, vjp_variables=vjp_variables, variables=variables, - rngs=rngs) + rngs=rngs, + ) def jvp( @@ -1118,10 +1190,16 @@ def f(scope, x): ``primals_out``. """ return lift_direct_transform( - lift.jvp, (fn,), mdl, primals, tangents, variable_tangents, + lift.jvp, + (fn,), + mdl, + primals, + tangents, + variable_tangents, multi_scope=False, variables=variables, - rngs=rngs) + rngs=rngs, + ) ModuleT = TypeVar('ModuleT', bound=Module) @@ -1135,7 +1213,8 @@ def while_loop( init: C, carry_variables: lift.CollectionFilter = False, broadcast_variables: lift.CollectionFilter = True, - split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict()) -> C: + split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict(), +) -> C: """Lifted version of jax.lax.while_loop. The lifted scope is passed to `cond_fn` and `body_fn`. @@ -1184,10 +1263,14 @@ def body_fn(mdl, c): The final state after executing the while loop. """ return lift_direct_transform( - lift.while_loop, (cond_fn, body_fn), mdl, + lift.while_loop, + (cond_fn, body_fn), + mdl, init, - carry_variables, broadcast_variables, - split_rngs) + carry_variables, + broadcast_variables, + split_rngs, + ) def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs): @@ -1196,10 +1279,13 @@ def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs): def cond( pred: Any, - true_fun: Callable[..., C], false_fun: Callable[..., C], - mdl: Module, *operands, + true_fun: Callable[..., C], + false_fun: Callable[..., C], + mdl: Module, + *operands, variables: lift.CollectionFilter = True, - rngs: lift.PRNGSequenceFilter = True) -> C: + rngs: lift.PRNGSequenceFilter = True, +) -> C: """Lifted version of ``jax.lax.cond``. The returned values from ``true_fun`` and ``false_fun`` @@ -1242,9 +1328,15 @@ def false_fn(mdl, x): The result of the evaluated branch (``true_fun`` or ``false_fun``). """ return lift_direct_transform( - _cond_wrapper, (true_fun, false_fun), mdl, - pred, *operands, - variables=variables, rngs=rngs) + _cond_wrapper, + (true_fun, false_fun), + mdl, + pred, + *operands, + variables=variables, + rngs=rngs, + ) + def _switch_wrapper(*args, variables, rngs, n_branches): # first n_branches arguments are branches. @@ -1253,12 +1345,15 @@ def _switch_wrapper(*args, variables, rngs, n_branches): scope, index, *operands = args[n_branches:] return lift.switch(index, branches, scope, *operands, variables=variables, rngs=rngs) + def switch( index: Any, branches: Sequence[Callable[..., C]], - mdl: Module, *operands, + mdl: Module, + *operands, variables: lift.CollectionFilter = True, - rngs: lift.PRNGSequenceFilter = True) -> C: + rngs: lift.PRNGSequenceFilter = True, +) -> C: """Lifted version of ``jax.lax.switch``. The returned values from ``branches`` @@ -1326,9 +1421,16 @@ def head_fn(i): The result of the evaluated branch. """ return lift_direct_transform( - _switch_wrapper, tuple(branches), mdl, - index, *operands, - variables=variables, rngs=rngs, n_branches=len(branches)) + _switch_wrapper, + tuple(branches), + mdl, + index, + *operands, + variables=variables, + rngs=rngs, + n_branches=len(branches), + ) + # a version of lift.custom_vjp with a single scope function # this avoids having to lift multiple functions in @@ -1337,18 +1439,20 @@ def _custom_vjp_single_scope_fn( fn: Callable[..., Any], backward_fn: Callable[..., Any], grad_vars: lift.CollectionFilter = 'params', - nondiff_argnums=()): + nondiff_argnums=(), +): nodiff_fn = functools.partial(fn, needs_residual=False) forward_fn = functools.partial(fn, needs_residual=True) - return lift.custom_vjp(nodiff_fn, forward_fn, backward_fn, grad_vars, - nondiff_argnums) + return lift.custom_vjp(nodiff_fn, forward_fn, backward_fn, grad_vars, nondiff_argnums) -def custom_vjp(fn: Callable[..., Any], - forward_fn: Callable[..., Any], - backward_fn: Callable[..., Any], - grad_vars: lift.CollectionFilter = 'params', - nondiff_argnums=()): +def custom_vjp( + fn: Callable[..., Any], + forward_fn: Callable[..., Any], + backward_fn: Callable[..., Any], + grad_vars: lift.CollectionFilter = 'params', + nondiff_argnums=(), +): """Lifted version of `jax.custom_vjp`. `forward_fn` and `backward_fn` together define a custom vjp for `fn`. @@ -1406,16 +1510,21 @@ def bwd(vjp_fn, y_t): Returns: A function with the same signature as `fn` with the custom vjp. """ + def shared_forward_fn(*args, needs_residual, **kwargs): if needs_residual: return forward_fn(*args, **kwargs) else: - return fn(*args, ** kwargs) + return fn(*args, **kwargs) + return decorator_lift_transform( - _custom_vjp_single_scope_fn, shared_forward_fn, - backward_fn=backward_fn, grad_vars=grad_vars, + _custom_vjp_single_scope_fn, + shared_forward_fn, + backward_fn=backward_fn, + grad_vars=grad_vars, nondiff_argnums=nondiff_argnums, - multi_scope=False) + multi_scope=False, + ) def named_call(class_fn, force=True): @@ -1431,22 +1540,23 @@ def named_call(class_fn, force=True): Returns: A wrapped version of ``class_fn`` that is labeled. """ + # We use JAX's dynamic name-stack named_call. No transform boundary needed! @functools.wraps(class_fn) def wrapped_fn(self, *args, **kwargs): - if ((not force and not linen_module._use_named_call) # pylint: disable=protected-access - or self._state.in_setup): # pylint: disable=protected-access + if (not force and not linen_module._use_named_call) or self._state.in_setup: # pylint: disable=protected-access # pylint: disable=protected-access return class_fn(self, *args, **kwargs) full_name = _derive_profiling_name(self, class_fn) return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs) + return wrapped_fn def add_metadata_axis( target: Target, - variable_axes: Mapping[lift.CollectionFilter, - lift.InOutAxis] = FrozenDict(), - metadata_params: Dict[Any, Any] = {}) -> Target: + variable_axes: Mapping[lift.CollectionFilter, lift.InOutAxis] = FrozenDict(), + metadata_params: Dict[Any, Any] = {}, +) -> Target: """A helper to manipulate boxed axis metadata. This is a helper to manipulate the *metadata* in boxed variables, similar @@ -1466,10 +1576,13 @@ def add_metadata_axis( A transformed version of ``target`` that performs a transform of the axis metadata on its variables. """ + def add_fn(axis): return lambda x: meta.add_axis(x, axis, metadata_params) + def remove_fn(axis): return lambda x: meta.remove_axis(x, axis, metadata_params) + for col_name, axis in variable_axes.items(): target = map_variables( target, @@ -1478,4 +1591,4 @@ def remove_fn(axis): trans_out_fn=add_fn(axis), mutable=True, ) - return target \ No newline at end of file + return target diff --git a/flax/metrics/__init__.py b/flax/metrics/__init__.py index cd2b215030..e80ba0b35f 100644 --- a/flax/metrics/__init__.py +++ b/flax/metrics/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/flax/metrics/tensorboard.py b/flax/metrics/tensorboard.py index 8885cbe528..0d7f92e0a7 100644 --- a/flax/metrics/tensorboard.py +++ b/flax/metrics/tensorboard.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Write Summaries from JAX for use with Tensorboard. -""" +"""Write Summaries from JAX for use with Tensorboard.""" import contextlib import functools @@ -43,7 +42,16 @@ def _flatten_dict(input_dict, parent_key='', sep='.'): new_key = parent_key + sep + k if parent_key else k # Valid types according to https://github.com/tensorflow/tensorboard/blob/1204566da5437af55109f7a4af18f9f8b7c4f864/tensorboard/plugins/hparams/summary_v2.py - valid_types = (bool, int, float, str, np.bool_, np.integer, np.floating, np.character) + valid_types = ( + bool, + int, + float, + str, + np.bool_, + np.integer, + np.floating, + np.character, + ) if isinstance(v, dict): # Recursively flatten the dict. @@ -167,8 +175,13 @@ def audio(self, tag, audiodata, step, sample_rate=44100, max_outputs=3): audio = tf.convert_to_tensor(audiodata, dtype=tf.float32) with self._as_default(self._event_writer): tf.summary.audio( - name=tag, data=audio, sample_rate=sample_rate, step=step, - max_outputs=max_outputs, encoding='wav') + name=tag, + data=audio, + sample_rate=sample_rate, + step=step, + max_outputs=max_outputs, + encoding='wav', + ) def histogram(self, tag, values, step, bins=None): """Saves histogram of values. @@ -211,11 +224,7 @@ def write(self, tag, tensor, step, metadata=None): Note: markdown formatting is rendered by tensorboard. """ with self._as_default(self._event_writer): - tf.summary.write( - tag=tag, - tensor=tensor, - step=step, - metadata=metadata) + tf.summary.write(tag=tag, tensor=tensor, step=step, metadata=metadata) def hparams(self, hparams): """Saves hyper parameters. diff --git a/flax/serialization.py b/flax/serialization.py index 1cd3a20098..a5b0235786 100644 --- a/flax/serialization.py +++ b/flax/serialization.py @@ -36,6 +36,7 @@ class _ErrorContext(threading.local): def __init__(self): self.path = [] + _error_context = _ErrorContext() @@ -55,6 +56,7 @@ def current_path(): class _NamedTuple: """Fake type marker for namedtuple for registry.""" + pass @@ -109,8 +111,9 @@ def to_state_dict(target) -> Dict[str, Any]: return state_dict -def register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, - override=False): +def register_serialization_state( + ty, ty_to_state_dict, ty_from_state_dict, override=False +): """Register a type for serialization. Args: @@ -123,8 +126,9 @@ def register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, (default: False). """ if ty in _STATE_DICT_REGISTRY and not override: - raise ValueError(f'a serialization handler for "{ty.__name__}"' - ' is already registered') + raise ValueError( + f'a serialization handler for "{ty.__name__}"' ' is already registered' + ) _STATE_DICT_REGISTRY[ty] = (ty_to_state_dict, ty_from_state_dict) @@ -134,9 +138,11 @@ def _list_state_dict(xs: List[Any]) -> Dict[str, Any]: def _restore_list(xs, state_dict: Dict[str, Any]) -> List[Any]: if len(state_dict) != len(xs): - raise ValueError('The size of the list and the state dict do not match,' - f' got {len(xs)} and {len(state_dict)} ' - f'at path {current_path()}') + raise ValueError( + 'The size of the list and the state dict do not match,' + f' got {len(xs)} and {len(state_dict)} ' + f'at path {current_path()}' + ) ys = [] for i in range(len(state_dict)): y = from_state_dict(xs[i], state_dict[str(i)], name=str(i)) @@ -157,12 +163,16 @@ def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]: def _restore_dict(xs, states: Dict[str, Any]) -> Dict[str, Any]: diff = set(map(str, xs.keys())).difference(states.keys()) if diff: - raise ValueError('The target dict keys and state dict keys do not match,' - f' target dict contains keys {diff} which are not present in state dict ' - f'at path {current_path()}') + raise ValueError( + 'The target dict keys and state dict keys do not match,' + f' target dict contains keys {diff} which are not present in state dict ' + f'at path {current_path()}' + ) - return {key: from_state_dict(value, states[str(key)], name=str(key)) - for key, value in xs.items()} + return { + key: from_state_dict(value, states[str(key)], name=str(key)) + for key, value in xs.items() + } def _namedtuple_state_dict(nt) -> Dict[str, Any]: @@ -173,8 +183,10 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]): """Rebuild namedtuple from serialized dict.""" if set(state_dict.keys()) == {'name', 'fields', 'values'}: # TODO(jheek): remove backward compatible named tuple restoration early 2022 - state_dict = {state_dict['fields'][str(i)]: state_dict['values'][str(i)] - for i in range(len(state_dict['fields']))} + state_dict = { + state_dict['fields'][str(i)]: state_dict['values'][str(i)] + for i in range(len(state_dict['fields'])) + } sd_keys = set(state_dict.keys()) nt_keys = set(xs._fields) @@ -182,10 +194,10 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]): if sd_keys != nt_keys: raise ValueError( 'The field names of the state dict and the named tuple do not match,' - f' got {sd_keys} and {nt_keys} at path {current_path()}') + f' got {sd_keys} and {nt_keys} at path {current_path()}' + ) fields = { - k: from_state_dict(getattr(xs, k), v, name=k) - for k, v in state_dict.items() + k: from_state_dict(getattr(xs, k), v, name=k) for k, v in state_dict.items() } return type(xs)(**fields) @@ -193,24 +205,22 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]): register_serialization_state(dict, _dict_state_dict, _restore_dict) register_serialization_state(list, _list_state_dict, _restore_list) register_serialization_state( - tuple, _list_state_dict, - lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict))) -register_serialization_state(_NamedTuple, - _namedtuple_state_dict, - _restore_namedtuple) + tuple, + _list_state_dict, + lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict)), +) +register_serialization_state(_NamedTuple, _namedtuple_state_dict, _restore_namedtuple) register_serialization_state( jax.tree_util.Partial, - lambda x: ( - { - "args": to_state_dict(x.args), - "keywords": to_state_dict(x.keywords), - } - ), + lambda x: ({ + 'args': to_state_dict(x.args), + 'keywords': to_state_dict(x.keywords), + }), lambda x, sd: jax.tree_util.Partial( x.func, - *from_state_dict(x.args, sd["args"]), - **from_state_dict(x.keywords, sd["keywords"]), + *from_state_dict(x.args, sd['args']), + **from_state_dict(x.keywords, sd['keywords']), ), ) @@ -232,8 +242,9 @@ def _ndarray_to_bytes(arr) -> bytes: if isinstance(arr, jax.Array): arr = np.array(arr) if arr.dtype.hasobject or arr.dtype.isalignedstruct: - raise ValueError('Object and structured dtypes not supported ' - 'for serialization of ndarrays.') + raise ValueError( + 'Object and structured dtypes not supported ' 'for serialization of ndarrays.' + ) tpl = (arr.shape, arr.dtype.name, arr.tobytes('C')) return msgpack.packb(tpl, use_bin_type=True) @@ -249,14 +260,14 @@ def _dtype_from_name(name: str): def _ndarray_from_bytes(data: bytes) -> np.ndarray: """Load ndarray from simple msgpack encoding.""" shape, dtype_name, buffer = msgpack.unpackb(data, raw=True) - return np.frombuffer(buffer, - dtype=_dtype_from_name(dtype_name), - count=-1, - offset=0).reshape(shape, order='C') + return np.frombuffer( + buffer, dtype=_dtype_from_name(dtype_name), count=-1, offset=0 + ).reshape(shape, order='C') class _MsgpackExtType(enum.IntEnum): """Messagepack custom type ids.""" + ndarray = 1 native_complex = 2 npscalar = 3 @@ -270,11 +281,11 @@ def _msgpack_ext_pack(x): return msgpack.ExtType(_MsgpackExtType.ndarray, _ndarray_to_bytes(x)) if np.issctype(type(x)): # pack scalar as ndarray - return msgpack.ExtType(_MsgpackExtType.npscalar, - _ndarray_to_bytes(np.asarray(x))) + return msgpack.ExtType(_MsgpackExtType.npscalar, _ndarray_to_bytes(np.asarray(x))) elif isinstance(x, complex): - return msgpack.ExtType(_MsgpackExtType.native_complex, - msgpack.packb((x.real, x.imag))) + return msgpack.ExtType( + _MsgpackExtType.native_complex, msgpack.packb((x.real, x.imag)) + ) return x @@ -321,10 +332,9 @@ def _np_convert_in_place(d): def _chunk(arr) -> Dict[str, Any]: """Convert array to a canonical dictionary of chunked arrays.""" chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize)) - data = {'__msgpack_chunked_array__': True, - 'shape': _tuple_to_dict(arr.shape)} + data = {'__msgpack_chunked_array__': True, 'shape': _tuple_to_dict(arr.shape)} flatarr = arr.reshape(-1) - chunks = [flatarr[i:i + chunksize] for i in range(0, flatarr.size, chunksize)] + chunks = [flatarr[i : i + chunksize] for i in range(0, flatarr.size, chunksize)] data['chunks'] = _tuple_to_dict(chunks) return data @@ -404,8 +414,7 @@ def msgpack_restore(encoded_pytree: bytes): Python tree of dict, list, tuple with python primitive and array leaves. """ - state_dict = msgpack.unpackb( - encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False) + state_dict = msgpack.unpackb(encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False) return _unchunk_array_leaves_in_place(state_dict) diff --git a/flax/struct.py b/flax/struct.py index aeafdf4c15..7dc225a4cd 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for defining custom classes that can be used with jax transformations. -""" +"""Utilities for defining custom classes that can be used with jax transformations.""" import dataclasses from typing import TypeVar, Callable, Tuple, Union, Any @@ -24,14 +23,14 @@ from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet -_T = TypeVar("_T") +_T = TypeVar('_T') def field(pytree_node=True, **kwargs): return dataclasses.field(metadata={'pytree_node': pytree_node}, **kwargs) -@dataclass_transform(field_descriptors=(field,)) # type: ignore[literal-required] +@dataclass_transform(field_descriptors=(field,)) # type: ignore[literal-required] def dataclass(clz: _T) -> _T: """Create a class which can be passed to functional transformations. @@ -98,7 +97,7 @@ def create(cls, kernel): if '_flax_dataclass' in clz.__dict__: return clz - data_clz = dataclasses.dataclass(frozen=True)(clz) # type: ignore + data_clz = dataclasses.dataclass(frozen=True)(clz) # type: ignore meta_fields = [] data_fields = [] for field_info in dataclasses.fields(data_clz): @@ -109,7 +108,7 @@ def create(cls, kernel): meta_fields.append(field_info.name) def replace(self, **updates): - """"Returns a new object replacing the specified fields with new values.""" + """ "Returns a new object replacing the specified fields with new values.""" return dataclasses.replace(self, **updates) data_clz.replace = replace @@ -122,8 +121,7 @@ def iterate_clz(x): def iterate_clz_with_keys(x): meta = tuple(getattr(x, name) for name in meta_fields) data = tuple( - (jax.tree_util.GetAttrKey(name), getattr(x, name)) - for name in data_fields + (jax.tree_util.GetAttrKey(name), getattr(x, name)) for name in data_fields ) return data, meta @@ -138,8 +136,9 @@ def clz_from_iterable(meta, data): ) def to_state_dict(x): - state_dict = {name: serialization.to_state_dict(getattr(x, name)) - for name in data_fields} + state_dict = { + name: serialization.to_state_dict(getattr(x, name)) for name in data_fields + } return state_dict def from_state_dict(x, state): @@ -148,32 +147,35 @@ def from_state_dict(x, state): updates = {} for name in data_fields: if name not in state: - raise ValueError(f'Missing field {name} in state dict while restoring' - f' an instance of {clz.__name__},' - f' at path {serialization.current_path()}') + raise ValueError( + f'Missing field {name} in state dict while restoring' + f' an instance of {clz.__name__},' + f' at path {serialization.current_path()}' + ) value = getattr(x, name) value_state = state.pop(name) updates[name] = serialization.from_state_dict(value, value_state, name=name) if state: names = ','.join(state.keys()) - raise ValueError(f'Unknown field(s) "{names}" in state dict while' - f' restoring an instance of {clz.__name__}' - f' at path {serialization.current_path()}') + raise ValueError( + f'Unknown field(s) "{names}" in state dict while' + f' restoring an instance of {clz.__name__}' + f' at path {serialization.current_path()}' + ) return x.replace(**updates) - serialization.register_serialization_state( - data_clz, to_state_dict, from_state_dict) + serialization.register_serialization_state(data_clz, to_state_dict, from_state_dict) # add a _flax_dataclass flag to distinguish from regular dataclasses - data_clz._flax_dataclass = True # type: ignore[attr-defined] + data_clz._flax_dataclass = True # type: ignore[attr-defined] - return data_clz # type: ignore + return data_clz # type: ignore TNode = TypeVar('TNode', bound='PyTreeNode') -@dataclass_transform(field_descriptors=(field,)) # type: ignore[literal-required] +@dataclass_transform(field_descriptors=(field,)) # type: ignore[literal-required] class PyTreeNode: """Base class for dataclasses that should act like a JAX pytree node. diff --git a/flax/testing/benchmark.py b/flax/testing/benchmark.py index 74e8e874a6..9717a95e8f 100644 --- a/flax/testing/benchmark.py +++ b/flax/testing/benchmark.py @@ -45,20 +45,22 @@ flags.DEFINE_string( - 'benchmark_output_dir', default=None, help='Benchmark output directory.') + 'benchmark_output_dir', default=None, help='Benchmark output directory.' +) FLAGS = flags.FLAGS -_SCALAR_PLUGIN_NAME = summary_lib.scalar_pb( - '', 0).value[0].metadata.plugin_data.plugin_name +_SCALAR_PLUGIN_NAME = ( + summary_lib.scalar_pb('', 0).value[0].metadata.plugin_data.plugin_name +) def _make_events_generator(path): """Makes a generator yielding TensorBoard events from files in `path`.""" return directory_watcher.DirectoryWatcher( - path, event_file_loader.EventFileLoader, - io_wrapper.IsSummaryEventsFile).Load() + path, event_file_loader.EventFileLoader, io_wrapper.IsSummaryEventsFile + ).Load() def _is_scalar_value(value): @@ -76,8 +78,12 @@ def _process_event(event): continue if value.HasField('tensor'): - yield (value.tag, event.wall_time, - event.step, tensor_util.make_ndarray(value.tensor).item()) + yield ( + value.tag, + event.wall_time, + event.step, + tensor_util.make_ndarray(value.tensor).item(), + ) def _get_tensorboard_scalars(path): @@ -118,13 +124,11 @@ def __init__(self, *args, **kwargs): for func_name in dir(self): if func_name.startswith('assert'): func = getattr(self, func_name) - patched_func = functools.partial( - self._collect_assert_wrapper, fn=func) + patched_func = functools.partial(self._collect_assert_wrapper, fn=func) setattr(self, func_name, patched_func) # Create target directory if defined. - if FLAGS.benchmark_output_dir and not io.exists( - FLAGS.benchmark_output_dir): + if FLAGS.benchmark_output_dir and not io.exists(FLAGS.benchmark_output_dir): io.makedirs(FLAGS.benchmark_output_dir) # pylint: disable=invalid-name @@ -162,8 +166,9 @@ def get_tmp_model_dir(self): model_dir = FLAGS.benchmark_output_dir else: model_dir = tempfile.mkdtemp() - model_dir_path = os.path.join(model_dir, self._reported_name or - self._get_test_name()) + model_dir_path = os.path.join( + model_dir, self._reported_name or self._get_test_name() + ) # Create directories if they don't exist. if not io.exists(model_dir_path): io.makedirs(model_dir_path) @@ -255,15 +260,19 @@ def _report_benchmark_results(self): """ name = self._reported_name if not name: - raise ValueError('Unable to determine test name for reporting ' - 'benchmark results. Make sure you are using ' - '`self.report_*` methods.') + raise ValueError( + 'Unable to determine test name for reporting ' + 'benchmark results. Make sure you are using ' + '`self.report_*` methods.' + ) succeeded = not self.has_outstanding_fails() - results = {'name': name, - 'succeeded': succeeded, - 'metrics': self._reported_metrics, - 'extras': self._reported_extras} + results = { + 'name': name, + 'succeeded': succeeded, + 'metrics': self._reported_metrics, + 'extras': self._reported_extras, + } if self._reported_wall_time is not None: results['wall_time'] = self._reported_wall_time if not succeeded: diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 9623cb7776..df70dba214 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -50,20 +50,21 @@ try: from jax.experimental.array_serialization.serialization import get_tensorstore_spec from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager + _IMPORT_GDAM_SUCCESSFUL = True except ImportError: - logging.warning('GlobalAsyncCheckpointManager is not imported correctly. ' - 'Checkpointing of GlobalDeviceArrays will not be available.' - 'To use the feature, install tensorstore.') + logging.warning( + 'GlobalAsyncCheckpointManager is not imported correctly. ' + 'Checkpointing of GlobalDeviceArrays will not be available.' + 'To use the feature, install tensorstore.' + ) # Single-group reg-exps for int or float numerical substrings. # captures sign: -SIGNED_FLOAT_RE = re.compile( - r'([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') +SIGNED_FLOAT_RE = re.compile(r'([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') # does not capture sign: -UNSIGNED_FLOAT_RE = re.compile( - r'[-+]?((?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') +UNSIGNED_FLOAT_RE = re.compile(r'[-+]?((?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') # Module name followed by number. MODULE_NUM_RE = re.compile(r'(.*)_\d+$') # Alternative schemes handled by `gfile`, e.g. on Google Cloud Storage (GCS). @@ -99,9 +100,9 @@ def _is_multiprocess_array(value: Any) -> bool: return False -def _checkpoint_path(ckpt_dir: str, - step: Union[int, float, str], - prefix: str = 'checkpoint_') -> str: +def _checkpoint_path( + ckpt_dir: str, step: Union[int, float, str], prefix: str = 'checkpoint_' +) -> str: return os.path.join(ckpt_dir, f'{prefix}{step}') @@ -112,12 +113,14 @@ def _checkpoint_path_step(path: str) -> Optional[float]: return float(s) return None + def _allowempty_listdir(path: str): try: return io.listdir(path) except io.NotFoundError: return [] + def _safe_remove(path: str): """Identify whether a path is a dir or list and choose the correct remove method.""" if io.isdir(path): @@ -125,7 +128,8 @@ def _safe_remove(path: str): else: io.remove(path) -class AsyncManager(): + +class AsyncManager: """A simple object to track async checkpointing. How to use: create an instance and pass to save_checkpoint() calls: @@ -155,7 +159,7 @@ def save_async(self, task: Callable[[], Any]): task: The callable to be executed asynchronously. """ self.wait_previous_save() - self.save_future = self.executor.submit(task) # type: ignore + self.save_future = self.executor.submit(task) # type: ignore def _split_mp_arrays( @@ -179,8 +183,7 @@ def _split_mp_arrays( return target, mpa_targets -def _make_mpa_dirs(mpa_targets: List[Tuple[MultiprocessArrayType, str]], - tmp_path: str): +def _make_mpa_dirs(mpa_targets: List[Tuple[MultiprocessArrayType, str]], tmp_path: str): # Temporary array path is not used in GCS. if tmp_path.startswith('gs://'): return @@ -195,14 +198,24 @@ def _make_mpa_dirs(mpa_targets: List[Tuple[MultiprocessArrayType, str]], io.makedirs(os.path.join(mpa_tmp_path, subpath)) -def _save_mpas(gda_manager, mpa_targets: List[Tuple[MultiprocessArrayType, str]], - tmp_path: str, final_path: str, base_path: str, keep: int, - overwrite: bool, keep_every_n_steps: Optional[int], - ckpt_start_time: float, - async_manager: Optional[AsyncManager] = None): +def _save_mpas( + gda_manager, + mpa_targets: List[Tuple[MultiprocessArrayType, str]], + tmp_path: str, + final_path: str, + base_path: str, + keep: int, + overwrite: bool, + keep_every_n_steps: Optional[int], + ckpt_start_time: float, + async_manager: Optional[AsyncManager] = None, +): """Save the multiprocess arrays given the paths.""" mpa_list, mpa_subpaths = zip(*mpa_targets) - mpa_tmp_path, mpa_final_path = tmp_path + MP_ARRAY_POSTFIX, final_path + MP_ARRAY_POSTFIX + mpa_tmp_path, mpa_final_path = ( + tmp_path + MP_ARRAY_POSTFIX, + final_path + MP_ARRAY_POSTFIX, + ) write_commit_success = False # If the checkpoint directory is a GCS directory, then keep the final # checkpoint directory as the temporary checkpoint directory. This is because @@ -228,15 +241,19 @@ def _save_mpas(gda_manager, mpa_targets: List[Tuple[MultiprocessArrayType, str]] ckpt_start_time, has_mpa=True, write_commit_success=write_commit_success, - async_manager=async_manager)) + async_manager=async_manager, + ), + ) -def _restore_mpas(state_dict, - target: Optional[Any], - ckpt_path: str, - step: Optional[Union[int, float]], - gda_manager: Optional[Any], - allow_partial: bool = False): +def _restore_mpas( + state_dict, + target: Optional[Any], + ckpt_path: str, + step: Optional[Union[int, float]], + gda_manager: Optional[Any], + allow_partial: bool = False, +): """Restore the multiprocess arrays given the target structure and type.""" def _check_mpa_errors(): @@ -247,12 +264,14 @@ def _check_mpa_errors(): def _safe_deserialize( target_mpas: List[Tuple[Tuple[Any, ...], MultiprocessArrayType, str]], - gda_manager: Any) -> List[MultiprocessArrayType]: + gda_manager: Any, + ) -> List[MultiprocessArrayType]: gda_manager.wait_until_finished() # Check if reading from GCS and the array dir is potentially corrupted. if ckpt_path.startswith('gs://') and not io.exists( - os.path.join(ckpt_path + MP_ARRAY_POSTFIX, COMMIT_SUCCESS_FILE)): + os.path.join(ckpt_path + MP_ARRAY_POSTFIX, COMMIT_SUCCESS_FILE) + ): raise errors.MPARestoreDataCorruptedError(step, ckpt_path) # Check if the given target array types are valid. @@ -267,11 +286,15 @@ def _safe_deserialize( # When target is a single leaf instead of a pytree dict. if not isinstance(state_dict, (core.FrozenDict, dict)): - if _is_multiprocess_array(target) and isinstance( - state_dict, str) and state_dict.startswith(MP_ARRAY_PH): + if ( + _is_multiprocess_array(target) + and isinstance(state_dict, str) + and state_dict.startswith(MP_ARRAY_PH) + ): _check_mpa_errors() - return _safe_deserialize([((), target, ckpt_path + MP_ARRAY_POSTFIX)], - gda_manager)[0] + return _safe_deserialize( + [((), target, ckpt_path + MP_ARRAY_POSTFIX)], gda_manager + )[0] return state_dict # Go through the restored checkpoint pytree for all MPAs @@ -279,15 +302,19 @@ def _safe_deserialize( target_flattened = {} if target: target_flattened = traverse_util.flatten_dict( - serialization.to_state_dict(target), keep_empty_nodes=True) + serialization.to_state_dict(target), keep_empty_nodes=True + ) # A list of (state_dict_key, target_array, array_file_path) for every array # to be restored target_mpas = [] for key, value in flattened.items(): if isinstance(value, str) and value.startswith(MP_ARRAY_PH): _check_mpa_errors() - if not target or (key not in target_flattened) or ( - not _is_multiprocess_array(target_flattened[key])): + if ( + not target + or (key not in target_flattened) + or (not _is_multiprocess_array(target_flattened[key])) + ): if allow_partial: logging.warning( 'Multiprocess array %s could not be restored because a valid' @@ -299,8 +326,7 @@ def _safe_deserialize( else: raise errors.MPARestoreTargetRequiredError(ckpt_path, step, key) else: - mpa_path = os.path.join(ckpt_path + MP_ARRAY_POSTFIX, - value[len(MP_ARRAY_PH):]) + mpa_path = os.path.join(ckpt_path + MP_ARRAY_POSTFIX, value[len(MP_ARRAY_PH) :]) target_mpas.append((key, target_flattened[key], mpa_path)) # If any MPA needs to be restored, call deserialize @@ -327,13 +353,16 @@ def natural_sort(file_list: Iterable[str], signed: bool = True) -> List[str]: file_0.1, file_-0.2, file_2.0 --> file_-0.2, file_0.1, file_2.0 """ float_re = SIGNED_FLOAT_RE if signed else UNSIGNED_FLOAT_RE + def maybe_num(s): if float_re.match(s): return float(s) else: return s + def split_keys(s): return [maybe_num(c) for c in float_re.split(s)] + return sorted(file_list, key=split_keys) @@ -345,11 +374,15 @@ def safe_normpath(path: str) -> str: return (d['scheme'] or '') + os.path.normpath(d['path']) -def _remove_invalid_ckpts(ckpt_path: str, base_path: str, keep: int, - overwrite: bool, keep_every_n_steps: Optional[int], - has_mpa: bool) -> None: - """Clean up the checkpoint space according to `overwrite`, `keep`, and `keep_every_n_steps` parameters. - """ +def _remove_invalid_ckpts( + ckpt_path: str, + base_path: str, + keep: int, + overwrite: bool, + keep_every_n_steps: Optional[int], + has_mpa: bool, +) -> None: + """Clean up the checkpoint space according to `overwrite`, `keep`, and `keep_every_n_steps` parameters.""" dir_path, prefix = os.path.split(base_path) checkpoint_files: List[Any] = [ pathlib.PurePath(c) for c in _allowempty_listdir(dir_path) @@ -385,9 +418,12 @@ def _remove_invalid_ckpts(ckpt_path: str, base_path: str, keep: int, if keep_every_n_steps: step_number = _checkpoint_path_step(path) if step_number and (step_number - last_kept) >= keep_every_n_steps: - logging.debug('Not deleting %s, because last_kept=%f and keeping ' - 'every %d steps.', - path, last_kept, keep_every_n_steps) + logging.debug( + 'Not deleting %s, because last_kept=%f and keeping ' 'every %d steps.', + path, + last_kept, + keep_every_n_steps, + ) last_kept = step_number continue logging.info('Removing checkpoint at %s', path) @@ -398,11 +434,18 @@ def _remove_invalid_ckpts(ckpt_path: str, base_path: str, keep: int, _safe_remove(path) -def _save_commit(ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int, - overwrite: bool, keep_every_n_steps: Optional[int], - ckpt_start_time: float, has_mpa: bool, - write_commit_success: bool, - async_manager: Optional[AsyncManager] = None) -> None: +def _save_commit( + ckpt_tmp_path: str, + ckpt_path: str, + base_path: str, + keep: int, + overwrite: bool, + keep_every_n_steps: Optional[int], + ckpt_start_time: float, + has_mpa: bool, + write_commit_success: bool, + async_manager: Optional[AsyncManager] = None, +) -> None: """Commit changes after saving checkpoints to disk. This function does the following, sequentially: @@ -413,7 +456,10 @@ def _save_commit(ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int, 3. Remove old checkpoint files based on `keep` and `keep_every_n_steps`. 4. Record program duration saved by this checkpoint. """ - mpa_ckpt_tmp_path, mpa_ckpt_path = ckpt_tmp_path + MP_ARRAY_POSTFIX, ckpt_path + MP_ARRAY_POSTFIX + mpa_ckpt_tmp_path, mpa_ckpt_path = ( + ckpt_tmp_path + MP_ARRAY_POSTFIX, + ckpt_path + MP_ARRAY_POSTFIX, + ) # Rename the multiprocess array path once serialization and writing finished. if has_mpa: if write_commit_success: @@ -437,21 +483,25 @@ def _save_commit(ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int, logging.info('Saved checkpoint at %s', ckpt_path) # Remove newer and older invalid checkpoints. - _remove_invalid_ckpts(ckpt_path, base_path, keep, overwrite, - keep_every_n_steps, has_mpa) + _remove_invalid_ckpts( + ckpt_path, base_path, keep, overwrite, keep_every_n_steps, has_mpa + ) # Record checkpoint-related metrics. ocp.utils.record_saved_duration(ckpt_start_time) if async_manager: jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/total_duration_secs', - time.time() - ckpt_start_time) + '/jax/checkpoint/write/async/total_duration_secs', time.time() - ckpt_start_time + ) -def _check_overwrite_error(ckpt_tmp_path: str, ckpt_path: str, base_path: str, - step: int): +def _check_overwrite_error( + ckpt_tmp_path: str, ckpt_path: str, base_path: str, step: int +): """Throw error if a ckpt file of this step or higher already exists.""" dir_path, prefix = os.path.split(base_path) - checkpoint_files: List[Any] = [pathlib.PurePath(c) for c in _allowempty_listdir(dir_path)] + checkpoint_files: List[Any] = [ + pathlib.PurePath(c) for c in _allowempty_listdir(dir_path) + ] checkpoint_files = [ os.path.join(dir_path, c) for c in checkpoint_files @@ -470,11 +520,17 @@ def _check_overwrite_error(ckpt_tmp_path: str, ckpt_path: str, base_path: str, raise errors.InvalidCheckpointError(ckpt_path, step) -def _save_main_ckpt_file(target: bytes, has_mpa: bool, paths: Tuple[str, str], - base_path: str, step: int, - keep: int, overwrite: bool, - keep_every_n_steps: Optional[int], - ckpt_start_time: float): +def _save_main_ckpt_file( + target: bytes, + has_mpa: bool, + paths: Tuple[str, str], + base_path: str, + step: int, + keep: int, + overwrite: bool, + keep_every_n_steps: Optional[int], + ckpt_start_time: float, +): """Save the main checkpoint file via file system.""" ckpt_tmp_path, ckpt_path = paths io.makedirs(os.path.dirname(ckpt_path)) @@ -493,13 +549,15 @@ def _save_main_ckpt_file(target: bytes, has_mpa: bool, paths: Tuple[str, str], keep_every_n_steps, ckpt_start_time, has_mpa=False, - write_commit_success=False) + write_commit_success=False, + ) def _get_checkpoint_paths( ckpt_dir: Union[str, os.PathLike], step: Union[int, float], - prefix: str = 'checkpoint_') -> Tuple[str, str, str]: + prefix: str = 'checkpoint_', +) -> Tuple[str, str, str]: """Generate the checkpoint paths used in this save operation.""" ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str logging.info('Saving checkpoint at step: %s', step) @@ -564,8 +622,7 @@ def save_checkpoint( if async_manager: async_manager.wait_previous_save() - ckpt_path, ckpt_tmp_path, base_path = _get_checkpoint_paths( - ckpt_dir, step, prefix) + ckpt_path, ckpt_tmp_path, base_path = _get_checkpoint_paths(ckpt_dir, step, prefix) if config.flax_use_orbax_checkpointing or orbax_checkpointer: logging.info( @@ -582,9 +639,7 @@ def save_checkpoint( ) # Make sure any previous work is done before making file changes. - if orbax_checkpointer and isinstance( - orbax_checkpointer, ocp.AsyncCheckpointer - ): + if orbax_checkpointer and isinstance(orbax_checkpointer, ocp.AsyncCheckpointer): orbax_checkpointer.wait_until_finished() # If no checkpointer provided, save synchronously with default setting. if not orbax_checkpointer: @@ -602,15 +657,16 @@ def save_checkpoint( ) save_args = orbax_utils.save_args_from_target(target) - orbax_checkpointer.save( - ckpt_path, target, save_args=save_args, force=overwrite) + orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite) # Do a process check here in case people call this for multihost. if process_index() == 0: - _remove_invalid_ckpts(ckpt_path, base_path, keep, overwrite, - keep_every_n_steps, True) + _remove_invalid_ckpts( + ckpt_path, base_path, keep, overwrite, keep_every_n_steps, True + ) end_time = time.time() - monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, - end_time - start_time) + monitoring.record_event_duration_secs( + _WRITE_CHECKPOINT_EVENT, end_time - start_time + ) return ckpt_path warnings.warn( @@ -624,22 +680,31 @@ def save_checkpoint( DeprecationWarning, ) if not overwrite: - _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore + _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore target = serialization.to_bytes(target) + # Save the files via I/O sync or async. def save_main_ckpt_task(): jax.monitoring.record_event('/jax/flax/checkpoint/save_main_ckpt_task') - return _save_main_ckpt_file(target, False, (ckpt_tmp_path, ckpt_path), - base_path, step, keep, overwrite, - keep_every_n_steps, start_time) + return _save_main_ckpt_file( + target, + False, + (ckpt_tmp_path, ckpt_path), + base_path, + step, + keep, + overwrite, + keep_every_n_steps, + start_time, + ) + if async_manager: async_manager.save_async(save_main_ckpt_task) else: save_main_ckpt_task() end_time = time.time() - monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, - end_time - start_time) + monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, end_time - start_time) return ckpt_path @@ -703,8 +768,7 @@ def save_checkpoint_multiprocess( gda_manager.wait_until_finished() sync_global_devices('Flax:Checkpoint:WaitLastSaveDone') - ckpt_path, ckpt_tmp_path, base_path = _get_checkpoint_paths( - ckpt_dir, step, prefix) + ckpt_path, ckpt_tmp_path, base_path = _get_checkpoint_paths(ckpt_dir, step, prefix) if config.flax_use_orbax_checkpointing or orbax_checkpointer: logging.info( @@ -713,9 +777,7 @@ def save_checkpoint_multiprocess( ' https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting' ) # Make sure any previous work is done before making file changes. - if orbax_checkpointer and isinstance( - orbax_checkpointer, ocp.AsyncCheckpointer - ): + if orbax_checkpointer and isinstance(orbax_checkpointer, ocp.AsyncCheckpointer): orbax_checkpointer.wait_until_finished() # If no checkpointer provided, save synchronously with default setting. @@ -734,14 +796,15 @@ def save_checkpoint_multiprocess( ) if process_index() == 0: - _remove_invalid_ckpts(ckpt_path, base_path, keep, overwrite, - keep_every_n_steps, True) + _remove_invalid_ckpts( + ckpt_path, base_path, keep, overwrite, keep_every_n_steps, True + ) save_args = orbax_utils.save_args_from_target(target) - orbax_checkpointer.save( - ckpt_path, target, save_args=save_args, force=overwrite) + orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite) end_time = time.time() - monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, - end_time - start_time) + monitoring.record_event_duration_secs( + _WRITE_CHECKPOINT_EVENT, end_time - start_time + ) return ckpt_path warnings.warn( @@ -761,14 +824,24 @@ def save_checkpoint_multiprocess( has_mpa = mpa_targets and _IMPORT_GDAM_SUCCESSFUL if not overwrite: - _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore + _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore sync_global_devices('Flax:Checkpoint:CheckOverwriteBeforeSave') + # Save the files via I/O sync or async. def save_main_ckpt_task(): jax.monitoring.record_event('/jax/flax/checkpoint/save_main_ckpt_task') - return _save_main_ckpt_file(target, has_mpa, (ckpt_tmp_path, ckpt_path), - base_path, step, keep, overwrite, - keep_every_n_steps, start_time) + return _save_main_ckpt_file( + target, + has_mpa, + (ckpt_tmp_path, ckpt_path), + base_path, + step, + keep, + overwrite, + keep_every_n_steps, + start_time, + ) + # Write the main checkpoint file only via process 0, to avoid race condition. if process_index() == 0: if async_manager: @@ -784,17 +857,27 @@ def save_main_ckpt_task(): if process_index() == 0: _make_mpa_dirs(mpa_targets, ckpt_tmp_path) sync_global_devices('Flax:Checkpoint:AfterCreateMPADir') - _save_mpas(gda_manager, mpa_targets, ckpt_tmp_path, ckpt_path, base_path, - keep, overwrite, keep_every_n_steps, start_time, async_manager) + _save_mpas( + gda_manager, + mpa_targets, + ckpt_tmp_path, + ckpt_path, + base_path, + keep, + overwrite, + keep_every_n_steps, + start_time, + async_manager, + ) end_time = time.time() - monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, - end_time - start_time) + monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, end_time - start_time) return ckpt_path -def _all_checkpoints(ckpt_dir: Union[str, os.PathLike], - prefix: str = 'checkpoint_') -> List[str]: +def _all_checkpoints( + ckpt_dir: Union[str, os.PathLike], prefix: str = 'checkpoint_' +) -> List[str]: """Retrieve all checkpoint paths in directory. Args: @@ -823,8 +906,9 @@ def _all_checkpoints(ckpt_dir: Union[str, os.PathLike], return [] -def latest_checkpoint(ckpt_dir: Union[str, os.PathLike], - prefix: str = 'checkpoint_') -> Optional[str]: +def latest_checkpoint( + ckpt_dir: Union[str, os.PathLike], prefix: str = 'checkpoint_' +) -> Optional[str]: """Retrieve the path of the latest checkpoint in a directory. Args: @@ -841,9 +925,11 @@ def latest_checkpoint(ckpt_dir: Union[str, os.PathLike], return None -def available_steps(ckpt_dir: Union[str, os.PathLike], - prefix: str = 'checkpoint_', - step_type: Type = int) -> List[Union[int, float]]: +def available_steps( + ckpt_dir: Union[str, os.PathLike], + prefix: str = 'checkpoint_', + step_type: Type = int, +) -> List[Union[int, float]]: """Return step numbers of available checkpoints in a directory. @@ -919,9 +1005,7 @@ def restore_checkpoint( """ start_time = time.time() # Make sure any previous work is done before checking files. - if orbax_checkpointer and isinstance( - orbax_checkpointer, ocp.AsyncCheckpointer - ): + if orbax_checkpointer and isinstance(orbax_checkpointer, ocp.AsyncCheckpointer): orbax_checkpointer.wait_until_finished() ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str @@ -939,10 +1023,11 @@ def restore_checkpoint( if io.exists(os.path.join(ckpt_dir, ORBAX_CKPT_FILENAME)): ckpt_path = ckpt_dir else: - ckpt_path = latest_checkpoint(ckpt_dir, prefix) # type: ignore + ckpt_path = latest_checkpoint(ckpt_dir, prefix) # type: ignore if not ckpt_path: - logging.info('Found no checkpoint files in %s with prefix %s', - ckpt_dir, prefix) + logging.info( + 'Found no checkpoint files in %s with prefix %s', ckpt_dir, prefix + ) return target else: ckpt_path = ckpt_dir @@ -959,23 +1044,17 @@ def restore_checkpoint( restore_kwargs = {} if target is not None: - restore_kwargs['restore_args'] = orbax_utils.restore_args_from_target( - target - ) + restore_kwargs['restore_args'] = orbax_utils.restore_args_from_target(target) if isinstance(orbax_checkpointer._handler, ocp.PyTreeCheckpointHandler): # pylint: disable=protected-access - restore_kwargs['transforms'] = ( - orbax_utils.maybe_construct_transformations( - target, orbax_transforms - ) + restore_kwargs['transforms'] = orbax_utils.maybe_construct_transformations( + target, orbax_transforms ) - restored = orbax_checkpointer.restore( - ckpt_path, item=target, **restore_kwargs) + restored = orbax_checkpointer.restore(ckpt_path, item=target, **restore_kwargs) restored = serialization.to_state_dict(restored) if target is not None: restored = serialization.from_state_dict(target, restored) end_time = time.time() - monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, - end_time - start_time) + monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, end_time - start_time) return restored ckpt_size = io.getsize(ckpt_path) @@ -994,7 +1073,7 @@ def read_chunk(i): f.seek(i * buf_size) buf = f.read(buf_size) if buf: - checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf + checkpoint_contents[i * buf_size : i * buf_size + len(buf)] = buf return len(buf) / buf_size pool_size = 32 @@ -1007,8 +1086,9 @@ def read_chunk(i): state_dict = serialization.msgpack_restore(checkpoint_contents) if _IMPORT_GDAM_SUCCESSFUL: - state_dict = _restore_mpas(state_dict, target, ckpt_path, step, gda_manager, - allow_partial_mpa_restoration) + state_dict = _restore_mpas( + state_dict, target, ckpt_path, step, gda_manager, allow_partial_mpa_restoration + ) if target is None: restored_checkpoint = state_dict @@ -1016,8 +1096,7 @@ def read_chunk(i): restored_checkpoint = serialization.from_state_dict(target, state_dict) end_time = time.time() - monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, - end_time - start_time) + monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, end_time - start_time) return restored_checkpoint @@ -1088,5 +1167,5 @@ def __call__(self, x): params_renamed[name] = convert_pre_linen(value) if isinstance(params, core.FrozenDict): - params_renamed = core.freeze(params_renamed) # type: ignore + params_renamed = core.freeze(params_renamed) # type: ignore return params_renamed diff --git a/flax/training/common_utils.py b/flax/training/common_utils.py index 10b7790284..546c8bb541 100644 --- a/flax/training/common_utils.py +++ b/flax/training/common_utils.py @@ -35,7 +35,8 @@ def shard(xs): """ local_device_count = jax.local_device_count() return jax.tree_util.tree_map( - lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs) + lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs + ) def shard_prng_key(prng_key): @@ -99,6 +100,6 @@ def onehot(labels, num_classes, on_value=1.0, off_value=0.0): A (n+1)-dim array whose last dimension contains one-hot vectors of length num_classes. """ - x = (labels[..., None] == jnp.arange(num_classes).reshape((1,) * labels.ndim + (-1,))) + x = labels[..., None] == jnp.arange(num_classes).reshape((1,) * labels.ndim + (-1,)) x = lax.select(x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value)) return x.astype(jnp.float32) diff --git a/flax/training/dynamic_scale.py b/flax/training/dynamic_scale.py index 1754981831..cb39be9fe7 100644 --- a/flax/training/dynamic_scale.py +++ b/flax/training/dynamic_scale.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Dynamic loss scaling for mixed precision gradients. -""" +"""Dynamic loss scaling for mixed precision gradients.""" import functools from typing import Any, Callable, NamedTuple, Optional, Sequence, Union @@ -25,7 +24,6 @@ import jax.numpy as jnp - Array = Any @@ -81,6 +79,7 @@ def loss_fn(p): minimum_scale: the minimum value that the scale can take (default: the smallest positive number representable in floating point). """ + growth_factor: float = struct.field(pytree_node=False, default=2.0) backoff_factor: float = struct.field(pytree_node=False, default=0.5) growth_interval: int = struct.field(pytree_node=False, default=2000) @@ -90,11 +89,13 @@ def loss_fn(p): pytree_node=False, default=jnp.finfo(jnp.float32).tiny ) - def value_and_grad(self, fun: Callable[..., Any], - argnums: Union[int, Sequence[int]] = 0, - has_aux: bool = False, - axis_name: Optional[str] = None, - ) -> Callable[..., DynamicScaleResult]: + def value_and_grad( + self, + fun: Callable[..., Any], + argnums: Union[int, Sequence[int]] = 0, + has_aux: bool = False, + axis_name: Optional[str] = None, + ) -> Callable[..., DynamicScaleResult]: """Wrapper around `jax.value_and_grad`. Args: @@ -114,6 +115,7 @@ def value_and_grad(self, fun: Callable[..., Any], A function that takes the same arguments as `fun` and returns a DynamicScaleResult """ + @functools.wraps(fun) def loss_wrapper(*args): aux = fun(*args) @@ -123,12 +125,14 @@ def loss_wrapper(*args): return self.scale * aux grad_fn = jax.value_and_grad(loss_wrapper, argnums, has_aux) + def grad_fn_wrapper(*args): aux, grad = grad_fn(*args) aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale grad = jax.tree_util.tree_map( - lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad) + lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad + ) if axis_name is not None: grad = lax.pmean(grad, axis_name) @@ -140,7 +144,8 @@ def grad_fn_wrapper(*args): fin_scale = jnp.where( grow & finite, jnp.minimum(self.scale * self.growth_factor, jnp.finfo(jnp.float32).max), - self.scale) + self.scale, + ) inf_scale = self.scale * self.backoff_factor if self.minimum_scale is not None: inf_scale = jnp.maximum(inf_scale, self.minimum_scale) @@ -149,4 +154,5 @@ def grad_fn_wrapper(*args): new_self = self.replace(fin_steps=new_fin_steps, scale=new_scale) return DynamicScaleResult(new_self, finite, aux, grad) + return grad_fn_wrapper diff --git a/flax/training/early_stopping.py b/flax/training/early_stopping.py index ab5b31289d..43b4c2d02d 100644 --- a/flax/training/early_stopping.py +++ b/flax/training/early_stopping.py @@ -44,6 +44,7 @@ class EarlyStopping(struct.PyTreeNode): should_stop: Whether the training loop should stop to avoid overfitting. """ + min_delta: float = 0 patience: int = 0 best_metric: float = float('inf') @@ -51,9 +52,7 @@ class EarlyStopping(struct.PyTreeNode): should_stop: bool = False def reset(self): - return self.replace(best_metric=float('inf'), - patience_count=0, - should_stop=False) + return self.replace(best_metric=float('inf'), patience_count=0, should_stop=False) def update(self, metric): """Update the state based on metric. @@ -65,9 +64,9 @@ def update(self, metric): """ if math.isinf(self.best_metric) or self.best_metric - metric > self.min_delta: - return True, self.replace(best_metric=metric, - patience_count=0) + return True, self.replace(best_metric=metric, patience_count=0) else: should_stop = self.patience_count >= self.patience or self.should_stop - return False, self.replace(patience_count=self.patience_count + 1, - should_stop=should_stop) + return False, self.replace( + patience_count=self.patience_count + 1, should_stop=should_stop + ) diff --git a/flax/training/lr_schedule.py b/flax/training/lr_schedule.py index 753e2004bf..294d94f5cf 100644 --- a/flax/training/lr_schedule.py +++ b/flax/training/lr_schedule.py @@ -33,8 +33,9 @@ def _piecewise_constant(boundaries, values, t): return jnp.take(values, index) -def create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, - warmup_length=0.0): +def create_constant_learning_rate_schedule( + base_learning_rate, steps_per_epoch, warmup_length=0.0 +): """Create a constant learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are @@ -60,20 +61,24 @@ def create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, Function `f(step) -> lr` that computes the learning rate for a given step. """ logging.warning( - 'Learning rate schedules in ``flax.training`` are effectively deprecated ' - 'in favor of Optax schedules. Please refer to ' - 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' - ' for alternatives.') + 'Learning rate schedules in ``flax.training`` are effectively deprecated ' + 'in favor of Optax schedules. Please refer to ' + 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' + ' for alternatives.' + ) + def learning_rate_fn(step): lr = base_learning_rate if warmup_length > 0.0: - lr = lr * jnp.minimum(1., step / float(warmup_length) / steps_per_epoch) + lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr + return learning_rate_fn -def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, - lr_sched_steps, warmup_length=0.0): +def create_stepped_learning_rate_schedule( + base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0 +): """Create a stepped learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are @@ -114,10 +119,11 @@ def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, Function `f(step) -> lr` that computes the learning rate for a given step. """ logging.warning( - 'Learning rate schedules in ``flax.training`` are effectively deprecated ' - 'in favor of Optax schedules. Please refer to ' - 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' - ' for alternatives.') + 'Learning rate schedules in ``flax.training`` are effectively deprecated ' + 'in favor of Optax schedules. Please refer to ' + 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' + ' for alternatives.' + ) boundaries = [step[0] for step in lr_sched_steps] decays = [step[1] for step in lr_sched_steps] boundaries = np.array(boundaries) * steps_per_epoch @@ -127,13 +133,15 @@ def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, def learning_rate_fn(step): lr = _piecewise_constant(boundaries, values, step) if warmup_length > 0.0: - lr = lr * jnp.minimum(1., step / float(warmup_length) / steps_per_epoch) + lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr + return learning_rate_fn -def create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, - halfcos_epochs, warmup_length=0.0): +def create_cosine_learning_rate_schedule( + base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0 +): """Create a cosine learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are @@ -164,18 +172,18 @@ def create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, Function `f(step) -> lr` that computes the learning rate for a given step. """ logging.warning( - 'Learning rate schedules in ``flax.training`` are effectively deprecated ' - 'in favor of Optax schedules. Please refer to ' - 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' - ' for alternatives.') + 'Learning rate schedules in ``flax.training`` are effectively deprecated ' + 'in favor of Optax schedules. Please refer to ' + 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' + ' for alternatives.' + ) halfwavelength_steps = halfcos_epochs * steps_per_epoch def learning_rate_fn(step): scale_factor = jnp.cos(step * jnp.pi / halfwavelength_steps) * 0.5 + 0.5 lr = base_learning_rate * scale_factor if warmup_length > 0.0: - lr = lr * jnp.minimum(1., step / float(warmup_length) / steps_per_epoch) + lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr - return learning_rate_fn - + return learning_rate_fn diff --git a/flax/training/orbax_utils.py b/flax/training/orbax_utils.py index 44481a3827..1a3f5bb547 100644 --- a/flax/training/orbax_utils.py +++ b/flax/training/orbax_utils.py @@ -40,9 +40,7 @@ def save_args_from_target(target: Any) -> Any: ) -def maybe_construct_transformations( - target: Any, transforms: Optional[Any] -) -> Any: +def maybe_construct_transformations(target: Any, transforms: Optional[Any]) -> Any: if transforms is not None: return transforms flat_transforms = {} @@ -66,6 +64,7 @@ def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any: Returns: A Pytree of Orbax `RestoreArgs` or `ArrayRestoreArgs` """ + def find_sharding(x): if is_multi_device_array(x): return x.sharding @@ -90,7 +89,5 @@ def find_sharding(x): DeprecationWarning, ) axes_tree = jax.tree_util.tree_map(lambda s: s.spec, sharding_tree) - return ocp.checkpoint_utils.restore_args_from_target( - mesh, target, axes_tree - ) + return ocp.checkpoint_utils.restore_args_from_target(mesh, target, axes_tree) return ocp.checkpoint_utils.construct_restore_args(target, sharding_tree) diff --git a/flax/training/prefetch_iterator.py b/flax/training/prefetch_iterator.py index 13e0e418e2..c8deafe1bd 100644 --- a/flax/training/prefetch_iterator.py +++ b/flax/training/prefetch_iterator.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility for constructing an iterator which prefetches data asynchronously. -""" +"""Utility for constructing an iterator which prefetches data asynchronously.""" import threading import warnings @@ -43,8 +42,11 @@ def __init__(self, data_iter, buffer_size=1): data_iter: the Iterator that should be prefetched. buffer_size: how many items to prefetch (default: 1). """ - warnings.warn('PrefetchIterator is deprecated. Use the standard `tf.data`' - ' prefetch method instead', DeprecationWarning) + warnings.warn( + 'PrefetchIterator is deprecated. Use the standard `tf.data`' + ' prefetch method instead', + DeprecationWarning, + ) self._data_iter = data_iter self.buffer_size = buffer_size @@ -77,6 +79,7 @@ def close(self): def _prefetch_loop(self): """Prefetch loop that prefetches a tf dataset.""" + def _predicate(): return len(self._buffer) < self.buffer_size or not self._active diff --git a/flax/training/train_state.py b/flax/training/train_state.py index d3b5641dfd..211fa39cfa 100644 --- a/flax/training/train_state.py +++ b/flax/training/train_state.py @@ -49,6 +49,7 @@ class TrainState(struct.PyTreeNode): tx: An Optax gradient transformation. opt_state: The state for `tx`. """ + step: int apply_fn: Callable = struct.field(pytree_node=False) params: core.FrozenDict[str, Any] = struct.field(pytree_node=True) @@ -70,8 +71,7 @@ def apply_gradients(self, *, grads, **kwargs): and `opt_state` updated by applying `grads`, and additional attributes replaced as specified by `kwargs`. """ - updates, new_opt_state = self.tx.update( - grads, self.opt_state, self.params) + updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) new_params = optax.apply_updates(self.params, updates) return self.replace( step=self.step + 1, diff --git a/flax/traverse_util.py b/flax/traverse_util.py index e12b27be4b..d06460bd8e 100644 --- a/flax/traverse_util.py +++ b/flax/traverse_util.py @@ -54,11 +54,13 @@ Path = Tuple[str, ...] + # the empty node is a struct.dataclass to be compatible with JAX. @struct.dataclass class _EmptyNode: pass + empty_node = _EmptyNode() @@ -97,8 +99,8 @@ def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None): The flattened dictionary. """ assert isinstance( - xs, - (flax.core.FrozenDict, dict)), f'expected (frozen)dict; got {type(xs)}' + xs, (flax.core.FrozenDict, dict) + ), f'expected (frozen)dict; got {type(xs)}' def _key(path): if sep is None: @@ -107,7 +109,8 @@ def _key(path): def _flatten(xs, prefix): if not isinstance(xs, (flax.core.FrozenDict, dict)) or ( - is_leaf and is_leaf(prefix, xs)): + is_leaf and is_leaf(prefix, xs) + ): return {_key(prefix): xs} result = {} is_empty = True @@ -120,6 +123,7 @@ def _flatten(xs, prefix): return {} return {_key(prefix): empty_node} return result + return _flatten(xs, ()) @@ -164,7 +168,8 @@ def unflatten_dict(xs, sep=None): def path_aware_map( - f: Callable[[Path, Any], Any], nested_dict: VariableDict) -> VariableDict: + f: Callable[[Path, Any], Any], nested_dict: VariableDict +) -> VariableDict: """A map function that operates over nested dictionary structures while taking the path to each leaf into account. @@ -187,8 +192,10 @@ def path_aware_map( A new nested dictionary structure with the mapped values. """ flat = flatten_dict(nested_dict, keep_empty_nodes=True) - return unflatten_dict({ - k: f(k, v) if v is not empty_node else v for k, v in flat.items()}) + return unflatten_dict( + {k: f(k, v) if v is not empty_node else v for k, v in flat.items()} + ) + class Traversal(abc.ABC): """Base class for all traversals.""" @@ -199,7 +206,9 @@ def __new__(cls, *args, **kwargs): '`flax.traverse_util.Traversal` will be deprecated. If you are using ' 'it for `flax.optim`, use `optax` instead. Refer to the update guide ' 'https://flax.readthedocs.io/en/latest/guides/optax_update_guide.html ' - 'for detailed instructions.', DeprecationWarning) + 'for detailed instructions.', + DeprecationWarning, + ) return super().__new__(cls) @abc.abstractmethod @@ -235,6 +244,7 @@ def set(self, values, inputs): Returns: A new object with the updated values. """ + def update_fn(_): if not values: raise ValueError('Not enough values provided') @@ -378,8 +388,7 @@ def update(self, fn, inputs): sl = slice(self._key, self._key + 1) indices = set(range(*sl.indices(len(inputs)))) - args = [fn(inputs[i]) if i in indices else inputs[i] - for i in range(len(inputs))] + args = [fn(inputs[i]) if i in indices else inputs[i] for i in range(len(inputs))] if _is_namedtuple(ty): return ty(*args) else: @@ -415,8 +424,7 @@ def iterate(self, inputs): class TraverseTree(Traversal): - """Traverse every item in a pytree. - """ + """Traverse every item in a pytree.""" def update(self, fn, inputs): return jax.tree_util.tree_map(fn, inputs) @@ -431,7 +439,8 @@ def _get_params_dict(inputs): else: raise ValueError( 'Can only traverse a flax Model instance or a nested dict, not ' - f'{type(inputs)}') + f'{type(inputs)}' + ) def _sorted_items(x): diff --git a/flax/version.py b/flax/version.py index bf52b2c38e..9c5e6150f2 100644 --- a/flax/version.py +++ b/flax/version.py @@ -14,4 +14,3 @@ """Current Flax version at head on Github.""" __version__ = "0.7.0" - diff --git a/tests/checkpoints_test.py b/tests/checkpoints_test.py index 1b4964ce1c..81ee79786b 100644 --- a/tests/checkpoints_test.py +++ b/tests/checkpoints_test.py @@ -43,7 +43,8 @@ def check_eq(xs, ys): return jax.tree_util.tree_all( - jax.tree_util.tree_map(np.testing.assert_allclose, xs, ys)) + jax.tree_util.tree_map(np.testing.assert_allclose, xs, ys) + ) def shuffle(l): @@ -111,52 +112,51 @@ def test_safe_normpath(self): def test_save_restore_checkpoints(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = pathlib.Path(self.create_tempdir().full_path) - test_object0 = {'a': np.array([0, 0, 0], np.int32), - 'b': np.array([0, 0, 0], np.int32)} - test_object1 = {'a': np.array([1, 2, 3], np.int32), - 'b': np.array([1, 1, 1], np.int32)} - test_object2 = {'a': np.array([4, 5, 6], np.int32), - 'b': np.array([2, 2, 2], np.int32)} - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') + test_object0 = { + 'a': np.array([0, 0, 0], np.int32), + 'b': np.array([0, 0, 0], np.int32), + } + test_object1 = { + 'a': np.array([1, 2, 3], np.int32), + 'b': np.array([1, 1, 1], np.int32), + } + test_object2 = { + 'a': np.array([4, 5, 6], np.int32), + 'b': np.array([2, 2, 2], np.int32), + } + new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') check_eq(new_object, test_object0) - checkpoints.save_checkpoint( - tmp_dir, test_object1, 0, prefix='test_', keep=1) + checkpoints.save_checkpoint(tmp_dir, test_object1, 0, prefix='test_', keep=1) self.assertIn('test_0', os.listdir(tmp_dir)) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') + new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') check_eq(new_object, test_object1) - checkpoints.save_checkpoint( - tmp_dir, test_object1, 1, prefix='test_', keep=1) - checkpoints.save_checkpoint( - tmp_dir, test_object2, 2, prefix='test_', keep=1) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') + checkpoints.save_checkpoint(tmp_dir, test_object1, 1, prefix='test_', keep=1) + checkpoints.save_checkpoint(tmp_dir, test_object2, 2, prefix='test_', keep=1) + new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') check_eq(new_object, test_object2) - checkpoints.save_checkpoint( - tmp_dir, test_object2, 3, prefix='test_', keep=2) - checkpoints.save_checkpoint( - tmp_dir, test_object1, 4, prefix='test_', keep=2) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') + checkpoints.save_checkpoint(tmp_dir, test_object2, 3, prefix='test_', keep=2) + checkpoints.save_checkpoint(tmp_dir, test_object1, 4, prefix='test_', keep=2) + new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') check_eq(new_object, test_object1) new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, step=3, prefix='test_') + tmp_dir, test_object0, step=3, prefix='test_' + ) check_eq(new_object, test_object2) # Restore with a specific checkpoint path, not the directory path. new_object = checkpoints.restore_checkpoint( - os.path.join(tmp_dir, 'test_3'), test_object0) + os.path.join(tmp_dir, 'test_3'), test_object0 + ) check_eq(new_object, test_object2) # If a specific path is specified, but it does not exist, the same behavior # as when a directory is empty should apply: the target is returned # unchanged. new_object = checkpoints.restore_checkpoint( - os.path.join(tmp_dir, 'test_not_there'), test_object0) + os.path.join(tmp_dir, 'test_not_there'), test_object0 + ) check_eq(new_object, test_object0) with self.assertRaises(ValueError): - checkpoints.restore_checkpoint( - tmp_dir, test_object0, step=5, prefix='test_') + checkpoints.restore_checkpoint(tmp_dir, test_object0, step=5, prefix='test_') @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_overwrite_checkpoints(self, use_orbax): @@ -183,8 +183,10 @@ def test_overwrite_checkpoints(self, use_orbax): new_object = checkpoints.restore_checkpoint(non_norm_dir_path, test_object0) check_eq(new_object, test_object) - @parameterized.parameters({'use_orbax': True, 'keep_every_n_steps': None}, - {'use_orbax': False, 'keep_every_n_steps': 7}) + @parameterized.parameters( + {'use_orbax': True, 'keep_every_n_steps': None}, + {'use_orbax': False, 'keep_every_n_steps': 7}, + ) def test_keep(self, use_orbax, keep_every_n_steps): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path @@ -195,18 +197,20 @@ def test_keep(self, use_orbax, keep_every_n_steps): increment = 5 for step in range(steps_start, steps_end, increment): - checkpoints.save_checkpoint(tmp_dir, - test_object, - step=step, - keep=keep, - keep_every_n_steps=keep_every_n_steps) + checkpoints.save_checkpoint( + tmp_dir, + test_object, + step=step, + keep=keep, + keep_every_n_steps=keep_every_n_steps, + ) last_checkpoint = -float('inf') for step in range(steps_start, steps_end, increment): - if ((steps_end - step) / increment <= keep) or (keep_every_n_steps and ( - step - last_checkpoint) >= keep_every_n_steps): - restored = checkpoints.restore_checkpoint( - tmp_dir, target=None, step=step) + if ((steps_end - step) / increment <= keep) or ( + keep_every_n_steps and (step - last_checkpoint) >= keep_every_n_steps + ): + restored = checkpoints.restore_checkpoint(tmp_dir, target=None, step=step) check_eq(restored, test_object) last_checkpoint = step else: @@ -217,22 +221,24 @@ def test_keep(self, use_orbax, keep_every_n_steps): def test_save_restore_checkpoints_w_float_steps(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path - test_object0 = {'a': np.array([0, 0, 0], np.int32), - 'b': np.array([0, 0, 0], np.int32)} - test_object1 = {'a': np.array([1, 2, 3], np.int32), - 'b': np.array([1, 1, 1], np.int32)} - test_object2 = {'a': np.array([4, 5, 6], np.int32), - 'b': np.array([2, 2, 2], np.int32)} - checkpoints.save_checkpoint( - tmp_dir, test_object1, 0.0, prefix='test_', keep=1) + test_object0 = { + 'a': np.array([0, 0, 0], np.int32), + 'b': np.array([0, 0, 0], np.int32), + } + test_object1 = { + 'a': np.array([1, 2, 3], np.int32), + 'b': np.array([1, 1, 1], np.int32), + } + test_object2 = { + 'a': np.array([4, 5, 6], np.int32), + 'b': np.array([2, 2, 2], np.int32), + } + checkpoints.save_checkpoint(tmp_dir, test_object1, 0.0, prefix='test_', keep=1) self.assertIn('test_0.0', os.listdir(tmp_dir)) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') + new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') check_eq(new_object, test_object1) - checkpoints.save_checkpoint( - tmp_dir, test_object1, 2.0, prefix='test_', keep=1) - checkpoints.save_checkpoint( - tmp_dir, test_object2, 3.0, prefix='test_', keep=2) + checkpoints.save_checkpoint(tmp_dir, test_object1, 2.0, prefix='test_', keep=1) + checkpoints.save_checkpoint(tmp_dir, test_object2, 3.0, prefix='test_', keep=2) self.assertIn('test_3.0', os.listdir(tmp_dir)) self.assertIn('test_2.0', os.listdir(tmp_dir)) check_eq(new_object, test_object1) @@ -241,15 +247,16 @@ def test_save_restore_checkpoints_w_float_steps(self, use_orbax): def test_save_restore_checkpoints_target_none(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path - test_object0 = {'a': np.array([0, 0, 0], np.int32), - 'b': np.array([0, 0, 0], np.int32)} + test_object0 = { + 'a': np.array([0, 0, 0], np.int32), + 'b': np.array([0, 0, 0], np.int32), + } # Target pytree is a dictionary, so it's equal to a restored state_dict. checkpoints.save_checkpoint(tmp_dir, test_object0, 0) new_object = checkpoints.restore_checkpoint(tmp_dir, target=None) check_eq(new_object, test_object0) # Target pytree it's a tuple, check the expected state_dict is recovered. - test_object1 = (np.array([0, 0, 0], np.int32), - np.array([1, 1, 1], np.int32)) + test_object1 = (np.array([0, 0, 0], np.int32), np.array([1, 1, 1], np.int32)) checkpoints.save_checkpoint(tmp_dir, test_object1, 1) new_object = checkpoints.restore_checkpoint(tmp_dir, target=None) expected_new_object = {str(k): v for k, v in enumerate(test_object1)} @@ -294,32 +301,41 @@ def test_save_restore_checkpoints_target_empty(self, use_orbax): def test_async_save_checkpoints(self): tmp_dir = pathlib.Path(self.create_tempdir().full_path) - test_object0 = {'a': np.array([0, 0, 0], np.int32), - 'b': np.array([0, 0, 0], np.int32)} - test_object1 = {'a': np.random.normal(size=(1000, 1000)), - 'b': np.random.normal(size=(1000, 1000))} - test_object2 = {'a': np.random.normal(size=(1000, 1000)), - 'b': np.random.normal(size=(1000, 1000))} - test_object3 = {'a': np.random.normal(size=(1000, 1000)), - 'b': np.random.normal(size=(1000, 1000))} + test_object0 = { + 'a': np.array([0, 0, 0], np.int32), + 'b': np.array([0, 0, 0], np.int32), + } + test_object1 = { + 'a': np.random.normal(size=(1000, 1000)), + 'b': np.random.normal(size=(1000, 1000)), + } + test_object2 = { + 'a': np.random.normal(size=(1000, 1000)), + 'b': np.random.normal(size=(1000, 1000)), + } + test_object3 = { + 'a': np.random.normal(size=(1000, 1000)), + 'b': np.random.normal(size=(1000, 1000)), + } am = checkpoints.AsyncManager() checkpoints.save_checkpoint( - tmp_dir, test_object1, 0, prefix='test_', keep=1, async_manager=am) + tmp_dir, test_object1, 0, prefix='test_', keep=1, async_manager=am + ) # Hard-wait the write to be done, then check its content. am.save_future.result() self.assertIn('test_0', os.listdir(tmp_dir)) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object1, prefix='test_') + new_object = checkpoints.restore_checkpoint(tmp_dir, test_object1, prefix='test_') check_eq(new_object, test_object1) # Check two consecutive saves happen in the right order. checkpoints.save_checkpoint( - tmp_dir, test_object2, 1, prefix='test_', keep=1, async_manager=am) + tmp_dir, test_object2, 1, prefix='test_', keep=1, async_manager=am + ) checkpoints.save_checkpoint( - tmp_dir, test_object3, 2, prefix='test_', keep=1, async_manager=am) + tmp_dir, test_object3, 2, prefix='test_', keep=1, async_manager=am + ) am.save_future.result() self.assertIn('test_2', os.listdir(tmp_dir)) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object1, prefix='test_') + new_object = checkpoints.restore_checkpoint(tmp_dir, test_object1, prefix='test_') check_eq(new_object, test_object3) def test_last_checkpoint(self): @@ -400,10 +416,12 @@ def test_auto_restore(self): # Both gets restored with same API. restored = checkpoints.restore_checkpoint( - os.path.join(tmp_dir, 'test_0'), target=target) + os.path.join(tmp_dir, 'test_0'), target=target + ) check_eq(restored, to_save) restored = checkpoints.restore_checkpoint( - os.path.join(tmp_dir, 'test_1'), target=target) + os.path.join(tmp_dir, 'test_1'), target=target + ) check_eq(restored, to_save) def test_convert_pre_linen(self): @@ -413,33 +431,23 @@ def test_convert_pre_linen(self): 'submod2_1': {}, 'submod1_2': {}, }, - 'mod2_2': { - 'submod2_2_0': {} - }, - 'mod2_11': { - 'submod2_11_0': {} - }, - 'mod2_1': { - 'submod2_1_0': {} - }, + 'mod2_2': {'submod2_2_0': {}}, + 'mod2_11': {'submod2_11_0': {}}, + 'mod2_1': {'submod2_1_0': {}}, }) self.assertDictEqual( - core.unfreeze(params), { + core.unfreeze(params), + { 'mod_0': { 'submod1_0': {}, 'submod1_1': {}, 'submod2_0': {}, }, - 'mod2_0': { - 'submod2_1_0': {} - }, - 'mod2_1': { - 'submod2_2_0': {} - }, - 'mod2_2': { - 'submod2_11_0': {} - }, - }) + 'mod2_0': {'submod2_1_0': {}}, + 'mod2_1': {'submod2_2_0': {}}, + 'mod2_2': {'submod2_11_0': {}}, + }, + ) if __name__ == '__main__': diff --git a/tests/core/core_frozen_dict_test.py b/tests/core/core_frozen_dict_test.py index 3ff17aab96..fa9821d00e 100644 --- a/tests/core/core_frozen_dict_test.py +++ b/tests/core/core_frozen_dict_test.py @@ -40,8 +40,8 @@ def test_frozen_dict_pop(self): def test_frozen_dict_partially_maps(self): x = jax.tree_util.tree_map( - lambda a, b: (a, b), - freeze({'a': 2}), freeze({'a': {'b': 1}})) + lambda a, b: (a, b), freeze({'a': 2}), freeze({'a': {'b': 1}}) + ) self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})}) def test_frozen_dict_hash(self): @@ -96,12 +96,8 @@ def test_frozen_dict_copy_reserved_name(self): ) def test_utility_pop(self, x, key, actual_new_x, actual_value): new_x, value = pop(x, key) - self.assertTrue( - new_x == actual_new_x and isinstance(new_x, type(actual_new_x)) - ) - self.assertTrue( - value == actual_value and isinstance(value, type(actual_value)) - ) + self.assertTrue(new_x == actual_new_x and isinstance(new_x, type(actual_new_x))) + self.assertTrue(value == actual_value and isinstance(value, type(actual_value))) @parameterized.parameters( { @@ -117,9 +113,7 @@ def test_utility_pop(self, x, key, actual_new_x, actual_value): ) def test_utility_copy(self, x, add_or_replace, actual_new_x): new_x = copy(x, add_or_replace=add_or_replace) - self.assertTrue( - new_x == actual_new_x and isinstance(new_x, type(actual_new_x)) - ) + self.assertTrue(new_x == actual_new_x and isinstance(new_x, type(actual_new_x))) @parameterized.parameters( { @@ -149,13 +143,16 @@ def test_flatten(self): flat_path_leaves, tdef = jax.tree_util.tree_flatten_with_path(frozen) self.assertEqual( flat_path_leaves, - [((jax.tree_util.DictKey('b'), jax.tree_util.DictKey('a')), 2), - ((jax.tree_util.DictKey('c'),), 1)], + [ + ((jax.tree_util.DictKey('b'), jax.tree_util.DictKey('a')), 2), + ((jax.tree_util.DictKey('c'),), 1), + ], ) self.assertEqual( jax.tree_util.tree_unflatten(tdef, [l for _, l in flat_path_leaves]), frozen, ) + if __name__ == '__main__': absltest.main() diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index 1794e35478..2661b28efc 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -43,10 +43,13 @@ def g(scopes, _): def test_undefined_param(self): def f(scope): - dense = lift.vmap(nn.dense, - in_axes=(0, None), out_axes=0, - variable_axes={'params': 0}, - split_rngs={'params': True}) + dense = lift.vmap( + nn.dense, + in_axes=(0, None), + out_axes=0, + variable_axes={'params': 0}, + split_rngs={'params': True}, + ) dense(scope.push('dense'), np.ones((3, 2)), 2) msg = r'Could not find parameter named "kernel" in scope "/vmap\(dense\)".' @@ -55,11 +58,14 @@ def f(scope): def test_jit_cache(self): compiles = 0 + @lift.jit def f(scope, x): nonlocal compiles compiles += 1 - if scope.is_mutable_collection('intermediates') and not scope.is_mutable_collection('params'): + if scope.is_mutable_collection( + 'intermediates' + ) and not scope.is_mutable_collection('params'): scope.put_variable('intermediates', 'x', x + 1) return nn.dense(scope, x, 1) @@ -77,7 +83,6 @@ def f(scope, x): self.assertEqual(compiles, 3) # applying again should not self.assertEqual(state['intermediates']['x'].sum(), 3 * 2 * 2) - def test_vjp(self): def g(scope, x, y): p = scope.param('test', nn.initializers.constant(0.5), ()) @@ -88,15 +93,18 @@ def f(scope, x, y): z, bwd = lift.vjp(g, scope, x, y) return bwd(jnp.ones(y.shape)) - x = jnp.array([1., 2., 3.]) - y = jnp.array([4., 5., 6.]) + x = jnp.array([1.0, 2.0, 3.0]) + y = jnp.array([4.0, 5.0, 6.0]) _, params = init(f)(random.PRNGKey(0), x, y) params_grad, x_grad, y_grad = apply(f)(params, x, y) - self.assertEqual(params_grad, { - 'params': FrozenDict({'test': 32.}), - }) - np.testing.assert_allclose(x_grad, [2., 2.5, 3.]) - np.testing.assert_allclose(y_grad, [0.5, 1., 1.5]) + self.assertEqual( + params_grad, + { + 'params': FrozenDict({'test': 32.0}), + }, + ) + np.testing.assert_allclose(x_grad, [2.0, 2.5, 3.0]) + np.testing.assert_allclose(y_grad, [0.5, 1.0, 1.5]) def test_jvp(self): def g(scope, x): @@ -105,7 +113,9 @@ def g(scope, x): return p * x def f(scope, x): - vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {})) + vars_t = jax.tree_util.tree_map( + jnp.ones_like, scope.variables().get('params', {}) + ) _, out_t = lift.jvp(g, scope, (x,), (jnp.zeros_like(x),), {'params': vars_t}) return out_t @@ -126,27 +136,52 @@ def f(scope, x): def cond_fn(scope, c): acc = scope.get_variable('state', 'acc') return acc < x + def body_fn(scope, c): i = scope.get_variable('state', 'acc') p_rng = scope.make_rng('params') l_rng = scope.make_rng('loop') - scope.put_variable('state', 'rng_params', scope.get_variable('state', 'rng_params').at[i].set(p_rng)) - scope.put_variable('state', 'rng_loop', scope.get_variable('state', 'rng_loop').at[i].set(l_rng)) + scope.put_variable( + 'state', + 'rng_params', + scope.get_variable('state', 'rng_params').at[i].set(p_rng), + ) + scope.put_variable( + 'state', + 'rng_loop', + scope.get_variable('state', 'rng_loop').at[i].set(l_rng), + ) inc = scope.get_variable('params', 'inc') scope.put_variable('state', 'acc', i + inc) return c + 2 - return lift.while_loop(cond_fn, body_fn, scope, 0, carry_variables='state', split_rngs={'params': False, 'loop': True}) + + return lift.while_loop( + cond_fn, + body_fn, + scope, + 0, + carry_variables='state', + split_rngs={'params': False, 'loop': True}, + ) + x = 2 - c, vars = apply(f, mutable=True)({}, x, rngs={'params': random.PRNGKey(1), 'loop': random.PRNGKey(2)}) + c, vars = apply(f, mutable=True)( + {}, x, rngs={'params': random.PRNGKey(1), 'loop': random.PRNGKey(2)} + ) self.assertEqual(vars['state']['acc'], x) self.assertEqual(c, 2 * x) - np.testing.assert_array_equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1]) - np.testing.assert_array_compare(operator.__ne__, vars['state']['rng_loop'][0], vars['state']['rng_loop'][1]) + np.testing.assert_array_equal( + vars['state']['rng_params'][0], vars['state']['rng_params'][1] + ) + np.testing.assert_array_compare( + operator.__ne__, vars['state']['rng_loop'][0], vars['state']['rng_loop'][1] + ) def test_cond(self): def f(scope, x, pred): scope.variable('state', 'true_count', lambda: 0) scope.variable('state', 'false_count', lambda: 0) + def true_fn(scope, x): scope.variable('state', 'true_count').value += 1 return scope.child(nn.dense)(x, 2) @@ -160,7 +195,7 @@ def false_fn(scope, x): x = jnp.ones((1, 3)) y1, vars = init(f)(random.PRNGKey(0), x, True) self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 0}) - y2, vars = apply(f, mutable="state")(vars, x, False) + y2, vars = apply(f, mutable='state')(vars, x, False) self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1}) np.testing.assert_allclose(y1, -y2) @@ -187,32 +222,33 @@ def c_fn(scope, x): x = jnp.ones((1, 3)) y1, vars = init(f)(random.PRNGKey(0), x, 0) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 0, 'c_count': 0}) - y2, updates = apply(f, mutable="state")(vars, x, 1) + y2, updates = apply(f, mutable='state')(vars, x, 1) vars = copy(vars, updates) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 0}) np.testing.assert_allclose(y1, -y2) - y3, updates = apply(f, mutable="state")(vars, x, 2) + y3, updates = apply(f, mutable='state')(vars, x, 2) vars = copy(vars, updates) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1}) np.testing.assert_allclose(y1, y3) def test_subscope_var_aliasing(self): def test(scope, x): - subscope = scope.push(name="a") - subscope.put_variable('state', 'x', 0.) + subscope = scope.push(name='a') + subscope.put_variable('state', 'x', 0.0) _ = lift.while_loop( - lambda scope, x: False, - lambda scope, x: x, - scope, - jnp.array(0, jnp.int32), - carry_variables=['state'], + lambda scope, x: False, + lambda scope, x: x, + scope, + jnp.array(0, jnp.int32), + carry_variables=['state'], ) - subscope.put_variable('state', 'x', 1.) + subscope.put_variable('state', 'x', 1.0) val0 = scope.variables()['state']['a']['x'] val1 = subscope.variables()['state']['x'] self.assertEqual(val0, val1) return x - init(test)( random.PRNGKey(0), 1.) + + init(test)(random.PRNGKey(0), 1.0) if __name__ == '__main__': diff --git a/tests/core/core_meta_test.py b/tests/core/core_meta_test.py index 7d90bbc9b8..6089921719 100644 --- a/tests/core/core_meta_test.py +++ b/tests/core/core_meta_test.py @@ -28,8 +28,7 @@ class MetaTest(absltest.TestCase): def test_boxed_param(self): def f(scope, xs): def g(scope, x): - kernel_init = meta.with_partitioning(nn.initializers.ones_init(), - ('in', 'out')) + kernel_init = meta.with_partitioning(nn.initializers.ones_init(), ('in', 'out')) kernel = scope.param('kernel', kernel_init, (x.shape[-1], 2)) kernel_box = scope.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, meta.Partitioned) @@ -37,22 +36,24 @@ def g(scope, x): return x @ kernel lift.vmap( - g, in_axes=0, - variable_axes={'params': 0}, split_rngs={'params': True}, - metadata_params={meta.PARTITION_NAME: 'batch'})(scope, xs) + g, + in_axes=0, + variable_axes={'params': 0}, + split_rngs={'params': True}, + metadata_params={meta.PARTITION_NAME: 'batch'}, + )(scope, xs) _, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) - self.assertEqual(variables['params']['kernel'].names, - ('batch', 'in', 'out')) + self.assertEqual(variables['params']['kernel'].names, ('batch', 'in', 'out')) def test_boxed_variable(self): def f(scope, xs): def g(scope, x): - kernel_init = meta.with_partitioning(nn.initializers.ones_init(), - ('in', 'out')) - kernel = scope.variable('params', 'kernel', kernel_init, - scope.make_rng('params'), (x.shape[-1], 2)) - kernel.value += 1. + kernel_init = meta.with_partitioning(nn.initializers.ones_init(), ('in', 'out')) + kernel = scope.variable( + 'params', 'kernel', kernel_init, scope.make_rng('params'), (x.shape[-1], 2) + ) + kernel.value += 1.0 self.assertEqual(kernel.value.sum(), kernel.value.size * 2) kernel_box = scope.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, meta.Partitioned) @@ -60,69 +61,86 @@ def g(scope, x): return x @ kernel.value lift.vmap( - g, in_axes=0, - variable_axes={'params': 0}, split_rngs={'params': True}, - metadata_params={meta.PARTITION_NAME: 'batch'})(scope, xs) + g, + in_axes=0, + variable_axes={'params': 0}, + split_rngs={'params': True}, + metadata_params={meta.PARTITION_NAME: 'batch'}, + )(scope, xs) _, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) - self.assertEqual(variables['params']['kernel'].names, - ('batch', 'in', 'out')) + self.assertEqual(variables['params']['kernel'].names, ('batch', 'in', 'out')) def test_partition_axis_unspecified(self): def f(scope, xs): def g(scope, x): - kernel_init = meta.with_partitioning(nn.initializers.ones_init(), - ('in', 'out')) + kernel_init = meta.with_partitioning(nn.initializers.ones_init(), ('in', 'out')) scope.param('kernel', kernel_init, (3, 2)) return x with self.assertRaises(errors.PartitioningUnspecifiedError): lift.vmap( - g, in_axes=0, - variable_axes={'params': 0}, split_rngs={'params': True}, - metadata_params={})(scope, xs) + g, + in_axes=0, + variable_axes={'params': 0}, + split_rngs={'params': True}, + metadata_params={}, + )(scope, xs) + init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) def test_unbox(self): - xs = {'kernel': meta.Partitioned(jnp.zeros((3, 2)), ('in', 'out')), - 'complex': meta.Partitioned( - {'K': jnp.zeros((3, 2)), 'b': jnp.zeros((3,))}, ('data',))} + xs = { + 'kernel': meta.Partitioned(jnp.zeros((3, 2)), ('in', 'out')), + 'complex': meta.Partitioned( + {'K': jnp.zeros((3, 2)), 'b': jnp.zeros((3,))}, ('data',) + ), + } unboxed = meta.unbox(xs) unboxed_shapes = jax.tree_map(jnp.shape, unboxed) - self.assertEqual(unboxed_shapes, { - 'kernel': (3, 2), - 'complex': { - 'K': (3, 2), 'b': (3,), - } - }) + self.assertEqual( + unboxed_shapes, + { + 'kernel': (3, 2), + 'complex': { + 'K': (3, 2), + 'b': (3,), + }, + }, + ) def test_scan_over_layers(self): def f(scope, x): def body(scope, x): - kernel_init = meta.with_partitioning(nn.initializers.ones_init(), - ('in', 'out')) + kernel_init = meta.with_partitioning(nn.initializers.ones_init(), ('in', 'out')) y = nn.dense(scope, x, 3, kernel_init=kernel_init) return y, () c, _ = lift.scan( body, - variable_axes={'params': 0}, split_rngs={'params': True}, + variable_axes={'params': 0}, + split_rngs={'params': True}, length=8, - metadata_params={meta.PARTITION_NAME: 'layers'})(scope, x) + metadata_params={meta.PARTITION_NAME: 'layers'}, + )(scope, x) return c _, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) boxed_shapes = jax.tree_map(jnp.shape, variables['params']) - self.assertEqual(boxed_shapes, { - 'kernel': meta.Partitioned((8, 3, 3), ('layers', 'in', 'out')), - 'bias': (8, 3), - }) + self.assertEqual( + boxed_shapes, + { + 'kernel': meta.Partitioned((8, 3, 3), ('layers', 'in', 'out')), + 'bias': (8, 3), + }, + ) def test_get_partition_spec(self): - xs = {'kernel': meta.Partitioned(jnp.zeros((8, 3, 3)), - ('layers', 'in', 'out')), - 'bias': jnp.zeros((8, 3)), - 'step': jnp.array(100)} + xs = { + 'kernel': meta.Partitioned(jnp.zeros((8, 3, 3)), ('layers', 'in', 'out')), + 'bias': jnp.zeros((8, 3)), + 'step': jnp.array(100), + } ps = meta.get_partition_spec(xs) self.assertEqual( ps, @@ -136,15 +154,18 @@ def test_get_partition_spec(self): def test_get_sharding(self): devices = mesh_utils.create_device_mesh((jax.local_device_count(), 1)) mesh = sharding.Mesh(devices, ('in', 'out')) - xs = {'kernel': meta.Partitioned(jnp.zeros((8, 3)), - ('in', 'out')), - 'bias': jnp.zeros((8, 3)), - 'step': jnp.array(100)} + xs = { + 'kernel': meta.Partitioned(jnp.zeros((8, 3)), ('in', 'out')), + 'bias': jnp.zeros((8, 3)), + 'step': jnp.array(100), + } ps = meta.get_sharding(xs, mesh) self.assertEqual( ps, { - 'kernel': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('in', 'out')), + 'kernel': jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('in', 'out') + ), 'bias': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()), 'step': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()), }, @@ -155,13 +176,14 @@ def test_boxed_param_with_mesh(self): mesh = sharding.Mesh(devices, ('in', 'out')) def f(scope, x): - kernel_init = meta.with_partitioning( - nn.initializers.ones_init(),('in', 'out'), mesh=mesh) - kernel = scope.param('kernel', kernel_init, (x.shape[-1], 2)) - kernel_box = scope.get_variable('params', 'kernel') - self.assertIsInstance(kernel_box, meta.Partitioned) - self.assertEqual(kernel_box.names, ('in', 'out')) - return x @ kernel + kernel_init = meta.with_partitioning( + nn.initializers.ones_init(), ('in', 'out'), mesh=mesh + ) + kernel = scope.param('kernel', kernel_init, (x.shape[-1], 2)) + kernel_box = scope.get_variable('params', 'kernel') + self.assertIsInstance(kernel_box, meta.Partitioned) + self.assertEqual(kernel_box.names, ('in', 'out')) + return x @ kernel @jax.jit def create_state(): @@ -171,11 +193,10 @@ def create_state(): variables = jax.lax.with_sharding_constraint(variables, shardings) return variables - variables = create_state() - self.assertEqual(variables['params']['kernel'].names, - ('in', 'out')) + self.assertEqual(variables['params']['kernel'].names, ('in', 'out')) self.assertIs(variables['params']['kernel'].mesh, mesh) + if __name__ == '__main__': absltest.main() diff --git a/tests/core/core_scope_test.py b/tests/core/core_scope_test.py index 4e40da18c5..cee176d128 100644 --- a/tests/core/core_scope_test.py +++ b/tests/core/core_scope_test.py @@ -28,6 +28,7 @@ from absl.testing import absltest + class ScopeTest(absltest.TestCase): def test_rng(self): @@ -36,11 +37,12 @@ def f(scope): self.assertFalse(scope.has_rng('dropout')) rng = scope.make_rng('params') self.assertTrue(np.all(rng == LazyRng.create(random.PRNGKey(0), 1).as_jax_rng())) + init(f)(random.PRNGKey(0)) def test_in_filter(self): - filter_true = lambda x, y : self.assertTrue(scope.in_filter(x, y)) - filter_false = lambda x, y : self.assertFalse(scope.in_filter(x, y)) + filter_true = lambda x, y: self.assertTrue(scope.in_filter(x, y)) + filter_false = lambda x, y: self.assertFalse(scope.in_filter(x, y)) filter_true(True, 'any_string1') filter_false(False, 'any_string2') @@ -60,7 +62,11 @@ def union_check(a, b, ans): union_check(True, False, True) union_check(False, False, set()) union_check(True, True, True) - union_check(scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), scope.DenyList(set(['b']))) + union_check( + scope.DenyList(['a', 'b']), + scope.DenyList(['b', 'c']), + scope.DenyList(set(['b'])), + ) union_check(scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList(set(['a']))) def test_intersect_filter(self): @@ -72,7 +78,11 @@ def intersect_check(a, b, ans): intersect_check(True, False, False) intersect_check(False, False, set()) intersect_check(True, True, True) - intersect_check(scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), scope.DenyList(set(['a', 'b', 'c']))) + intersect_check( + scope.DenyList(['a', 'b']), + scope.DenyList(['b', 'c']), + scope.DenyList(set(['a', 'b', 'c'])), + ) intersect_check(scope.DenyList(['a', 'b']), ['b', 'c'], set(['c'])) def test_subtract_filter(self): @@ -85,13 +95,14 @@ def subtract_check(a, b, ans): subtract_check(True, True, False) subtract_check(True, 'a', scope.DenyList('a')) subtract_check(scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), set(['c'])) - subtract_check(scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList(set(['a', 'b', 'c']))) - + subtract_check( + scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList(set(['a', 'b', 'c'])) + ) def test_group_collections(self): - params = { 'dense1': { 'x': [10, 20] } } - batch_stats = { 'dense1': { 'ema': 5 } } - xs = { 'params': params, 'batch_stats': batch_stats } + params = {'dense1': {'x': [10, 20]}} + batch_stats = {'dense1': {'ema': 5}} + xs = {'params': params, 'batch_stats': batch_stats} # Retrieve all keys only once. group = scope.group_collections(xs, ['params', 'params']) @@ -101,8 +112,7 @@ def test_group_collections(self): self.assertEqual(scope.group_collections(xs, ['vars']), ({},)) # False gets nothing and True retrieves all keys once. - self.assertEqual(scope.group_collections(xs, [False, True, True]), - ({}, xs, {})) + self.assertEqual(scope.group_collections(xs, [False, True, True]), ({}, xs, {})) def test_inconsistent_param_shapes(self): def f(scope): @@ -123,15 +133,16 @@ def f(scope): }) apply(f)(params) # Valid. msg = 'but got a dict with an extra params layer' - with self.assertRaisesRegex(errors.ApplyScopeInvalidVariablesStructureError, - msg): + with self.assertRaisesRegex(errors.ApplyScopeInvalidVariablesStructureError, msg): apply(f)({'params': params}) def test_mutate_undefined_collection(self): def f(scope): scope.put_variable('state', 'test', 123) - msg = r'Cannot update variable "test" in "/" because collection "state" is immutable.' + msg = ( + r'Cannot update variable "test" in "/" because collection "state" is immutable.' + ) with self.assertRaisesRegex(errors.ModifyScopeVariableError, msg): init(f, mutable='params')(random.PRNGKey(0)) @@ -154,11 +165,13 @@ def f(scope, should_be_mutable): def test_rngs_check_w_frozen_dict(self): def f(scope, x): return x - _ = apply(f)( - {}, np.array([0.]), rngs=freeze({'a':random.PRNGKey(0)})) - @unittest.skipIf(not hasattr(jax_config, 'jax_enable_custom_prng'), - 'custom PRNG tests require config.jax_enable_custom_prng') + _ = apply(f)({}, np.array([0.0]), rngs=freeze({'a': random.PRNGKey(0)})) + + @unittest.skipIf( + not hasattr(jax_config, 'jax_enable_custom_prng'), + 'custom PRNG tests require config.jax_enable_custom_prng', + ) def test_rng_check_w_old_and_new_keys(self): old_setting = jax_config.jax_enable_custom_prng try: @@ -171,16 +184,21 @@ def test_rng_check_w_old_and_new_keys(self): def test_jax_leak_detector(self): with jax.check_tracer_leaks(True): + def f(scope): def g(scope): pass + scope.child(g)() + jax.jit(init(f))(random.PRNGKey(0)) def test_rng_counter_reuse(self): root = Scope({}, {'dropout': random.PRNGKey(0)}) + def f(scope): return scope.make_rng('dropout') + a = root.child(f)() root = root.rewound() b = root.child(f)() @@ -213,34 +231,43 @@ def test_variable_no_init(self): def test_variable_alias(self): scope = Scope({}, mutable='state') - subscope = scope.push(name="a") - subscope.put_variable('state', 'x', 0.) - scope.put_variable('state', 'a', {'x': jnp.array(1., jnp.float32)}) - self.assertEqual(scope.variables()['state']['a']['x'], subscope.variables()['state']['x']) + subscope = scope.push(name='a') + subscope.put_variable('state', 'x', 0.0) + scope.put_variable('state', 'a', {'x': jnp.array(1.0, jnp.float32)}) + self.assertEqual( + scope.variables()['state']['a']['x'], subscope.variables()['state']['x'] + ) def test_lazy_init(self): def f(scope, x): - k = scope.param("kernel", nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1])) + k = scope.param( + 'kernel', nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1]) + ) return x @ k + init_fn = lazy_init(f) # provide a massive input message which would OOM if any compute ops were actually executed - variables = init_fn(random.PRNGKey(0), jax.ShapeDtypeStruct((1024 * 1024 * 1024, 128), jnp.float32)) - self.assertEqual(variables["params"]["kernel"].shape, (128, 128)) + variables = init_fn( + random.PRNGKey(0), jax.ShapeDtypeStruct((1024 * 1024 * 1024, 128), jnp.float32) + ) + self.assertEqual(variables['params']['kernel'].shape, (128, 128)) def test_lazy_init_fails_on_data_dependence(self): def f(scope, x): # kernel is initialized with x so params are now dependent on the input - k = scope.param("kernel", lambda _: x) + k = scope.param('kernel', lambda _: x) return x * k + init_fn = lazy_init(f) with self.assertRaises(errors.LazyInitError): init_fn(random.PRNGKey(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) @temp_flip_flag('fix_rng_separator', True) def test_fold_in_static_seperator(self): - x = LazyRng(random.PRNGKey(0), ("ab", "c")) - y = LazyRng(random.PRNGKey(0), ("a", "bc")) + x = LazyRng(random.PRNGKey(0), ('ab', 'c')) + y = LazyRng(random.PRNGKey(0), ('a', 'bc')) self.assertFalse(np.all(x.as_jax_rng() == y.as_jax_rng())) + if __name__ == '__main__': absltest.main() diff --git a/tests/core/design/core_attention_test.py b/tests/core/design/core_attention_test.py index bfc5042814..2712cc8c60 100644 --- a/tests/core/design/core_attention_test.py +++ b/tests/core/design/core_attention_test.py @@ -29,40 +29,39 @@ def softmax_attn(scope: Scope, weights: Array): del scope norm_dims = tuple(range(weights.ndim // 2, weights.ndim)) - log_norms = jax.scipy.special.logsumexp( - weights, axis=norm_dims, keepdims=True) + log_norms = jax.scipy.special.logsumexp(weights, axis=norm_dims, keepdims=True) return jnp.exp(weights - log_norms) + def with_dropout(fn, rate: float, deterministic: bool = False): def attn_fn(scope: Scope, weights: Array): attn_weights = fn(scope, weights) return nn.dropout(scope, attn_weights, deterministic=deterministic, rate=rate) + return attn_fn + def _dot_product_attention( scope: Scope, - query: Array, key: Array, value: Array, + query: Array, + key: Array, + value: Array, bias: Optional[Array] = None, attn_fn: Callable = softmax_attn, - dtype=jnp.float32): + dtype=jnp.float32, +): assert key.ndim == query.ndim assert key.ndim == value.ndim n = query.ndim - attn_weights = lax.dot_general( - query, key, - (((n-1,), (n - 1,)), ((), ()))) + attn_weights = lax.dot_general(query, key, (((n - 1,), (n - 1,)), ((), ()))) if bias is not None: attn_weights += bias attn_weights = attn_fn(scope, attn_weights) attn_weights = attn_weights.astype(dtype) - contract_dims = ( - tuple(range(n - 1, attn_weights.ndim)), - tuple(range(0, n - 1))) - y = lax.dot_general( - attn_weights, value, - (contract_dims, ((), ()))) + contract_dims = (tuple(range(n - 1, attn_weights.ndim)), tuple(range(0, n - 1))) + y = lax.dot_general(attn_weights, value, (contract_dims, ((), ()))) return y @@ -74,7 +73,8 @@ def dot_product_attention( qkv_features: Optional[int] = None, out_features: Optional[int] = None, attn_fn: Callable = softmax_attn, - dtype=jnp.float32): + dtype=jnp.float32, +): if qkv_features is None: qkv_features = inputs_q.shape[-1] if out_features is None: @@ -86,14 +86,12 @@ def dot_product_attention( value = scope.child(dense, 'value')(inputs_kv) y = _dot_product_attention( - scope, query, key, value, - bias=bias, - attn_fn=attn_fn, dtype=dtype) + scope, query, key, value, bias=bias, attn_fn=attn_fn, dtype=dtype + ) return scope.child(nn.dense, 'out')(y, features=out_features, dtype=dtype) - def multi_head_dot_product_attention( scope: Scope, inputs_q: Array, @@ -105,8 +103,8 @@ def multi_head_dot_product_attention( batch_axes: Sequence[int] = (0,), num_heads: int = 1, dtype=jnp.float32, - broadcast_dropout=False): - + broadcast_dropout=False, +): if qkv_features is None: qkv_features = inputs_q.shape[-1] if out_features is None: @@ -117,19 +115,24 @@ def multi_head_dot_product_attention( attn_fn=attn_fn, qkv_features=qkv_features // num_heads, out_features=out_features, - dtype=dtype) + dtype=dtype, + ) attn_fn = lift.vmap( attn_fn, - in_axes=(None, None, None), out_axes=-2, + in_axes=(None, None, None), + out_axes=-2, axis_size=num_heads, variable_axes={'params': 0}, - split_rngs={'params': True, 'dropout': not broadcast_dropout}) + split_rngs={'params': True, 'dropout': not broadcast_dropout}, + ) for axis in reversed(sorted(batch_axes)): attn_fn = lift.vmap( attn_fn, - in_axes=(axis, axis, axis), out_axes=axis, + in_axes=(axis, axis, axis), + out_axes=axis, variable_axes={'params': None}, - split_rngs={'params': False, 'dropout': not broadcast_dropout}) + split_rngs={'params': False, 'dropout': not broadcast_dropout}, + ) y = attn_fn(scope, inputs_q, inputs_kv, bias) return y.mean(axis=-2) @@ -141,19 +144,24 @@ def test_attention(self): inputs = jnp.ones((2, 7, 16)) model = partial( multi_head_dot_product_attention, - num_heads=2, batch_axes=(0,), - attn_fn=with_dropout(softmax_attn, 0.1, deterministic=False)) + num_heads=2, + batch_axes=(0,), + attn_fn=with_dropout(softmax_attn, 0.1, deterministic=False), + ) rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} y, variables = jax.jit(init(model))(rngs, inputs, inputs) variable_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) self.assertEqual(y.shape, (2, 7, 16)) - self.assertEqual(unfreeze(variable_shapes), { - 'key': {'kernel': (2, 16, 8)}, - 'value': {'kernel': (2, 16, 8)}, - 'query': {'kernel': (2, 16, 8)}, - 'out': {'bias': (2, 16), 'kernel': (2, 8, 16)}, - }) + self.assertEqual( + unfreeze(variable_shapes), + { + 'key': {'kernel': (2, 16, 8)}, + 'value': {'kernel': (2, 16, 8)}, + 'query': {'kernel': (2, 16, 8)}, + 'out': {'bias': (2, 16), 'kernel': (2, 8, 16)}, + }, + ) if __name__ == '__main__': diff --git a/tests/core/design/core_auto_encoder_test.py b/tests/core/design/core_auto_encoder_test.py index f28aafd733..ac6d0b1f8c 100644 --- a/tests/core/design/core_auto_encoder_test.py +++ b/tests/core/design/core_auto_encoder_test.py @@ -36,7 +36,6 @@ def mlp(scope: Scope, x: Array, hidden: int, out: int): @dataclass class AutoEncoder: - latents: int features: int hidden: int @@ -60,6 +59,7 @@ def wrapper(self, *args, **kwargs): scope = self.scope.rewound() mod_fn = lambda scope: fn(self, scope, *args, **kwargs) return scope.child(mod_fn, name)() + return wrapper @@ -107,16 +107,19 @@ def test_auto_encoder_hp_struct(self): x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) - self.assertEqual(variable_shapes, { - 'encoder': { - 'hidden': {'kernel': (4, 3), 'bias': (3,)}, - 'out': {'kernel': (3, 2), 'bias': (2,)}, - }, - 'decoder': { - 'hidden': {'kernel': (2, 3), 'bias': (3,)}, - 'out': {'kernel': (3, 4), 'bias': (4,)}, + self.assertEqual( + variable_shapes, + { + 'encoder': { + 'hidden': {'kernel': (4, 3), 'bias': (3,)}, + 'out': {'kernel': (3, 2), 'bias': (2,)}, + }, + 'decoder': { + 'hidden': {'kernel': (2, 3), 'bias': (3,)}, + 'out': {'kernel': (3, 4), 'bias': (4,)}, + }, }, - }) + ) def test_auto_encoder_with_scope(self): ae = lambda scope, x: AutoEncoder2(scope, latents=2, features=4, hidden=3)(x) @@ -125,16 +128,19 @@ def test_auto_encoder_with_scope(self): x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) - self.assertEqual(variable_shapes, { - 'encode': { - 'hidden': {'kernel': (4, 3), 'bias': (3,)}, - 'out': {'kernel': (3, 2), 'bias': (2,)}, + self.assertEqual( + variable_shapes, + { + 'encode': { + 'hidden': {'kernel': (4, 3), 'bias': (3,)}, + 'out': {'kernel': (3, 2), 'bias': (2,)}, + }, + 'decode': { + 'hidden': {'kernel': (2, 3), 'bias': (3,)}, + 'out': {'kernel': (3, 4), 'bias': (4,)}, + }, }, - 'decode': { - 'hidden': {'kernel': (2, 3), 'bias': (3,)}, - 'out': {'kernel': (3, 4), 'bias': (4,)}, - }, - }) + ) def test_auto_encoder_bind_method(self): ae = lambda scope, x: AutoEncoder3.create(scope, latents=2, features=4, hidden=3)(x) @@ -143,16 +149,19 @@ def test_auto_encoder_bind_method(self): x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) - self.assertEqual(variable_shapes, { - 'encode': { - 'hidden': {'kernel': (4, 3), 'bias': (3,)}, - 'out': {'kernel': (3, 2), 'bias': (2,)}, - }, - 'decode': { - 'hidden': {'kernel': (2, 3), 'bias': (3,)}, - 'out': {'kernel': (3, 4), 'bias': (4,)}, + self.assertEqual( + variable_shapes, + { + 'encode': { + 'hidden': {'kernel': (4, 3), 'bias': (3,)}, + 'out': {'kernel': (3, 2), 'bias': (2,)}, + }, + 'decode': { + 'hidden': {'kernel': (2, 3), 'bias': (3,)}, + 'out': {'kernel': (3, 4), 'bias': (4,)}, + }, }, - }) + ) if __name__ == '__main__': diff --git a/tests/core/design/core_big_resnets_test.py b/tests/core/design/core_big_resnets_test.py index d8a3a5a9bf..41e963ff67 100644 --- a/tests/core/design/core_big_resnets_test.py +++ b/tests/core/design/core_big_resnets_test.py @@ -36,8 +36,10 @@ def residual_block(scope: Scope, x: Array, conv, norm, act, features: int): x = scope.child(norm, 'bn_2')(x) return act(residual + x) -def big_resnet(scope: Scope, x, blocks=(10, 5), dtype=jnp.float32, - norm=default_norm, act=nn.relu): + +def big_resnet( + scope: Scope, x, blocks=(10, 5), dtype=jnp.float32, norm=default_norm, act=nn.relu +): conv = partial(nn.conv, bias=False, dtype=dtype) norm = partial(norm, dtype=dtype) @@ -50,10 +52,12 @@ def body_fn(scope, x): return residual_block(scope, x, conv, norm, act, features=x.shape[-1]) return lift.remat_scan( - body_fn, lengths=blocks, + body_fn, + lengths=blocks, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, - policy=None)(scope, x) + policy=None, + )(scope, x) class BigResnetTest(absltest.TestCase): @@ -62,20 +66,26 @@ def test_big_resnet(self): x = random.normal(random.PRNGKey(0), (1, 8, 8, 8)) y, variables = init(big_resnet)(random.PRNGKey(1), x) self.assertEqual(y.shape, (1, 8, 8, 8)) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) batch_stats_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['batch_stats'])) - self.assertEqual(param_shapes, { - 'conv_1': {'kernel': (10, 5, 3, 3, 8, 8)}, - 'conv_2': {'kernel': (10, 5, 3, 3, 8, 8)}, - 'bn_1': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}, - 'bn_2': {'scale': (10, 5, 8), 'bias': (10, 5, 8)} - }) - self.assertEqual(batch_stats_shapes, { - 'bn_1': {'var': (10, 5, 8), 'mean': (10, 5, 8)}, - 'bn_2': {'var': (10, 5, 8), 'mean': (10, 5, 8)} - }) + jax.tree_util.tree_map(jnp.shape, variables['batch_stats']) + ) + self.assertEqual( + param_shapes, + { + 'conv_1': {'kernel': (10, 5, 3, 3, 8, 8)}, + 'conv_2': {'kernel': (10, 5, 3, 3, 8, 8)}, + 'bn_1': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}, + 'bn_2': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}, + }, + ) + self.assertEqual( + batch_stats_shapes, + { + 'bn_1': {'var': (10, 5, 8), 'mean': (10, 5, 8)}, + 'bn_2': {'var': (10, 5, 8), 'mean': (10, 5, 8)}, + }, + ) if __name__ == '__main__': diff --git a/tests/core/design/core_custom_vjp_test.py b/tests/core/design/core_custom_vjp_test.py index 853aff502c..65705d4744 100644 --- a/tests/core/design/core_custom_vjp_test.py +++ b/tests/core/design/core_custom_vjp_test.py @@ -26,10 +26,12 @@ from jax import random, numpy as jnp -def mlp_custom_grad(scope: Scope, x: Array, - sizes: Sequence[int] = (8, 1), - act_fn: Callable[[Array], Array] = nn.relu): - +def mlp_custom_grad( + scope: Scope, + x: Array, + sizes: Sequence[int] = (8, 1), + act_fn: Callable[[Array], Array] = nn.relu, +): f = nn.dense def fwd(scope, x, features): @@ -44,7 +46,8 @@ def bwd(features, res, y_t): return (params_t, *input_t) dense_custom_grad = lift.custom_vjp( - f, forward_fn=fwd, backward_fn=bwd, nondiff_argnums=(2,)) + f, forward_fn=fwd, backward_fn=bwd, nondiff_argnums=(2,) + ) # hidden layers for size in sizes[:-1]: @@ -60,12 +63,10 @@ class CustomVJPTest(absltest.TestCase): def test_custom_vjp(self): x = random.normal(random.PRNGKey(0), (1, 4)) y, variables = init(mlp_custom_grad)(random.PRNGKey(1), x) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) loss_fn = lambda p, x: jnp.mean(apply(mlp_custom_grad)(p, x) ** 2) grad = jax.grad(loss_fn)(variables, x) - grad_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, grad['params'])) + grad_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, grad['params'])) self.assertEqual(y.shape, (1, 1)) expected_param_shapes = { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, diff --git a/tests/core/design/core_dense_test.py b/tests/core/design/core_dense_test.py index 9d2d51b625..5881224c61 100644 --- a/tests/core/design/core_dense_test.py +++ b/tests/core/design/core_dense_test.py @@ -48,9 +48,14 @@ class ExplicitDense: # a fully explicit "scope free" version @staticmethod - def create(rng, in_size, out_size, bias=True, - kernel_init=nn.linear.default_kernel_init, - bias_init=nn.initializers.zeros_init()): + def create( + rng, + in_size, + out_size, + bias=True, + kernel_init=nn.linear.default_kernel_init, + bias_init=nn.initializers.zeros_init(), + ): k1, k2 = random.split(rng, 2) kernel = kernel_init(k1, (in_size, out_size)) if bias: @@ -61,9 +66,14 @@ def create(rng, in_size, out_size, bias=True, # a semi-explicit version where a scope is used to create explicit params @staticmethod - def create_in_scope(scope, in_size, out_size, bias=True, - kernel_init=nn.linear.default_kernel_init, - bias_init=nn.initializers.zeros_init()): + def create_in_scope( + scope, + in_size, + out_size, + bias=True, + kernel_init=nn.linear.default_kernel_init, + bias_init=nn.initializers.zeros_init(), + ): kernel = scope.param('kernel', kernel_init, (in_size, out_size)) if bias: bias = scope.param('bias', bias_init, (out_size,)) @@ -77,6 +87,7 @@ def __call__(self, x): y += self.bias.reshape((1,) * (y.ndim - 1) + (-1,)) return y + def explicit_mlp(scope, x, sizes=(3, 1)): for i, size in enumerate(sizes): dense = scope.param(f'dense_{i}', ExplicitDense.create, x.shape[-1], size) @@ -85,9 +96,12 @@ def explicit_mlp(scope, x, sizes=(3, 1)): x = nn.relu(x) return x + def semi_explicit_mlp(scope, x, sizes=(3, 1)): for i, size in enumerate(sizes): - dense = scope.child(ExplicitDense.create_in_scope, prefix='dense_')(x.shape[-1], size) + dense = scope.child(ExplicitDense.create_in_scope, prefix='dense_')( + x.shape[-1], size + ) x = dense(x) if i + 1 < len(sizes): x = nn.relu(x) @@ -100,46 +114,54 @@ def test_dense(self): model = Dense(features=4) x = jnp.ones((1, 3)) y, variables = init(model)(random.PRNGKey(0), x) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 4)) - self.assertEqual(param_shapes, { - 'kernel': (3, 4), - 'bias': (4,), - }) + self.assertEqual( + param_shapes, + { + 'kernel': (3, 4), + 'bias': (4,), + }, + ) def test_explicit_dense(self): x = jnp.ones((1, 3)) y, variables = init(explicit_mlp)(random.PRNGKey(0), x) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 4)) - self.assertEqual(param_shapes, { - 'kernel': (3, 4), - 'bias': (4,), - }) + self.assertEqual( + param_shapes, + { + 'kernel': (3, 4), + 'bias': (4,), + }, + ) def test_explicit_dense(self): x = jnp.ones((1, 4)) y, variables = init(explicit_mlp)(random.PRNGKey(0), x) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 1)) - self.assertEqual(param_shapes, { - 'dense_0': ExplicitDense((4, 3), (3,)), - 'dense_1': ExplicitDense((3, 1), (1,)) - }) + self.assertEqual( + param_shapes, + { + 'dense_0': ExplicitDense((4, 3), (3,)), + 'dense_1': ExplicitDense((3, 1), (1,)), + }, + ) def test_semi_explicit_dense(self): x = jnp.ones((1, 4)) y, variables = init(semi_explicit_mlp)(random.PRNGKey(0), x) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 1)) - self.assertEqual(param_shapes, { - 'dense_0': {'kernel': (4, 3), 'bias': (3,)}, - 'dense_1': {'kernel': (3, 1), 'bias': (1,)} - }) + self.assertEqual( + param_shapes, + { + 'dense_0': {'kernel': (4, 3), 'bias': (3,)}, + 'dense_1': {'kernel': (3, 1), 'bias': (1,)}, + }, + ) if __name__ == '__main__': diff --git a/tests/core/design/core_flow_test.py b/tests/core/design/core_flow_test.py index 215fdc0fc0..a9b3229861 100644 --- a/tests/core/design/core_flow_test.py +++ b/tests/core/design/core_flow_test.py @@ -40,13 +40,11 @@ def params(self, scope: Scope, features: int): def forward(self, scope: Scope, x: Array): kernel, bias = self.params(scope, x.shape[-1]) - return jnp.dot( - x, expm(kernel)) + bias.reshape((1,) * (x.ndim - 1) + (-1,)) + return jnp.dot(x, expm(kernel)) + bias.reshape((1,) * (x.ndim - 1) + (-1,)) def backward(self, scope: Scope, y: Array): kernel, bias = self.params(scope, y.shape[-1]) - return jnp.dot( - y - bias.reshape((1,) * (y.ndim - 1) + (-1,)), expm(-kernel)) + return jnp.dot(y - bias.reshape((1,) * (y.ndim - 1) + (-1,)), expm(-kernel)) @dataclass @@ -70,14 +68,16 @@ def test_flow(self): x = jnp.ones((1, 3)) flow = StackFlow((DenseFlow(),) * 3) y, variables = init(flow.forward)(random.PRNGKey(0), x) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 3)) - self.assertEqual(param_shapes, { - '0': {'kernel': (3, 3), 'bias': (3,)}, - '1': {'kernel': (3, 3), 'bias': (3,)}, - '2': {'kernel': (3, 3), 'bias': (3,)}, - }) + self.assertEqual( + param_shapes, + { + '0': {'kernel': (3, 3), 'bias': (3,)}, + '1': {'kernel': (3, 3), 'bias': (3,)}, + '2': {'kernel': (3, 3), 'bias': (3,)}, + }, + ) x_restored = apply(flow.backward)(variables, y) self.assertTrue(jnp.allclose(x, x_restored)) diff --git a/tests/core/design/core_resnet_test.py b/tests/core/design/core_resnet_test.py index 19f1a5caa9..f98ae8a258 100644 --- a/tests/core/design/core_resnet_test.py +++ b/tests/core/design/core_resnet_test.py @@ -25,7 +25,9 @@ default_norm = partial(nn.batch_norm) -def residual_block(scope: Scope, x: Array, conv, norm, act, features: int, strides=(1, 1)): +def residual_block( + scope: Scope, x: Array, conv, norm, act, features: int, strides=(1, 1) +): residual = x x = scope.child(conv, 'conv_1')(x, features, (1, 1)) x = scope.child(norm, 'bn_1')(x) @@ -37,19 +39,24 @@ def residual_block(scope: Scope, x: Array, conv, norm, act, features: int, strid x = scope.child(norm, 'bn_3')(x) if x.shape != residual.shape: - residual = scope.child(conv, 'proj_conv')(residual, 4 * features, (1, 1), strides=strides) + residual = scope.child(conv, 'proj_conv')( + residual, 4 * features, (1, 1), strides=strides + ) residual = scope.child(norm, 'proj_bn')(residual) return act(residual + x) -def resnet(scope: Scope, x, - block_sizes=(3, 4, 6, 3), - features=16, num_classes=1000, - dtype=jnp.float32, - norm=default_norm, - act=nn.relu, - ): +def resnet( + scope: Scope, + x, + block_sizes=(3, 4, 6, 3), + features=16, + num_classes=1000, + dtype=jnp.float32, + norm=default_norm, + act=nn.relu, +): conv = partial(nn.conv, bias=False, dtype=dtype) norm = partial(norm, dtype=dtype) @@ -63,7 +70,7 @@ def resnet(scope: Scope, x, strides = (1, 1) if i > 0 and j == 0: strides = (2, 2) - block_features = features * 2 ** i + block_features = features * 2**i block_scope = scope.push(f'block_{i}_{j}') x = residual_block(block_scope, x, conv, norm, act, block_features, strides) # we can access parameters of the sub module by operating on the scope @@ -79,58 +86,56 @@ class ResNetTest(absltest.TestCase): def test_resnet(self): block_sizes = (2, 2) x = random.normal(random.PRNGKey(0), (1, 64, 64, 3)) - y, variables = init(resnet)(random.PRNGKey(1), x, block_sizes=block_sizes, features=16) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + y, variables = init(resnet)( + random.PRNGKey(1), x, block_sizes=block_sizes, features=16 + ) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 1000)) - self.assertEqual(param_shapes, { - 'init_conv': {'kernel': (7, 7, 3, 16)}, - 'init_bn': {'bias': (16,), 'scale': (16,)}, - 'out': {'kernel': (128, 1000), 'bias': (1000,)}, - 'block_0_0': { - 'conv_1': {'kernel': (1, 1, 16, 16)}, - 'conv_2': {'kernel': (3, 3, 16, 64)}, - 'conv_3': {'kernel': (1, 1, 64, 64)}, - - 'bn_1': {'bias': (16,), 'scale': (16,)}, - 'bn_2': {'bias': (64,), 'scale': (64,)}, - 'bn_3': {'bias': (64,), 'scale': (64,)}, - - 'proj_conv': {'kernel': (1, 1, 16, 64)}, - 'proj_bn': {'bias': (64,), 'scale': (64,)}, + self.assertEqual( + param_shapes, + { + 'init_conv': {'kernel': (7, 7, 3, 16)}, + 'init_bn': {'bias': (16,), 'scale': (16,)}, + 'out': {'kernel': (128, 1000), 'bias': (1000,)}, + 'block_0_0': { + 'conv_1': {'kernel': (1, 1, 16, 16)}, + 'conv_2': {'kernel': (3, 3, 16, 64)}, + 'conv_3': {'kernel': (1, 1, 64, 64)}, + 'bn_1': {'bias': (16,), 'scale': (16,)}, + 'bn_2': {'bias': (64,), 'scale': (64,)}, + 'bn_3': {'bias': (64,), 'scale': (64,)}, + 'proj_conv': {'kernel': (1, 1, 16, 64)}, + 'proj_bn': {'bias': (64,), 'scale': (64,)}, + }, + 'block_0_1': { + 'conv_1': {'kernel': (1, 1, 64, 16)}, + 'conv_2': {'kernel': (3, 3, 16, 64)}, + 'conv_3': {'kernel': (1, 1, 64, 64)}, + 'bn_1': {'bias': (16,), 'scale': (16,)}, + 'bn_2': {'bias': (64,), 'scale': (64,)}, + 'bn_3': {'bias': (64,), 'scale': (64,)}, + }, + 'block_1_0': { + 'conv_1': {'kernel': (1, 1, 64, 32)}, + 'conv_2': {'kernel': (3, 3, 32, 128)}, + 'conv_3': {'kernel': (1, 1, 128, 128)}, + 'bn_1': {'bias': (32,), 'scale': (32,)}, + 'bn_2': {'bias': (128,), 'scale': (128,)}, + 'bn_3': {'bias': (128,), 'scale': (128,)}, + 'proj_conv': {'kernel': (1, 1, 64, 128)}, + 'proj_bn': {'bias': (128,), 'scale': (128,)}, + }, + 'block_1_1': { + 'conv_1': {'kernel': (1, 1, 128, 32)}, + 'conv_2': {'kernel': (3, 3, 32, 128)}, + 'conv_3': {'kernel': (1, 1, 128, 128)}, + 'bn_1': {'bias': (32,), 'scale': (32,)}, + 'bn_2': {'bias': (128,), 'scale': (128,)}, + 'bn_3': {'bias': (128,), 'scale': (128,)}, + }, }, - 'block_0_1': { - 'conv_1': {'kernel': (1, 1, 64, 16)}, - 'conv_2': {'kernel': (3, 3, 16, 64)}, - 'conv_3': {'kernel': (1, 1, 64, 64)}, - - 'bn_1': {'bias': (16,), 'scale': (16,)}, - 'bn_2': {'bias': (64,), 'scale': (64,)}, - 'bn_3': {'bias': (64,), 'scale': (64,)}, - }, - 'block_1_0': { - 'conv_1': {'kernel': (1, 1, 64, 32)}, - 'conv_2': {'kernel': (3, 3, 32, 128)}, - 'conv_3': {'kernel': (1, 1, 128, 128)}, - - 'bn_1': {'bias': (32,), 'scale': (32,)}, - 'bn_2': {'bias': (128,), 'scale': (128,)}, - 'bn_3': {'bias': (128,), 'scale': (128,)}, - - 'proj_conv': {'kernel': (1, 1, 64, 128)}, - 'proj_bn': {'bias': (128,), 'scale': (128,)}, - }, - 'block_1_1': { - 'conv_1': {'kernel': (1, 1, 128, 32)}, - 'conv_2': {'kernel': (3, 3, 32, 128)}, - 'conv_3': {'kernel': (1, 1, 128, 128)}, - - 'bn_1': {'bias': (32,), 'scale': (32,)}, - 'bn_2': {'bias': (128,), 'scale': (128,)}, - 'bn_3': {'bias': (128,), 'scale': (128,)}, - }, - }) + ) if __name__ == '__main__': diff --git a/tests/core/design/core_scan_test.py b/tests/core/design/core_scan_test.py index 410e914bc5..e47cc99e26 100644 --- a/tests/core/design/core_scan_test.py +++ b/tests/core/design/core_scan_test.py @@ -21,10 +21,9 @@ from jax import random, numpy as jnp -def mlp_scan(scope: Scope, xs: Array, - share_params: bool = False): - +def mlp_scan(scope: Scope, xs: Array, share_params: bool = False): scope.variable('counter', 'i', jnp.zeros, ()) + def body_fn(scope, c, x): counter = scope.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 @@ -36,13 +35,15 @@ def body_fn(scope, c, x): body_fn, variable_carry='counter', variable_broadcast='params', - split_rngs={'params': False})(scope, (), xs) + split_rngs={'params': False}, + )(scope, (), xs) else: _, ys = lift.scan( body_fn, variable_carry='counter', variable_axes={'params': 0}, - split_rngs={'params': True})(scope, (), xs) + split_rngs={'params': True}, + )(scope, (), xs) # output layer return ys @@ -55,12 +56,14 @@ def test_scan_unshared_params(self): x = jnp.concatenate([x, x], 0) y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=False) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) self.assertEqual(variables['counter']['i'], 2) - self.assertEqual(param_shapes, { - 'dense_0': {'kernel': (2, 4, 1), 'bias': (2, 1)}, - }) + self.assertEqual( + param_shapes, + { + 'dense_0': {'kernel': (2, 4, 1), 'bias': (2, 1)}, + }, + ) self.assertNotEqual(y[0], y[1]) k1, k2 = variables['params']['dense_0']['kernel'] @@ -71,12 +74,14 @@ def test_scan_shared_params(self): x = jnp.concatenate([x, x], 0) y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=True) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) self.assertEqual(variables['counter']['i'], 2) - self.assertEqual(param_shapes, { - 'dense_0': {'kernel': (4, 1), 'bias': (1,)}, - }) + self.assertEqual( + param_shapes, + { + 'dense_0': {'kernel': (4, 1), 'bias': (1,)}, + }, + ) self.assertEqual(y[0], y[1]) diff --git a/tests/core/design/core_tied_autoencoder_test.py b/tests/core/design/core_tied_autoencoder_test.py index cc259fa55e..b0c8bf8142 100644 --- a/tests/core/design/core_tied_autoencoder_test.py +++ b/tests/core/design/core_tied_autoencoder_test.py @@ -28,13 +28,12 @@ def trans(variables): return jax.tree_util.tree_map(lambda x: x.T, variables) return lift.map_variables( - fn, "params", map_in_fn=trans, map_out_fn=trans, - mutable=True) + fn, 'params', map_in_fn=trans, map_out_fn=trans, mutable=True + ) @dataclass class TiedAutoEncoder: - latents: int features: int @@ -46,8 +45,7 @@ def encode(self, scope, x): return nn.dense(scope, x, self.latents, bias=False) def decode(self, scope, z): - return transpose(nn.dense)( - scope, z, self.features, bias=False) + return transpose(nn.dense)(scope, z, self.features, bias=False) class TiedAutoEncoderTest(absltest.TestCase): @@ -57,11 +55,13 @@ def test_tied_auto_encoder(self): x = jnp.ones((1, ae.features)) x_r, variables = init(ae)(random.PRNGKey(0), x) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) - self.assertEqual(param_shapes, { - 'kernel': (4, 2), - }) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) + self.assertEqual( + param_shapes, + { + 'kernel': (4, 2), + }, + ) self.assertEqual(x.shape, x_r.shape) def test_init_from_decoder(self): @@ -69,11 +69,13 @@ def test_init_from_decoder(self): z = jnp.ones((1, ae.latents)) x_r, variables = init(ae.decode)(random.PRNGKey(0), z) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) - self.assertEqual(param_shapes, { - 'kernel': (4, 2), - }) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) + self.assertEqual( + param_shapes, + { + 'kernel': (4, 2), + }, + ) self.assertEqual(x_r.shape, (1, 4)) diff --git a/tests/core/design/core_vmap_test.py b/tests/core/design/core_vmap_test.py index a85203b87d..85ad0d3a65 100644 --- a/tests/core/design/core_vmap_test.py +++ b/tests/core/design/core_vmap_test.py @@ -22,20 +22,27 @@ from jax import random, numpy as jnp -def mlp_vmap(scope: Scope, x: Array, - sizes: Sequence[int] = (8, 1), - act_fn: Callable[[Array], Array] = nn.relu, - share_params: bool = False): +def mlp_vmap( + scope: Scope, + x: Array, + sizes: Sequence[int] = (8, 1), + act_fn: Callable[[Array], Array] = nn.relu, + share_params: bool = False, +): if share_params: - dense_vmap = lift.vmap(nn.dense, - in_axes=(0, None), - variable_axes={'params': None}, - split_rngs={'params': False}) + dense_vmap = lift.vmap( + nn.dense, + in_axes=(0, None), + variable_axes={'params': None}, + split_rngs={'params': False}, + ) else: - dense_vmap = lift.vmap(nn.dense, - in_axes=(0, None), - variable_axes={'params': 0}, - split_rngs={'params': True}) + dense_vmap = lift.vmap( + nn.dense, + in_axes=(0, None), + variable_axes={'params': 0}, + split_rngs={'params': True}, + ) # hidden layers for size in sizes[:-1]: @@ -54,12 +61,14 @@ def test_vmap_shared(self): y, variables = init(mlp_vmap)(random.PRNGKey(1), x, share_params=True) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) - self.assertEqual(param_shapes, { - 'hidden_0' : {'kernel': (4, 8), 'bias': (8,)}, - 'out': {'kernel': (8, 1), 'bias': (1,)}, - }) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) + self.assertEqual( + param_shapes, + { + 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, + 'out': {'kernel': (8, 1), 'bias': (1,)}, + }, + ) self.assertEqual(y.shape, (2, 1)) self.assertTrue(jnp.allclose(y[0], y[1])) @@ -69,12 +78,14 @@ def test_vmap_unshared(self): y, variables = init(mlp_vmap)(random.PRNGKey(1), x, share_params=False) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) - self.assertEqual(param_shapes, { - 'hidden_0': {'kernel': (2, 4, 8), 'bias': (2, 8)}, - 'out': {'kernel': (2, 8, 1), 'bias': (2, 1)}, - }) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) + self.assertEqual( + param_shapes, + { + 'hidden_0': {'kernel': (2, 4, 8), 'bias': (2, 8)}, + 'out': {'kernel': (2, 8, 1), 'bias': (2, 1)}, + }, + ) self.assertEqual(y.shape, (2, 1)) self.assertFalse(jnp.allclose(y[0], y[1])) diff --git a/tests/core/design/core_weight_std_test.py b/tests/core/design/core_weight_std_test.py index e192c9b636..8320ddbd80 100644 --- a/tests/core/design/core_weight_std_test.py +++ b/tests/core/design/core_weight_std_test.py @@ -39,12 +39,13 @@ def std(variables): # this way we avoid lost mutations to param # map_variables also avoids accidental reuse of rngs # and it makes sure that other state is updated correctly (not twice during init!) - return lift.map_variables(fn, "params", std, init=True) + return lift.map_variables(fn, 'params', std, init=True) -def mlp(scope: Scope, x: Array, - sizes: Sequence[int] = (8, 1)): - std_dense = weight_std(partial( - nn.dense, kernel_init=nn.initializers.normal(stddev=1e5))) + +def mlp(scope: Scope, x: Array, sizes: Sequence[int] = (8, 1)): + std_dense = weight_std( + partial(nn.dense, kernel_init=nn.initializers.normal(stddev=1e5)) + ) for size in sizes[:-1]: x = scope.child(std_dense, prefix='hidden_')(x, size) return scope.child(nn.dense, 'out')(x, sizes[-1]) @@ -53,17 +54,25 @@ def mlp(scope: Scope, x: Array, class WeightStdTest(absltest.TestCase): def test_weight_std(self): - x = random.normal(random.PRNGKey(0), (1, 4,)) + x = random.normal( + random.PRNGKey(0), + ( + 1, + 4, + ), + ) y, variables = init(mlp)(random.PRNGKey(1), x) - param_shapes = unfreeze( - jax.tree_util.tree_map(jnp.shape, variables['params'])) - self.assertEqual(param_shapes, { - 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, - 'out': {'kernel': (8, 1), 'bias': (1,)}, - }) + param_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, variables['params'])) + self.assertEqual( + param_shapes, + { + 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, + 'out': {'kernel': (8, 1), 'bias': (1,)}, + }, + ) self.assertEqual(y.shape, (1, 1)) - self.assertTrue(y.ravel() < 1.) + self.assertTrue(y.ravel() < 1.0) y2 = apply(mlp)(variables, x) self.assertTrue(jnp.allclose(y, y2)) diff --git a/tests/early_stopping_test.py b/tests/early_stopping_test.py index b7f3e6e361..92b74fb9ed 100644 --- a/tests/early_stopping_test.py +++ b/tests/early_stopping_test.py @@ -29,13 +29,12 @@ class EarlyStoppingTests(absltest.TestCase): def test_update(self): - es = early_stopping.EarlyStopping(min_delta=0, - patience=0) + es = early_stopping.EarlyStopping(min_delta=0, patience=0) for i in range(2): improve_steps = 0 for step in range(10): - metric = 1. + metric = 1.0 did_improve, es = es.update(metric) if not did_improve: improve_steps += 1 @@ -48,12 +47,10 @@ def test_update(self): es = es.reset() # ensure object is reusable if reset. def test_patience(self): - es = early_stopping.EarlyStopping(min_delta=0, - patience=0) - patient_es = early_stopping.EarlyStopping(min_delta=0, - patience=6) + es = early_stopping.EarlyStopping(min_delta=0, patience=0) + patient_es = early_stopping.EarlyStopping(min_delta=0, patience=6) for step in range(10): - metric = 1. + metric = 1.0 did_improve, es = es.update(metric) if es.should_stop: break @@ -61,7 +58,7 @@ def test_patience(self): self.assertEqual(step, 1) for patient_step in range(10): - metric = 1. + metric = 1.0 did_improve, patient_es = patient_es.update(metric) if patient_es.should_stop: break @@ -69,13 +66,10 @@ def test_patience(self): self.assertEqual(patient_step, 7) def test_delta(self): - es = early_stopping.EarlyStopping(min_delta=0, - patience=0) - delta_es = early_stopping.EarlyStopping(min_delta=1e-3, - patience=0) - delta_patient_es = early_stopping.EarlyStopping(min_delta=1e-3, - patience=1) - metric = 1. + es = early_stopping.EarlyStopping(min_delta=0, patience=0) + delta_es = early_stopping.EarlyStopping(min_delta=1e-3, patience=0) + delta_patient_es = early_stopping.EarlyStopping(min_delta=1e-3, patience=1) + metric = 1.0 for step in range(100): metric -= 1e-4 did_improve, es = es.update(metric) @@ -84,7 +78,7 @@ def test_delta(self): self.assertEqual(step, 99) - metric = 1. + metric = 1.0 for step in range(100): metric -= 1e-4 did_improve, delta_es = delta_es.update(metric) @@ -93,8 +87,18 @@ def test_delta(self): self.assertEqual(step, 1) - metrics = [0.01, 0.005, 0.0033, 0.0025, 0.002, - 0.0017, 0.0014, 0.0012, 0.0011, 0.001] + metrics = [ + 0.01, + 0.005, + 0.0033, + 0.0025, + 0.002, + 0.0017, + 0.0014, + 0.0012, + 0.0011, + 0.001, + ] improvement_steps = 0 for step in range(10): metric = metrics[step] diff --git a/tests/io_test.py b/tests/io_test.py index 2bae68e44a..8d2ff921c2 100644 --- a/tests/io_test.py +++ b/tests/io_test.py @@ -31,16 +31,15 @@ class IOTest(parameterized.TestCase): @parameterized.parameters( - {'backend_mode': io.BackendMode.DEFAULT}, - {'backend_mode': io.BackendMode.TF} + {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF} ) def test_override(self, backend_mode): with io.override_mode(backend_mode): self.assertEqual(io.io_mode, backend_mode) @parameterized.parameters( - {'write_mode': io.BackendMode.DEFAULT, 'read_mode': io.BackendMode.TF}, - {'write_mode': io.BackendMode.TF, 'read_mode': io.BackendMode.DEFAULT} + {'write_mode': io.BackendMode.DEFAULT, 'read_mode': io.BackendMode.TF}, + {'write_mode': io.BackendMode.TF, 'read_mode': io.BackendMode.DEFAULT}, ) def test_GFile(self, write_mode, read_mode): test_string = b'testing write and read' @@ -72,8 +71,8 @@ def test_listdir(self): self.assertEqual(default_dir_set, tf_dir_set) @parameterized.parameters( - {'create_temp_fn': tempfile.TemporaryDirectory}, - {'create_temp_fn': tempfile.NamedTemporaryFile} + {'create_temp_fn': tempfile.TemporaryDirectory}, + {'create_temp_fn': tempfile.NamedTemporaryFile}, ) def test_isdir(self, create_temp_fn): with create_temp_fn() as temp: @@ -107,8 +106,8 @@ def test_copy(self): self.assertEqual(file.read(), test_string) @parameterized.parameters( - {'backend_mode': io.BackendMode.DEFAULT, 'error_type': errors.AlreadyExistsError}, - {'backend_mode': io.BackendMode.TF, 'error_type': tf.errors.AlreadyExistsError}, + {'backend_mode': io.BackendMode.DEFAULT, 'error_type': errors.AlreadyExistsError}, + {'backend_mode': io.BackendMode.TF, 'error_type': tf.errors.AlreadyExistsError}, ) def test_copy_raises_error(self, backend_mode, error_type): with tempfile.NamedTemporaryFile() as temp_file: @@ -135,8 +134,8 @@ def test_rename(self): self.assertTrue(os.path.exists(rename2_path)) @parameterized.parameters( - {'backend_mode': io.BackendMode.DEFAULT, 'error_type': errors.AlreadyExistsError}, - {'backend_mode': io.BackendMode.TF, 'error_type': tf.errors.AlreadyExistsError}, + {'backend_mode': io.BackendMode.DEFAULT, 'error_type': errors.AlreadyExistsError}, + {'backend_mode': io.BackendMode.TF, 'error_type': tf.errors.AlreadyExistsError}, ) def test_rename_raises_error(self, backend_mode, error_type): with tempfile.NamedTemporaryFile() as temp_file: @@ -146,7 +145,6 @@ def test_rename_raises_error(self, backend_mode, error_type): def test_exists(self): with tempfile.NamedTemporaryFile() as temp_file: - with io.override_mode(io.BackendMode.DEFAULT): default_exists = io.exists(temp_file.name) @@ -156,8 +154,7 @@ def test_exists(self): self.assertEqual(default_exists, tf_exists) @parameterized.parameters( - {'backend_mode': io.BackendMode.DEFAULT}, - {'backend_mode': io.BackendMode.TF} + {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF} ) def test_makedirs(self, backend_mode): with tempfile.TemporaryDirectory() as temp_dir_path: @@ -184,8 +181,7 @@ def test_glob(self): self.assertEqual(default_glob_set, tf_glob_set) @parameterized.parameters( - {'backend_mode': io.BackendMode.DEFAULT}, - {'backend_mode': io.BackendMode.TF} + {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF} ) def test_remove(self, backend_mode): with tempfile.TemporaryDirectory() as temp_dir_path: @@ -200,8 +196,7 @@ def test_remove(self, backend_mode): self.assertTrue(not os.path.exists(test_path)) @parameterized.parameters( - {'backend_mode': io.BackendMode.DEFAULT}, - {'backend_mode': io.BackendMode.TF} + {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF} ) def test_rmtree(self, backend_mode): with tempfile.TemporaryDirectory() as temp_dir_path: @@ -220,10 +215,8 @@ def test_rmtree(self, backend_mode): self.assertTrue(not os.path.exists(dir0_path)) - @parameterized.parameters( - {'backend_mode': io.BackendMode.DEFAULT}, - {'backend_mode': io.BackendMode.TF} + {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF} ) def test_getsize(self, backend_mode): with tempfile.TemporaryDirectory() as temp_dir_path: diff --git a/tests/jax_utils_test.py b/tests/jax_utils_test.py index d77aab33bd..5f306a552e 100644 --- a/tests/jax_utils_test.py +++ b/tests/jax_utils_test.py @@ -45,9 +45,9 @@ def add(a, b): return a + b x = np.arange(bs, dtype=dtype) - y = add(x, 10*x) + y = add(x, 10 * x) chex.assert_type(y.dtype, x.dtype) - np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x)) @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) def test_trees(self, dtype, bs): @@ -57,9 +57,9 @@ def add(a, b): return a['a'] + b[0] x = np.arange(bs, dtype=dtype) - y = add(dict(a=x), (10*x, )) + y = add(dict(a=x), (10 * x,)) chex.assert_type(y.dtype, x.dtype) - np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x)) @parameterized.parameters(DTYPES) def test_min_device_batch_avoids_recompile(self, dtype): @@ -73,9 +73,9 @@ def add(a, b): for bs in self.BATCH_SIZES: x = np.arange(bs, dtype=dtype) - y = add(x, 10*x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg + y = add(x, 10 * x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg chex.assert_type(y.dtype, x.dtype) - np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x)) @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) def test_static_argnum(self, dtype, bs): diff --git a/tests/linen/initializers_test.py b/tests/linen/initializers_test.py index 3a44d3fadf..26274dd454 100644 --- a/tests/linen/initializers_test.py +++ b/tests/linen/initializers_test.py @@ -33,27 +33,31 @@ class InitializersTest(parameterized.TestCase): @parameterized.parameters( - { - 'builder_fn': initializers.zeros_init, - 'params_shape': (2, 3), - 'expected_params': jnp.zeros((2, 3)), - }, { - 'builder_fn': initializers.ones_init, - 'params_shape': (3, 2), - 'expected_params': jnp.ones((3, 2)), - }) + { + 'builder_fn': initializers.zeros_init, + 'params_shape': (2, 3), + 'expected_params': jnp.zeros((2, 3)), + }, + { + 'builder_fn': initializers.ones_init, + 'params_shape': (3, 2), + 'expected_params': jnp.ones((3, 2)), + }, + ) def test_call_builder(self, builder_fn, params_shape, expected_params): params = builder_fn()(random.PRNGKey(42), params_shape, jnp.float32) np.testing.assert_allclose(params, expected_params) @parameterized.parameters( - { - 'builder_fn': initializers.zeros_init, - 'expected_params': jnp.zeros((2, 5)), - }, { - 'builder_fn': initializers.ones_init, - 'expected_params': jnp.ones((2, 5)), - }) + { + 'builder_fn': initializers.zeros_init, + 'expected_params': jnp.zeros((2, 5)), + }, + { + 'builder_fn': initializers.ones_init, + 'expected_params': jnp.ones((2, 5)), + }, + ) def test_kernel_builder(self, builder_fn, expected_params): layer = nn.Dense(5, kernel_init=builder_fn()) params = layer.init(random.PRNGKey(42), jnp.empty((3, 2)))['params'] @@ -61,4 +65,4 @@ def test_kernel_builder(self, builder_fn, expected_params): if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() diff --git a/tests/linen/kw_only_dataclasses_test.py b/tests/linen/kw_only_dataclasses_test.py index fe85a781fa..753f1b8c64 100644 --- a/tests/linen/kw_only_dataclasses_test.py +++ b/tests/linen/kw_only_dataclasses_test.py @@ -24,7 +24,6 @@ class KwOnlyDataclassesTest(absltest.TestCase): def test_kwonly_args_moved_to_end(self): - @kw_only_dataclasses.dataclass class TestClass: a: int = 1 @@ -47,7 +46,6 @@ class TestClass: self.assertDictEqual(dataclasses.asdict(v3), dict(a=1, b=2, c=30)) def test_base_optional_subclass_required(self): - @kw_only_dataclasses.dataclass class Parent: a: int = kw_only_dataclasses.field(default=2, kw_only=True) @@ -103,11 +101,9 @@ class C(B): self.assertEqual(c_params['size'].default, inspect.Parameter.empty) value = C(4, 'foo') # pylint: disable=too-many-function-args - self.assertDictEqual( - dataclasses.asdict(value), dict(name='foo', size=4, x=2, y=3)) + self.assertDictEqual(dataclasses.asdict(value), dict(name='foo', size=4, x=2, y=3)) def test_kwonly_marker(self): - @kw_only_dataclasses.dataclass class A: x: float diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index e16346c93b..af3d970bb2 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -33,6 +33,7 @@ # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() + class AttentionTest(parameterized.TestCase): def test_multihead_self_attention(self): @@ -141,7 +142,10 @@ def test_causal_mask_1d(self): ts = np.arange(16) mask_1d_simple = (ts[:, None] >= ts[None, :])[None, None, :, :] mask_1d_simple = jnp.broadcast_to(mask_1d_simple, (3, 1, 16, 16)) - np.testing.assert_allclose(mask_1d, mask_1d_simple,) + np.testing.assert_allclose( + mask_1d, + mask_1d_simple, + ) @parameterized.parameters([((5,), (1,)), ((6, 5), (2,))]) def test_decoding(self, spatial_shape, attn_dims): @@ -150,29 +154,28 @@ def test_decoding(self, spatial_shape, attn_dims): num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) - inputs = random.normal( - key1, (bs,) + spatial_shape + (num_heads * num_features,)) + inputs = random.normal(key1, (bs,) + spatial_shape + (num_heads * num_features,)) module = nn.SelfAttention( num_heads=num_heads, qkv_features=num_heads * num_features, precision=lax.Precision.HIGHEST, deterministic=False, - decode=False) + decode=False, + ) decode_module = module.clone(decode=True) initial_vars = decode_module.init(key2, inputs) state, params = pop(initial_vars, 'params') causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape)) - y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))( - inputs, causal_mask) + y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))(inputs, causal_mask) + # feed the inputs sequentially to simulate decoding def body_fn(state, x): - y, state = decode_module.apply( - {'params': params, **state}, x, mutable=['cache']) + y, state = decode_module.apply({'params': params, **state}, x, mutable=['cache']) return state, y + # scan_in_dim supports scanning multiple dims - _, y = jax_utils.scan_in_dim(body_fn, state, inputs, - axis=attn_dims, keepdims=True) + _, y = jax_utils.scan_in_dim(body_fn, state, inputs, axis=attn_dims, keepdims=True) np.testing.assert_allclose(y_ref, y, atol=1e-5) @@ -188,9 +191,8 @@ def test_autoregresive_receptive_field_1d(self): inputs = random.normal(rng2, input_shape) module = nn.MultiHeadDotProductAttention( - num_heads=num_heads, - kernel_init=jax.nn.initializers.ones, - deterministic=False) + num_heads=num_heads, kernel_init=jax.nn.initializers.ones, deterministic=False + ) initial_vars = module.init(rng1, inputs, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) @@ -209,14 +211,18 @@ def get_receptive_field_1d(pos): for i in range(length): deps = get_receptive_field_1d(i) - assert (deps[:i] == 1).all(), ('Receptive Field Error: Some of the ' - 'previous postions are not reachable ' - 'in autoregressive self-attention.') + assert (deps[:i] == 1).all(), ( + 'Receptive Field Error: Some of the ' + 'previous postions are not reachable ' + 'in autoregressive self-attention.' + ) if i != length - 1: k = i + 1 - assert (deps[k:] == 0).all(), ('Receptive Field Error: Some of the ' - 'future postions are reachable in ' - 'autoregressive self-attention.') + assert (deps[k:] == 0).all(), ( + 'Receptive Field Error: Some of the ' + 'future postions are reachable in ' + 'autoregressive self-attention.' + ) if __name__ == '__main__': diff --git a/tests/linen/linen_combinators_test.py b/tests/linen/linen_combinators_test.py index 67f41520fb..16a4de480d 100644 --- a/tests/linen/linen_combinators_test.py +++ b/tests/linen/linen_combinators_test.py @@ -41,8 +41,8 @@ def __call__(self, inputs): if self.activation is not None: x = self.activation(x) x = nn.Dense( - features=self.layer_sizes[-1], kernel_init=nn.initializers.ones_init())( - x) + features=self.layer_sizes[-1], kernel_init=nn.initializers.ones_init() + )(x) if self.activation_final is None: return x return self.activation_final(x) @@ -55,8 +55,8 @@ class AttentionTuple(nn.Module): @nn.compact def __call__(self, query, key_value): output = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, qkv_features=self.qkv_features)(query, - key_value) + num_heads=self.num_heads, qkv_features=self.qkv_features + )(query, key_value) return output, key_value @@ -67,8 +67,8 @@ class AttentionDict(nn.Module): @nn.compact def __call__(self, query, key_value): output = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, qkv_features=self.qkv_features)(query, - key_value) + num_heads=self.num_heads, qkv_features=self.qkv_features + )(query, key_value) return dict(query=output, key_value=key_value) @@ -84,15 +84,14 @@ def test_construction(self): def test_fails_if_layers_empty(self): sequential = nn.Sequential([]) - with self.assertRaisesRegex(ValueError, - 'Empty Sequential module'): + with self.assertRaisesRegex(ValueError, 'Empty Sequential module'): sequential.init(random.PRNGKey(42), jnp.ones((3, 5))) def test_same_output_as_mlp(self): sequential = nn.Sequential([ nn.Dense(4, kernel_init=nn.initializers.ones_init()), nn.Dense(8, kernel_init=nn.initializers.ones_init()), - nn.Dense(2, kernel_init=nn.initializers.ones_init()) + nn.Dense(2, kernel_init=nn.initializers.ones_init()), ]) mlp = MLP(layer_sizes=[4, 8, 2]) @@ -107,15 +106,17 @@ def test_same_output_as_mlp(self): def test_same_output_as_mlp_with_activation(self): sequential = nn.Sequential([ - nn.Dense(4, kernel_init=nn.initializers.ones_init()), nn.relu, - nn.Dense(8, kernel_init=nn.initializers.ones_init()), nn.relu, - nn.Dense(2, kernel_init=nn.initializers.ones_init()), nn.log_softmax + nn.Dense(4, kernel_init=nn.initializers.ones_init()), + nn.relu, + nn.Dense(8, kernel_init=nn.initializers.ones_init()), + nn.relu, + nn.Dense(2, kernel_init=nn.initializers.ones_init()), + nn.log_softmax, ]) mlp = MLP( - layer_sizes=[4, 8, 2], - activation=nn.relu, - activation_final=nn.log_softmax) + layer_sizes=[4, 8, 2], activation=nn.relu, activation_final=nn.log_softmax + ) key1, key2 = random.split(random.PRNGKey(0), 2) x = random.uniform(key1, (3, 5)) @@ -126,7 +127,6 @@ def test_same_output_as_mlp_with_activation(self): output_2 = mlp.apply(params_2, x) np.testing.assert_array_equal(output_1, output_2) - def test_tuple_output(self): sequential = nn.Sequential([ AttentionTuple(), diff --git a/tests/linen/linen_dtypes_test.py b/tests/linen/linen_dtypes_test.py index ab56d9374f..7233486c5e 100644 --- a/tests/linen/linen_dtypes_test.py +++ b/tests/linen/linen_dtypes_test.py @@ -25,12 +25,13 @@ import jax from jax import numpy as jnp -default_float_dtype = jnp.result_type(1.) +default_float_dtype = jnp.result_type(1.0) + class DtypesTest(absltest.TestCase): def test_no_inexact_dtype(self): - i32 = jnp.int32(1.) + i32 = jnp.int32(1.0) self.assertEqual(dtypes.canonicalize_dtype(i32, inexact=False), jnp.int32) def test_inexact_dtype(self): @@ -39,12 +40,12 @@ def test_inexact_dtype(self): self.assertEqual(dtypes.canonicalize_dtype(i64), jnp.float32) i32 = jnp.int32(1) self.assertEqual(dtypes.canonicalize_dtype(i32), jnp.float32) - i16 = jnp.int16(1.) + i16 = jnp.int16(1.0) self.assertEqual(dtypes.canonicalize_dtype(i16), jnp.float32) def test_explicit_downcast(self): - f32 = jnp.float32(1.) - x, = dtypes.promote_dtype(f32, dtype=jnp.float16) + f32 = jnp.float32(1.0) + (x,) = dtypes.promote_dtype(f32, dtype=jnp.float16) self.assertEqual(x.dtype, jnp.float16) diff --git a/tests/linen/linen_linear_test.py b/tests/linen/linen_linear_test.py index 3d5dca6ad6..0146979ae9 100644 --- a/tests/linen/linen_linear_test.py +++ b/tests/linen/linen_linear_test.py @@ -46,7 +46,7 @@ def test_dense(self): y, _ = dense_module.init_with_output(rng, x) self.assertEqual(y.shape, (1, 4)) self.assertEqual(y.dtype, jnp.float32) - np.testing.assert_allclose(y, np.full((1, 4), 4.)) + np.testing.assert_allclose(y, np.full((1, 4), 4.0)) def test_dense_extra_batch_dims(self): rng = dict(params=random.PRNGKey(0)) @@ -57,7 +57,7 @@ def test_dense_extra_batch_dims(self): bias_init=initializers.ones, ) y, _ = dense_module.init_with_output(rng, x) - np.testing.assert_allclose(y, np.full((1, 2, 4), 4.)) + np.testing.assert_allclose(y, np.full((1, 2, 4), 4.0)) def test_dense_no_bias(self): rng = dict(params=random.PRNGKey(0)) @@ -68,7 +68,7 @@ def test_dense_no_bias(self): kernel_init=initializers.ones, ) y, _ = dense_module.init_with_output(rng, x) - np.testing.assert_allclose(y, np.full((1, 4), 3.)) + np.testing.assert_allclose(y, np.full((1, 4), 3.0)) def test_dense_is_dense_general(self): x = jax.random.normal(random.PRNGKey(0), (5, 3)) @@ -108,7 +108,7 @@ def test_dense_general_two_out(self): bias_init=initializers.ones, ) y, _ = dg_module.init_with_output(rng, x) - np.testing.assert_allclose(y, np.full((1, 2, 2), 4.)) + np.testing.assert_allclose(y, np.full((1, 2, 2), 4.0)) def test_dense_general_two_in(self): rng = dict(params=random.PRNGKey(0)) @@ -120,17 +120,19 @@ def test_dense_general_two_in(self): bias_init=initializers.ones, ) y, _ = dg_module.init_with_output(rng, x) - np.testing.assert_allclose(y, np.full((1, 3), 5.)) + np.testing.assert_allclose(y, np.full((1, 3), 5.0)) def test_dense_general_batch_dim(self): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((2, 1, 3, 5)) - state = {'counter': 0.} + state = {'counter': 0.0} + def _counter_init(rng, shape, dtype, state): del rng, dtype - state['counter'] += 1. + state['counter'] += 1.0 return jnp.full(shape, state['counter']) + counter_init = functools.partial(_counter_init, state=state) dg_module = nn.DenseGeneral( @@ -141,12 +143,14 @@ def _counter_init(rng, shape, dtype, state): kernel_init=counter_init, ) y, _ = dg_module.init_with_output(rng, x) - target = np.full((2, 1, 7), 16.) + target = np.full((2, 1, 7), 16.0) np.testing.assert_allclose(y, target) - @parameterized.parameters([((-2, 3), (), 'bijk,jklm->bilm'), - ((3, -2), (), 'bijk,jklm->bilm'), - ((-2, 3), (0,), 'bijk,bjklm->bilm')]) + @parameterized.parameters([ + ((-2, 3), (), 'bijk,jklm->bilm'), + ((3, -2), (), 'bijk,jklm->bilm'), + ((-2, 3), (0,), 'bijk,bjklm->bilm'), + ]) def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((16, 8, 9, 10)) @@ -159,13 +163,11 @@ def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): kernel_init=initializers.normal(), ) y, initial_params = dg_module.init_with_output(rng, x) - target = np.einsum(einsum_expr, x, initial_params['params']['kernel']) + 1. + target = np.einsum(einsum_expr, x, initial_params['params']['kernel']) + 1.0 np.testing.assert_allclose(y, target, atol=1e-6) def test_complex_params_dense(self): - dense = nn.Dense( - features=2, - param_dtype=jnp.complex64) + dense = nn.Dense(features=2, param_dtype=jnp.complex64) x = jnp.ones((1, 2), jnp.float32) variables = dense.init(random.PRNGKey(0), x) self.assertEqual(variables['params']['kernel'].dtype, jnp.complex64) @@ -174,8 +176,7 @@ def test_complex_params_dense(self): self.assertEqual(y.dtype, jnp.complex64) def test_complex_input_dense(self): - dense = nn.Dense( - features=2) + dense = nn.Dense(features=2) x = jnp.ones((1, 2), jnp.complex64) variables = dense.init(random.PRNGKey(0), x) self.assertEqual(variables['params']['kernel'].dtype, jnp.float32) @@ -183,9 +184,7 @@ def test_complex_input_dense(self): y = dense.apply(variables, x) self.assertEqual(y.dtype, jnp.complex64) - - @parameterized.product( - use_bias=(True, False)) + @parameterized.product(use_bias=(True, False)) def test_conv(self, use_bias): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 8, 3)) @@ -199,12 +198,10 @@ def test_conv(self, use_bias): ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) - expected = 10. if use_bias else 9. + expected = 10.0 if use_bias else 9.0 np.testing.assert_allclose(y, np.full((1, 6, 4), expected)) - - @parameterized.product( - use_bias=(True, False)) + @parameterized.product(use_bias=(True, False)) def test_multibatch_input_conv(self, use_bias): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((2, 5, 8, 3)) @@ -218,10 +215,9 @@ def test_multibatch_input_conv(self, use_bias): ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) - expected = 10. if use_bias else 9. + expected = 10.0 if use_bias else 9.0 np.testing.assert_allclose(y, np.full((2, 5, 6, 4), expected)) - def test_conv_local(self): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 8, 2)) @@ -234,7 +230,7 @@ def test_conv_local(self): ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (6, 3 * 2, 4)) - np.testing.assert_allclose(y, np.full((1, 6, 4), 7.)) + np.testing.assert_allclose(y, np.full((1, 6, 4), 7.0)) def test_single_input_conv(self): rng = dict(params=random.PRNGKey(0)) @@ -248,7 +244,7 @@ def test_single_input_conv(self): ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) - np.testing.assert_allclose(y, np.full((6, 4), 10.)) + np.testing.assert_allclose(y, np.full((6, 4), 10.0)) def test_single_input_masked_conv(self): rng = dict(params=random.PRNGKey(0)) @@ -262,12 +258,14 @@ def test_single_input_masked_conv(self): kernel_init=initializers.ones, bias_init=initializers.ones, ) - expected = jnp.array([[10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.]]) + expected = jnp.array([ + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + ]) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) np.testing.assert_allclose(y, expected) @@ -284,7 +282,7 @@ def test_single_input_conv_local(self): ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (6, 3 * 2, 4)) - np.testing.assert_allclose(y, np.full((6, 4), 7.)) + np.testing.assert_allclose(y, np.full((6, 4), 7.0)) def test_group_conv(self): rng = dict(params=random.PRNGKey(0)) @@ -299,7 +297,7 @@ def test_group_conv(self): ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 2, 4)) - np.testing.assert_allclose(y, np.full((1, 6, 4), 7.)) + np.testing.assert_allclose(y, np.full((1, 6, 4), 7.0)) @parameterized.product( n_batch=(1, 3), @@ -307,11 +305,10 @@ def test_group_conv(self): kernel_size=(1, 2, 3, 9), n_input_features=(1, 3), input_size=(1, 8, 16), - module=(nn.Conv, nn.ConvLocal) + module=(nn.Conv, nn.ConvLocal), ) def test_circular_conv_1d_constant( - self, n_batch, n_features, kernel_size, n_input_features, input_size, - module + self, n_batch, n_features, kernel_size, n_input_features, input_size, module ): """ Test 1D convolution with circular padding: filter with all elements equal @@ -331,8 +328,7 @@ def test_circular_conv_1d_constant( ) y, initial_params = conv_module.init_with_output(rng, x) - kernel_shape = self._get_kernel_shape(x.shape, (kernel_size,), module, - n_features) + kernel_shape = self._get_kernel_shape(x.shape, (kernel_size,), module, n_features) self.assertEqual( initial_params['params']['kernel'].shape, @@ -343,16 +339,14 @@ def test_circular_conv_1d_constant( ) np.testing.assert_allclose(y, correct_ans) - def _get_kernel_shape(self, - input_shape, - kernel_size, - module, - n_features): + def _get_kernel_shape(self, input_shape, kernel_size, module, n_features): if module == nn.Conv: kernel_shape = kernel_size + (input_shape[-1], n_features) elif module == nn.ConvLocal: kernel_shape = input_shape[1:-1] + ( - input_shape[-1] * np.prod(kernel_size), n_features) + input_shape[-1] * np.prod(kernel_size), + n_features, + ) else: raise ValueError(module) return kernel_shape @@ -364,7 +358,7 @@ def _get_kernel_shape(self, n_input_features=(1, 5), input_x_size=(14,), input_y_size=(5, 10), - module=(nn.Conv, nn.ConvLocal) + module=(nn.Conv, nn.ConvLocal), ) def test_circular_conv_2d_constant( self, @@ -374,7 +368,7 @@ def test_circular_conv_2d_constant( n_input_features, input_x_size, input_y_size, - module + module, ): """ Test 2D convolution with circular padding: square filter with all elements @@ -395,8 +389,7 @@ def test_circular_conv_2d_constant( ) y, initial_params = conv_module.init_with_output(rng, x) - kernel_shape = self._get_kernel_shape(x.shape, kernel_size, module, - n_features) + kernel_shape = self._get_kernel_shape(x.shape, kernel_size, module, n_features) self.assertEqual( initial_params['params']['kernel'].shape, @@ -471,13 +464,15 @@ def test_circular_conv_1d_dilation(self): padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, - kernel_dilation=(3,)) + kernel_dilation=(3,), + ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1)) # Compare with manually computed convolution - correct_ans = np.array((3 + 2 * 1 + 4, 4 + 2 * 2 + 5, 5 + 2 * 3 + 1, - 1 + 2 * 4 + 2, 2 + 2 * 5 + 3)) + correct_ans = np.array( + (3 + 2 * 1 + 4, 4 + 2 * 2 + 5, 5 + 2 * 3 + 1, 1 + 2 * 4 + 2, 2 + 2 * 5 + 3) + ) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) @@ -488,13 +483,7 @@ def test_circular_conv_local_1d_dilation(self): rng = dict(params=random.PRNGKey(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) - kernel = np.array(( - (1, 2, 1), - (3, 4, 5), - (-1, 1, 2), - (2, 3, 4), - (-1, -2, -3) - )) + kernel = np.array(((1, 2, 1), (3, 4, 5), (-1, 1, 2), (2, 3, 4), (-1, -2, -3))) kernel = np.expand_dims(kernel, (2,)) conv_module = nn.ConvLocal( @@ -503,17 +492,19 @@ def test_circular_conv_local_1d_dilation(self): padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, - kernel_dilation=(3,) + kernel_dilation=(3,), ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (5, 3, 1)) # Compare with manually computed convolution - correct_ans = np.array((1 * 3 + 2 * 1 + 1 * 4, - 3 * 4 + 4 * 2 + 5 * 5, - -1 * 5 + 1 * 3 + 2 * 1, - 2 * 1 + 3 * 4 + 4 * 2, - -1 * 2 + -2 * 5 + -3 * 3)) + correct_ans = np.array(( + 1 * 3 + 2 * 1 + 1 * 4, + 3 * 4 + 4 * 2 + 5 * 5, + -1 * 5 + 1 * 3 + 2 * 1, + 2 * 1 + 3 * 4 + 4 * 2, + -1 * 2 + -2 * 5 + -3 * 3, + )) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) @@ -549,43 +540,23 @@ def test_circular_conv_local_2d_custom(self): Test 2d local convolution with circular padding on a 3x3 example """ rng = dict(params=random.PRNGKey(0)) - x = np.array(((1, 2, 3), - (4, 5, 6), - (7, 8, 9))) + x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9))) x = np.expand_dims(x, (0, 3)) kernel = np.array(( ( - ((0, 1, 0), - (1, 2, 1), - (0, 1, 0)), - ((0, 1, 0), - (1, 3, 1), - (0, 1, 0)), - ((0, 1, 0), - (1, 4, 1), - (0, 1, 0)) + ((0, 1, 0), (1, 2, 1), (0, 1, 0)), + ((0, 1, 0), (1, 3, 1), (0, 1, 0)), + ((0, 1, 0), (1, 4, 1), (0, 1, 0)), ), ( - ((0, 1, 0), - (1, 5, 1), - (0, 1, 0)), - ((0, 1, 0), - (1, 6, 1), - (0, 1, 0)), - ((0, 1, 0), - (1, 7, 1), - (0, 1, 0)) + ((0, 1, 0), (1, 5, 1), (0, 1, 0)), + ((0, 1, 0), (1, 6, 1), (0, 1, 0)), + ((0, 1, 0), (1, 7, 1), (0, 1, 0)), ), ( - ((0, 1, 0), - (1, 8, 1), - (0, 1, 0)), - ((0, 1, 0), - (1, 9, 1), - (0, 1, 0)), - ((0, 1, 0), - (1, 10, 1), - (0, 1, 0)) + ((0, 1, 0), (1, 8, 1), (0, 1, 0)), + ((0, 1, 0), (1, 9, 1), (0, 1, 0)), + ((0, 1, 0), (1, 10, 1), (0, 1, 0)), ), )) kernel = np.expand_dims(kernel, (3,)) @@ -602,13 +573,11 @@ def test_circular_conv_local_2d_custom(self): self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 9, 1)) # Compare with manually computed convolution - correct_ans = np.array( - ( - (2 * 1 + 7 + 2 + 4 + 3, 3 * 2 + 8 + 3 + 5 + 1, 4 * 3 + 9 + 1 + 6 + 2), - (5 * 4 + 1 + 5 + 7 + 6, 6 * 5 + 2 + 6 + 8 + 4, 7 * 6 + 3 + 4 + 9 + 5), - (8 * 7 + 4 + 8 + 1 + 9, 9 * 8 + 5 + 9 + 2 + 7, 10 * 9 + 6 + 7 + 3 + 8), - ) - ) + correct_ans = np.array(( + (2 * 1 + 7 + 2 + 4 + 3, 3 * 2 + 8 + 3 + 5 + 1, 4 * 3 + 9 + 1 + 6 + 2), + (5 * 4 + 1 + 5 + 7 + 6, 6 * 5 + 2 + 6 + 8 + 4, 7 * 6 + 3 + 4 + 9 + 5), + (8 * 7 + 4 + 8 + 1 + 9, 9 * 8 + 5 + 9 + 2 + 7, 10 * 9 + 6 + 7 + 3 + 8), + )) correct_ans = np.expand_dims(correct_ans, (0, 3)) np.testing.assert_allclose(y, correct_ans) @@ -623,10 +592,16 @@ def test_causal_conv1d(self): bias_init=initializers.ones, ) y, _ = conv_module.init_with_output(rng, x) - correct_ans = np.array([[[5., 5., 5., 5.], [9., 9., 9., 9.], - [13., 13., 13., 13.], [13., 13., 13., 13.], - [13., 13., 13., 13.], [13., 13., 13., 13.], - [13., 13., 13., 13.], [13., 13., 13., 13.]]]) + correct_ans = np.array([[ + [5.0, 5.0, 5.0, 5.0], + [9.0, 9.0, 9.0, 9.0], + [13.0, 13.0, 13.0, 13.0], + [13.0, 13.0, 13.0, 13.0], + [13.0, 13.0, 13.0, 13.0], + [13.0, 13.0, 13.0, 13.0], + [13.0, 13.0, 13.0, 13.0], + [13.0, 13.0, 13.0, 13.0], + ]]) np.testing.assert_allclose(y, correct_ans) np.testing.assert_array_equal(correct_ans.shape, y.shape) @@ -646,18 +621,20 @@ def test_conv_transpose(self, use_bias): ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) - correct_ans = np.array([[[ 4., 4., 4., 4.], - [ 7., 7., 7., 7.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [ 7., 7., 7., 7.], - [ 4., 4., 4., 4.]]]) + correct_ans = np.array([[ + [4.0, 4.0, 4.0, 4.0], + [7.0, 7.0, 7.0, 7.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [7.0, 7.0, 7.0, 7.0], + [4.0, 4.0, 4.0, 4.0], + ]]) if not use_bias: - correct_ans -= 1. + correct_ans -= 1.0 np.testing.assert_allclose(y, correct_ans) @parameterized.product( @@ -676,20 +653,22 @@ def test_multibatch_input_conv_transpose(self, use_bias): ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) - correct_ans = np.array([[[ 4., 4., 4., 4.], - [ 7., 7., 7., 7.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [ 7., 7., 7., 7.], - [ 4., 4., 4., 4.]]]) + correct_ans = np.array([[ + [4.0, 4.0, 4.0, 4.0], + [7.0, 7.0, 7.0, 7.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [7.0, 7.0, 7.0, 7.0], + [4.0, 4.0, 4.0, 4.0], + ]]) correct_ans = np.repeat(correct_ans[None], repeats=2, axis=0) correct_ans = np.repeat(correct_ans, repeats=5, axis=1) if not use_bias: - correct_ans -= 1. + correct_ans -= 1.0 np.testing.assert_allclose(y, correct_ans) def test_single_input_conv_transpose(self): @@ -704,16 +683,18 @@ def test_single_input_conv_transpose(self): ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) - correct_ans = np.array([[ 4., 4., 4., 4.], - [ 7., 7., 7., 7.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [ 7., 7., 7., 7.], - [ 4., 4., 4., 4.]]) + correct_ans = np.array([ + [4.0, 4.0, 4.0, 4.0], + [7.0, 7.0, 7.0, 7.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [7.0, 7.0, 7.0, 7.0], + [4.0, 4.0, 4.0, 4.0], + ]) np.testing.assert_allclose(y, correct_ans) def test_single_input_masked_conv_transpose(self): @@ -730,16 +711,18 @@ def test_single_input_masked_conv_transpose(self): ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) - correct_ans = np.array([[ 4., 3., 2., 1.], - [ 7., 5., 3., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [ 7., 5., 3., 1.], - [ 4., 3., 2., 1.]]) + correct_ans = np.array([ + [4.0, 3.0, 2.0, 1.0], + [7.0, 5.0, 3.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [10.0, 7.0, 4.0, 1.0], + [7.0, 5.0, 3.0, 1.0], + [4.0, 3.0, 2.0, 1.0], + ]) np.testing.assert_allclose(y, correct_ans) @parameterized.product( @@ -750,7 +733,7 @@ def test_single_input_masked_conv_transpose(self): input_size=(1, 8, 16), ) def test_circular_conv_transpose_1d_constant( - self, n_batch, n_features, kernel_size, n_input_features, input_size + self, n_batch, n_features, kernel_size, n_input_features, input_size ): """ Test 1D transposed convolution with circular padding: filter with all @@ -774,8 +757,9 @@ def test_circular_conv_transpose_1d_constant( initial_params['params']['kernel'].shape, (kernel_size, n_input_features, n_features), ) - correct_ans = np.full((n_batch, input_size, n_features), - kernel_size * n_input_features) + correct_ans = np.full( + (n_batch, input_size, n_features), kernel_size * n_input_features + ) np.testing.assert_allclose(y, correct_ans) @parameterized.product( @@ -824,7 +808,7 @@ def test_circular_conv_transpose_2d_constant( np.testing.assert_allclose(y, correct_ans) def test_circular_conv_transpose_2d_with_vmap(self): - layer = nn.ConvTranspose(features=5, kernel_size=(3,), padding="CIRCULAR") + layer = nn.ConvTranspose(features=5, kernel_size=(3,), padding='CIRCULAR') # this is ok sample_input = jnp.ones((1, 32, 2)) @@ -858,14 +842,23 @@ def test_circular_conv_transpose_1d_custom(self): self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1)) # Compare with manually computed convolution - correct_ans = np.array( # pyformat: disable - (1 * 1, 1 * 2, 1 * 1, - 2 * 1, 2 * 2, 2 * 1, - 3 * 1, 3 * 2, 3 * 1, - 4 * 1, 4 * 2, 4 * 1, - 5 * 1, 5 * 2, 5 * 1, - ) - ) + correct_ans = np.array(( # pyformat: disable + 1 * 1, + 1 * 2, + 1 * 1, + 2 * 1, + 2 * 2, + 2 * 1, + 3 * 1, + 3 * 2, + 3 * 1, + 4 * 1, + 4 * 2, + 4 * 1, + 5 * 1, + 5 * 2, + 5 * 1, + )) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) @@ -929,8 +922,7 @@ def test_circular_conv_transpose_2d_custom_bias(self): correct_ans = np.expand_dims(correct_ans, (0, 3)) np.testing.assert_allclose(y, correct_ans) - @parameterized.product( - use_bias=(True, False)) + @parameterized.product(use_bias=(True, False)) def test_transpose_kernel_conv_transpose(self, use_bias): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 15, 15, 3)) @@ -946,9 +938,7 @@ def test_transpose_kernel_conv_transpose(self, use_bias): self.assertEqual(initial_params['params']['kernel'].shape, (6, 6, 4, 3)) self.assertEqual(y.shape, (1, 30, 30, 4)) - @parameterized.product( - module=(nn.Conv, nn.ConvLocal) - ) + @parameterized.product(module=(nn.Conv, nn.ConvLocal)) def test_int_kernel_size(self, module): conv = module(features=4, kernel_size=3) x = jnp.ones((8, 3)) @@ -958,8 +948,9 @@ def test_int_kernel_size(self, module): def test_embed(self): rng = dict(params=random.PRNGKey(0)) x = jnp.arange(4)[None] - dummy_embedding = jnp.broadcast_to( - jnp.arange(4)[..., None], (4, 3)).astype(jnp.float32) + dummy_embedding = jnp.broadcast_to(jnp.arange(4)[..., None], (4, 3)).astype( + jnp.float32 + ) embed_module = nn.Embed( num_embeddings=4, features=3, @@ -967,15 +958,15 @@ def test_embed(self): ) y, initial_params = embed_module.init_with_output(rng, x) np.testing.assert_allclose(y, dummy_embedding[None]) - z = embed_module.apply(initial_params, jnp.ones((3,)), - method=embed_module.attend) - np.testing.assert_allclose(z, 3. * jnp.arange(4)) + z = embed_module.apply(initial_params, jnp.ones((3,)), method=embed_module.attend) + np.testing.assert_allclose(z, 3.0 * jnp.arange(4)) def test_embed_numpy(self): rng = dict(params=random.PRNGKey(0)) x = jnp.arange(4)[None] - dummy_embedding = np.broadcast_to( - np.arange(4)[..., None], (4, 3)).astype(np.float32) + dummy_embedding = np.broadcast_to(np.arange(4)[..., None], (4, 3)).astype( + np.float32 + ) embed_module = nn.Embed( num_embeddings=4, features=3, @@ -983,9 +974,8 @@ def test_embed_numpy(self): ) y, initial_params = embed_module.init_with_output(rng, x) np.testing.assert_allclose(y, dummy_embedding[None]) - z = embed_module.apply(initial_params, jnp.ones((3,)), - method=embed_module.attend) - np.testing.assert_allclose(z, 3. * jnp.arange(4)) + z = embed_module.apply(initial_params, jnp.ones((3,)), method=embed_module.attend) + np.testing.assert_allclose(z, 3.0 * jnp.arange(4)) def test_embed_hash(self): self.assertEqual(hash(nn.Embed(2, 3)), hash(nn.Embed(2, 3))) @@ -1002,10 +992,8 @@ def __call__(self, x): y, variables = Foo().init_with_output(random.PRNGKey(0), x) self.assertEqual( jax.tree_util.tree_map(jnp.shape, variables['params']), - {'dense': { - 'kernel': (4, 6), - 'bias': (6,) - }}) + {'dense': {'kernel': (4, 6), 'bias': (6,)}}, + ) self.assertEqual(y.shape, (2, 8, 6)) def test_non_final_axes(self): @@ -1019,10 +1007,8 @@ def __call__(self, x): y, variables = Foo().init_with_output(random.PRNGKey(0), x) self.assertEqual( jax.tree_util.tree_map(jnp.shape, variables['params']), - {'dense': { - 'kernel': (2, 4, 6), - 'bias': (6,) - }}) + {'dense': {'kernel': (2, 4, 6), 'bias': (6,)}}, + ) self.assertEqual(y.shape, (8, 6)) def test_canonicalize_padding(self): @@ -1032,12 +1018,14 @@ def test_pad(pad, rank, expected=None): nn.linear.canonicalize_padding(pad, rank) else: self.assertEqual(nn.linear.canonicalize_padding(pad, rank), expected) - test_pad("SAME", 2, "SAME") + + test_pad('SAME', 2, 'SAME') test_pad(2, 3, [(2, 2), (2, 2), (2, 2)]) test_pad((2, 2), 3) test_pad((2, 2), 1) test_pad([1, (2, 3)], 2, [(1, 1), (2, 3)]) test_pad([None, (1, 2)], 2) + if __name__ == '__main__': absltest.main() diff --git a/tests/linen/linen_meta_test.py b/tests/linen/linen_meta_test.py index d20fa40eb9..3652873eb0 100644 --- a/tests/linen/linen_meta_test.py +++ b/tests/linen/linen_meta_test.py @@ -32,8 +32,7 @@ class Bar(nn.Module): @nn.compact def __call__(mdl_self, x): # pylint: disable=no-self-argument - kernel_init = nn.with_partitioning(nn.initializers.ones_init(), - ('in', 'out')) + kernel_init = nn.with_partitioning(nn.initializers.ones_init(), ('in', 'out')) kernel = mdl_self.param('kernel', kernel_init, (x.shape[-1], 2)) kernel_box = mdl_self.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, nn.Partitioned) @@ -45,27 +44,31 @@ class Foo(nn.Module): @nn.compact def __call__(self, xs): return nn.vmap( - Bar, in_axes=0, - variable_axes={'params': 0}, split_rngs={'params': True}, - metadata_params={nn.PARTITION_NAME: 'batch'})(name='bar')(xs) + Bar, + in_axes=0, + variable_axes={'params': 0}, + split_rngs={'params': True}, + metadata_params={nn.PARTITION_NAME: 'batch'}, + )(name='bar')(xs) m = Foo() variables = m.init(random.PRNGKey(0), jnp.zeros((8, 3))) - self.assertEqual(variables['params']['bar']['kernel'].names, - ('batch', 'in', 'out')) - + self.assertEqual(variables['params']['bar']['kernel'].names, ('batch', 'in', 'out')) def test_boxed_variable(self): class Bar(nn.Module): @nn.compact def __call__(mdl_self, x): # pylint: disable=no-self-argument - kernel_init = nn.with_partitioning(nn.initializers.ones_init(), - ('in', 'out')) + kernel_init = nn.with_partitioning(nn.initializers.ones_init(), ('in', 'out')) kernel = mdl_self.variable( - 'params', 'kernel', kernel_init, - mdl_self.make_rng('params'), (x.shape[-1], 2)) - kernel.value += 1. + 'params', + 'kernel', + kernel_init, + mdl_self.make_rng('params'), + (x.shape[-1], 2), + ) + kernel.value += 1.0 self.assertEqual(kernel.value.sum(), kernel.value.size * 2) kernel_box = mdl_self.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, nn.Partitioned) @@ -77,29 +80,30 @@ class Foo(nn.Module): @nn.compact def __call__(self, xs): return nn.vmap( - Bar, in_axes=0, - variable_axes={'params': 0}, split_rngs={'params': True}, - metadata_params={nn.PARTITION_NAME: 'batch'})(name='bar')(xs) + Bar, + in_axes=0, + variable_axes={'params': 0}, + split_rngs={'params': True}, + metadata_params={nn.PARTITION_NAME: 'batch'}, + )(name='bar')(xs) m = Foo() variables = m.init(random.PRNGKey(0), jnp.zeros((8, 3))) - self.assertEqual(variables['params']['bar']['kernel'].names, - ('batch', 'in', 'out')) - + self.assertEqual(variables['params']['bar']['kernel'].names, ('batch', 'in', 'out')) # def test_boxed_variable(self): # def f(scope, xs): # def g(scope, x): - # kernel_init = nn.with_partitioning(nn.initializers.ones_init(), - # ('in', 'out')) - # kernel = scope.variable('params', 'kernel', kernel_init, - # scope.make_rng('params'), (x.shape[-1], 2)) - # kernel.value += 1. - # self.assertEqual(kernel.value.sum(), kernel.value.size * 2) - # kernel_box = scope.get_variable('params', 'kernel') - # self.assertIsInstance(kernel_box, nn.Partitioned) - # self.assertEqual(kernel_box.names, ('in', 'out')) - # return x @ kernel.value + # kernel_init = nn.with_partitioning(nn.initializers.ones_init(), + # ('in', 'out')) + # kernel = scope.variable('params', 'kernel', kernel_init, + # scope.make_rng('params'), (x.shape[-1], 2)) + # kernel.value += 1. + # self.assertEqual(kernel.value.sum(), kernel.value.size * 2) + # kernel_box = scope.get_variable('params', 'kernel') + # self.assertIsInstance(kernel_box, nn.Partitioned) + # self.assertEqual(kernel_box.names, ('in', 'out')) + # return x @ kernel.value # nn.vmap( # g, in_axes=0, @@ -118,12 +122,12 @@ class MLP(nn.Module): def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( - self.hidden_size, - kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x) + self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')) + )(x) h = nn.relu(h) return nn.Dense( - x.shape[-1], - kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h) + x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')) + )(h) class Model(nn.Module): @@ -132,47 +136,58 @@ def __call__(self, x): def body(_, c): c = MLP(512)(c) return c, () + c, _ = nn.scan( - body, variable_axes={'params': 0}, split_rngs={'params': 0}, - length=8, metadata_params={nn.PARTITION_NAME: None})( - self, x) + body, + variable_axes={'params': 0}, + split_rngs={'params': 0}, + length=8, + metadata_params={nn.PARTITION_NAME: None}, + )(self, x) return c devs = mesh_utils.create_device_mesh((jax.device_count(), 1)) mesh = Mesh(devs, ['data', 'model']) model = Model() x = jnp.ones((8, 128)) - spec = nn.get_partition_spec( - jax.eval_shape(model.init, random.PRNGKey(0), x)) - self.assertEqual(spec, { - 'params': { - 'MLP_0': { - 'Dense_0': { - 'bias': PartitionSpec(), - 'kernel': PartitionSpec(None, 'data', 'model'), - }, - 'Dense_1': { - 'bias': PartitionSpec(), - 'kernel': PartitionSpec(None, 'model', 'data'), + spec = nn.get_partition_spec(jax.eval_shape(model.init, random.PRNGKey(0), x)) + self.assertEqual( + spec, + { + 'params': { + 'MLP_0': { + 'Dense_0': { + 'bias': PartitionSpec(), + 'kernel': PartitionSpec(None, 'data', 'model'), + }, + 'Dense_1': { + 'bias': PartitionSpec(), + 'kernel': PartitionSpec(None, 'model', 'data'), + }, }, }, }, - }) + ) x_spec = PartitionSpec('data', 'model') f = lambda x: jax.sharding.NamedSharding(mesh, x) if jax.config.jax_enable_custom_prng: key_spec = PartitionSpec() else: key_spec = PartitionSpec(None) - init_fn = jax.jit(model.init, - in_shardings=jax.tree_map(f, (key_spec, x_spec)), - out_shardings=jax.tree_map(f, spec)) + init_fn = jax.jit( + model.init, + in_shardings=jax.tree_map(f, (key_spec, x_spec)), + out_shardings=jax.tree_map(f, spec), + ) variables = init_fn(random.PRNGKey(0), x) - apply_fn = jax.jit(model.apply, - in_shardings=jax.tree_map(f, (spec, x_spec)), - out_shardings=jax.tree_map(f, x_spec)) + apply_fn = jax.jit( + model.apply, + in_shardings=jax.tree_map(f, (spec, x_spec)), + out_shardings=jax.tree_map(f, x_spec), + ) y = apply_fn(variables, x) self.assertEqual(y.shape, (8, 128)) + if __name__ == '__main__': absltest.main() diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index dd14f9e216..2685a09326 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -22,8 +22,18 @@ import inspect import operator import sys -from typing import (Any, Callable, Generic, Mapping, NamedTuple, Sequence, - Tuple, TypeVar, get_type_hints, Optional) +from typing import ( + Any, + Callable, + Generic, + Mapping, + NamedTuple, + Sequence, + Tuple, + TypeVar, + get_type_hints, + Optional, +) from absl.testing import absltest from flax import config @@ -43,6 +53,7 @@ # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() + def tree_equals(x, y): return jax.tree_util.tree_all(jax.tree_util.tree_map(operator.eq, x, y)) @@ -70,8 +81,9 @@ class Dense(nn.Module): @compact def __call__(self, x): - kernel = self.param('kernel', initializers.lecun_normal(), - (x.shape[-1], self.features)) + kernel = self.param( + 'kernel', initializers.lecun_normal(), (x.shape[-1], self.features) + ) y = jnp.dot(x, kernel) return y @@ -80,31 +92,37 @@ class ModuleTest(absltest.TestCase): def test_init_module(self): rngkey = jax.random.PRNGKey(0) - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = DummyModule(parent=scope)(x) params = scope.variables()['params'] y2 = DummyModule(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) - np.testing.assert_allclose(y, jnp.array([2.])) - self.assertEqual(params, {'bias': jnp.array([1.])}) + np.testing.assert_allclose(y, jnp.array([2.0])) + self.assertEqual(params, {'bias': jnp.array([1.0])}) def test_lazy_init(self): - class Foo(nn.Module): + @compact def __call__(self, x): - k = self.param("kernel", nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1])) + k = self.param( + 'kernel', nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1]) + ) return x @ k + # provide a massive input message which would OOM if any compute ops were actually executed - variables = Foo().lazy_init(random.PRNGKey(0), jax.ShapeDtypeStruct((1024 * 1024 * 1024, 128), jnp.float32)) - self.assertEqual(variables["params"]["kernel"].shape, (128, 128)) + variables = Foo().lazy_init( + random.PRNGKey(0), jax.ShapeDtypeStruct((1024 * 1024 * 1024, 128), jnp.float32) + ) + self.assertEqual(variables['params']['kernel'].shape, (128, 128)) def test_lazy_init_fails_on_data_dependence(self): class Foo(nn.Module): + @compact def __call__(self, x): - k = self.param("kernel", lambda _: x) + k = self.param('kernel', lambda _: x) return x * k with self.assertRaises(errors.LazyInitError): @@ -141,14 +159,9 @@ def _mydense(self, x): y2 = MLP(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) param_shape = jax.tree_util.tree_map(jnp.shape, params) - self.assertEqual(param_shape, { - 'Dense_0': { - 'kernel': (10, 3) - }, - 'Dense_1': { - 'kernel': (3, 3) - } - }) + self.assertEqual( + param_shape, {'Dense_0': {'kernel': (10, 3)}, 'Dense_1': {'kernel': (3, 3)}} + ) def test_nested_module_reuse(self): rngkey = jax.random.PRNGKey(0) @@ -180,16 +193,10 @@ def __call__(self, x): y2 = Top(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) param_shape = jax.tree_util.tree_map(jnp.shape, params) - self.assertEqual(param_shape, { - 'MLP_0': { - 'Dense_0': { - 'kernel': (10, 3) - }, - 'Dense_1': { - 'kernel': (3, 3) - } - } - }) + self.assertEqual( + param_shape, + {'MLP_0': {'Dense_0': {'kernel': (10, 3)}, 'Dense_1': {'kernel': (3, 3)}}}, + ) def test_setup_dict_assignment(self): rngkey = jax.random.PRNGKey(0) @@ -215,17 +222,11 @@ def __call__(self, x): y2 = MLP(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) param_shape = jax.tree_util.tree_map(jnp.shape, params) - self.assertEqual(param_shape, { - 'lyrs1_a': { - 'kernel': (10, 3) - }, - 'lyrs1_b': { - 'kernel': (3, 3) - } - }) + self.assertEqual( + param_shape, {'lyrs1_a': {'kernel': (10, 3)}, 'lyrs1_b': {'kernel': (3, 3)}} + ) def test_setup_dict_nonstring_keys(self): - class Foo(nn.Module): def setup(self): @@ -239,14 +240,9 @@ def __call__(self, x): x = jnp.ones(shape=(1, 3)) params = foo.init(random.PRNGKey(0), x)['params'] param_shape = jax.tree_util.tree_map(jnp.shape, params) - self.assertEqual(param_shape, - {'a_(1, 2)': { - 'kernel': (3, 2), - 'bias': (2,) - }}) + self.assertEqual(param_shape, {'a_(1, 2)': {'kernel': (3, 2), 'bias': (2,)}}) def test_setup_cloning(self): - class MLP(nn.Module): def setup(self): @@ -300,14 +296,14 @@ def setup(self): def __call__(self, x): return x + self.bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = DummyModuleWithoutCompact(x.shape, parent=scope)(x) params = scope.variables()['params'] y2 = DummyModuleWithoutCompact(x.shape, parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) - np.testing.assert_allclose(y, jnp.array([2.])) - self.assertEqual(params, {'bias': jnp.array([1.])}) + np.testing.assert_allclose(y, jnp.array([2.0])) + self.assertEqual(params, {'bias': jnp.array([1.0])}) def test_init_outside_setup_without_compact(self): rngkey = jax.random.PRNGKey(0) @@ -318,7 +314,7 @@ def __call__(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'must be initialized.*setup'): unused_y = DummyModuleWithoutCompact(parent=scope)(x) @@ -337,7 +333,7 @@ def foo(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'must be initialized.*setup'): unused_y = Dummy(parent=scope).foo(x) @@ -356,7 +352,7 @@ def __call__(self, x): unused_bias = self.param('bias', initializers.ones, x.shape) return x + self.bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): @@ -374,7 +370,7 @@ def __call__(self, x): bias = self.param('bias', initializers.ones, self.xshape) return x + bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): @@ -393,7 +389,7 @@ def setup(self): def __call__(self, x): return x + self.bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): @@ -407,17 +403,16 @@ class Dummy(nn.Module): def setup(self): self.biases = [ - self.param(f'bias_{i}', initializers.ones, self.xshape) - for i in range(4) + self.param(f'bias_{i}', initializers.ones, self.xshape) for i in range(4) ] def __call__(self, x): return x + self.biases[0] - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dummy(x.shape, parent=scope)(x) - self.assertEqual(y, jnp.array([2.])) + self.assertEqual(y, jnp.array([2.0])) def test_setattr_name_var_disagreement_allowed_in_dicts(self): rngkey = jax.random.PRNGKey(0) @@ -439,10 +434,10 @@ def setup(self): def __call__(self, x): return x + self.biases['0'] - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dummy(x.shape, parent=scope)(x) - self.assertEqual(y, jnp.array([2.])) + self.assertEqual(y, jnp.array([2.0])) def test_submodule_var_collision_with_scope(self): rngkey = jax.random.PRNGKey(0) @@ -457,7 +452,7 @@ def setup(self): def __call__(self, x): return x + self.bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaises(errors.NameInUseError): @@ -477,7 +472,7 @@ def __call__(self, x): unused_bias = DummyModule(name='bias') return x + self.bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create submodule "bias" in Module Dummy: Name in use' @@ -498,7 +493,7 @@ def __call__(self, x): unused_bias = self.param('bias', initializers.ones, self.xshape) return x + self.bias - x = jnp.array([1.]) + x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' @@ -506,7 +501,6 @@ def __call__(self, x): unused_y = Dummy(x.shape, parent=scope)(x) def test_attr_empty_container(self): - class Foo(nn.Module): bar: Mapping[str, Any] @@ -531,7 +525,6 @@ def call2(self): pass def test_only_one_compact_method_subclass(self): - class Dummy(nn.Module): @nn.compact @@ -552,7 +545,6 @@ def __call__(self): subdummy() def test_forgotten_compact_annotation(self): - class Bar(nn.Module): # user forgot to add @compact @@ -568,13 +560,14 @@ def __call__(self, x): x = bar(x) return x - msg = (r'Submodule Dense must be defined in `setup\(\)` or in a method ' - 'wrapped in `@compact`') + msg = ( + r'Submodule Dense must be defined in `setup\(\)` or in a method ' + 'wrapped in `@compact`' + ) with self.assertRaisesRegex(errors.AssignSubModuleError, msg): Foo().init(random.PRNGKey(0), jnp.ones((1, 3))) def test_forgotten_compact_annotation_with_explicit_parent(self): - class Bar(nn.Module): def __call__(self, x): @@ -589,13 +582,14 @@ def __call__(self, x): x = bar(x) return x - msg = (r'Submodule Dense must be defined in `setup\(\)` or in a method ' - 'wrapped in `@compact`') + msg = ( + r'Submodule Dense must be defined in `setup\(\)` or in a method ' + 'wrapped in `@compact`' + ) with self.assertRaisesRegex(errors.AssignSubModuleError, msg): Foo().init(random.PRNGKey(0), jnp.ones((1, 3))) def test_numpy_array_shape_class_args(self): - class MLP(nn.Module): widths: Sequence[int] @@ -610,7 +604,6 @@ def __call__(self, x): _ = test.apply(params, jnp.ones((3, 3))) def test_get_local_methods(self): - class Base: @staticmethod @@ -643,11 +636,11 @@ class Derived2(Derived1): self.assertEqual(nn.module._get_local_method_names(Base), ('bleep',)) self.assertEqual(nn.module._get_local_method_names(Derived1), ('bloop',)) self.assertEqual( - nn.module._get_local_method_names(Derived1, exclude=('bloop',)), ()) + nn.module._get_local_method_names(Derived1, exclude=('bloop',)), () + ) self.assertEqual(nn.module._get_local_method_names(Derived2), ()) def test_inheritance_dataclass_attribs(self): - class Test(nn.Module): bar: int @@ -696,39 +689,34 @@ def __call__(self, x): self.assertTrue(hasattr(test4, 'baz')) self.assertTrue(hasattr(test4, 'name')) self.assertTrue(hasattr(test4, 'parent')) + self.assertEqual(list(Test.__dataclass_fields__.keys()), ['bar', 'parent', 'name']) self.assertEqual( - list(Test.__dataclass_fields__.keys()), ['bar', 'parent', 'name']) - self.assertEqual( - list(Test2.__dataclass_fields__.keys()), - ['bar', 'baz', 'parent', 'name']) + list(Test2.__dataclass_fields__.keys()), ['bar', 'baz', 'parent', 'name'] + ) self.assertEqual( - list(Test3.__dataclass_fields__.keys()), - ['bar', 'baz', 'parent', 'name']) + list(Test3.__dataclass_fields__.keys()), ['bar', 'baz', 'parent', 'name'] + ) self.assertEqual( - list(Test4.__dataclass_fields__.keys()), - ['bar', 'baz', 'parent', 'name']) + list(Test4.__dataclass_fields__.keys()), ['bar', 'baz', 'parent', 'name'] + ) def test_get_suffix_value_pairs(self): for x in [(), [], {}, None, 0, set()]: self.assertEqual(nn.module._get_suffix_value_pairs(x), [('', x)]) self.assertEqual( - nn.module._get_suffix_value_pairs({ - 'a': 1, - 'b': 2 - }), [('_a', 1), ('_b', 2)]) + nn.module._get_suffix_value_pairs({'a': 1, 'b': 2}), [('_a', 1), ('_b', 2)] + ) self.assertEqual( - nn.module._get_suffix_value_pairs([1, 2, 3]), [('_0', 1), ('_1', 2), - ('_2', 3)]) + nn.module._get_suffix_value_pairs([1, 2, 3]), [('_0', 1), ('_1', 2), ('_2', 3)] + ) x1 = [nn.Dense(10), nn.relu, nn.Dense(10)] y1 = nn.module._get_suffix_value_pairs(x1) self.assertEqual(y1, [('_0', x1[0]), ('_1', x1[1]), ('_2', x1[2])]) x2 = {'a': 1, 'b': {'c': nn.Dense(10), 'd': nn.relu}} y2 = nn.module._get_suffix_value_pairs(x2) - self.assertEqual(y2, [('_a', 1), ('_b_c', x2['b']['c']), - ('_b_d', x2['b']['d'])]) + self.assertEqual(y2, [('_a', 1), ('_b_c', x2['b']['c']), ('_b_d', x2['b']['d'])]) def test_mixed_list_assignment_in_setup(self): - class Test(nn.Module): def setup(self): @@ -754,7 +742,6 @@ def test_module_is_hashable(self): self.assertNotEqual(hash(module_a), hash(module_b)) def test_module_custom_hash(self): - class Test(nn.Module): x: int = 3 y: int = 5 @@ -770,12 +757,11 @@ def __hash__(self): def test_module_with_scope_is_not_hashable(self): module_a = nn.Dense(10, parent=Scope({})) - msg = 'Can\'t call __hash__ on modules that hold variables.' + msg = "Can't call __hash__ on modules that hold variables." with self.assertRaisesWithLiteralMatch(TypeError, msg): hash(module_a) def test_module_trace(self): - class MLP(nn.Module): act: Callable = nn.relu sizes: Sequence[int] = (3, 2) @@ -823,7 +809,6 @@ def __call__(self, x): self.assertEqual(trace, expected_trace) def test_module_apply_method(self): - class Foo(nn.Module): not_callable: int = 1 @@ -855,7 +840,7 @@ def test(self): Foo().init({}, method='test') # non-existent attribute names will yield AttributeError. - with self.assertRaisesRegex(AttributeError, "allowed_apply_fn"): + with self.assertRaisesRegex(AttributeError, 'allowed_apply_fn'): Foo().apply({}, method='allowed_apply_fn') # test same for init. Foo().init({}, method='allowed_apply_fn') @@ -873,7 +858,6 @@ def test_call_unbound_compact_module_methods(self): dense(jnp.ones((1,))) def test_call_unbound_has_variable(self): - class EmptyModule(nn.Module): def foo(self): @@ -884,7 +868,6 @@ def foo(self): empty.foo() def test_call_unbound_make_rng(self): - class EmptyModule(nn.Module): def foo(self): @@ -895,7 +878,6 @@ def foo(self): empty.foo() def test_call_unbound_variables(self): - class EmptyModule(nn.Module): def foo(self): @@ -906,7 +888,6 @@ def foo(self): empty.foo() def test_call_unbound_noncompact_module_methods(self): - class EmptyModule(nn.Module): foo: int = 3 @@ -919,7 +900,6 @@ def bar(self): self.assertEqual(empty.bar(), 3) def test_call_unbound_noncompact_module_methods_depending_on_setup(self): - class EmptyModule(nn.Module): def setup(self): @@ -934,7 +914,6 @@ def bar(self): empty.bar() def test_module_with_attrs(self): - class Foo(nn.Module): bar: nn.Dense = dataclasses.field(init=False) @@ -950,7 +929,6 @@ def __call__(self, x): self.assertEqual(variables['params']['bar']['kernel'].shape, (2, 3)) def test_noncompact_module_frozen(self): - class Foo(nn.Module): def setup(self): @@ -959,26 +937,28 @@ def setup(self): def __call__(self): self.i = 2 # This is not allowed. - msg = ('Can\'t set i=2 for Module of type Foo: Module instance is frozen ' - 'outside of setup method.') + msg = ( + "Can't set i=2 for Module of type Foo: Module instance is frozen " + 'outside of setup method.' + ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): Foo().init(random.PRNGKey(0)) def test_compact_module_frozen(self): - class Foo(nn.Module): @nn.compact def __call__(self): self.i = 2 - msg = ('Can\'t set i=2 for Module of type Foo: Module instance is frozen ' - 'outside of setup method.') + msg = ( + "Can't set i=2 for Module of type Foo: Module instance is frozen " + 'outside of setup method.' + ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): Foo().init(random.PRNGKey(0)) def test_submodule_frozen(self): - class Foo(nn.Module): @nn.compact @@ -986,13 +966,14 @@ def __call__(self): dense = nn.Dense(10) dense.features = 20 # <--- This is not allowed - msg = ('Can\'t set features=20 for Module of type Dense: Module instance ' - 'is frozen outside of setup method.') + msg = ( + "Can't set features=20 for Module of type Dense: Module instance " + 'is frozen outside of setup method.' + ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): Foo().init(random.PRNGKey(0)) def test_module_call_not_implemented(self): - class Foo(nn.Module): pass @@ -1001,7 +982,6 @@ class Foo(nn.Module): Foo().init(random.PRNGKey(0)) def test_is_mutable_collection(self): - class EmptyModule(nn.Module): def __call__(self): @@ -1012,7 +992,6 @@ def __call__(self): self.assertFalse(empty.apply({}, mutable=False)) def test_module_lazy_getattr_setup(self): - class A(nn.Module): def setup(self): @@ -1038,7 +1017,6 @@ def __call__(self, x): np.testing.assert_array_equal(y1, y2) def test_module_lazy_dir_setup(self): - class A(nn.Module): def setup(self): @@ -1063,7 +1041,6 @@ def __call__(self, x): _ = B().init_with_output(key, x) def test_module_unbound_getattr(self): - class A(nn.Module): def setup(self): @@ -1099,7 +1076,6 @@ def test(self): self.assertFalse(setup_called) def test_module_pass_as_attr(self): - class A(nn.Module): def setup(self): @@ -1153,7 +1129,6 @@ def __call__(self, x): self.assertIsNone(a.name) def test_toplevel_submodule_adoption(self): - class Encoder(nn.Module): n_layers: int ch: int @@ -1208,7 +1183,6 @@ def __call__(self, x): self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) def test_toplevel_submodule_adoption_pytree(self): - class A(nn.Module): @nn.compact @@ -1248,7 +1222,11 @@ def __call__(self, c, x): jax.tree_util.tree_all( jax.tree_util.tree_map( lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), - counters, ref_counters))) + counters, + ref_counters, + ) + ) + ) def test_toplevel_submodule_adoption_sharing(self): dense = functools.partial(nn.Dense, use_bias=False) @@ -1355,7 +1333,6 @@ def __call__(self, x): self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) def test_toplevel_submodule_pytree_adoption_sharing(self): - class A(nn.Module): @nn.compact @@ -1391,7 +1368,6 @@ def __call__(self, x): self.assertTrue(tree_equals(counters, ref_counters)) def test_inner_class_def(self): - class X(nn.Module): class Hyper(struct.PyTreeNode): @@ -1406,7 +1382,6 @@ def __call__(self, x): self.assertIsInstance(X.Hyper(a=1), X.Hyper) def test_sow(self): - class Foo(nn.Module): @nn.compact @@ -1424,16 +1399,17 @@ def __call__(self, x, **sow_args): _, state = Foo().apply({}, 1, mutable=['intermediates']) self.assertEqual(state, {'intermediates': {'h': (1, 2)}}) - _, state = Foo().apply({}, - 1, - init_fn=lambda: 0, - reduce_fn=lambda a, b: a + b, - mutable=['intermediates']) + _, state = Foo().apply( + {}, + 1, + init_fn=lambda: 0, + reduce_fn=lambda a, b: a + b, + mutable=['intermediates'], + ) self.assertEqual(state, {'intermediates': {'h': 3}}) self.assertEqual(Foo().apply({}, 1), 3) def test_capture_intermediates(self): - class Bar(nn.Module): def test(self, x): @@ -1452,7 +1428,6 @@ def __call__(self, x): self.assertEqual(state, {'intermediates': {'Bar_0': {'test': (2,)}}}) def test_perturb(self): - class Foo(nn.Module): @nn.compact @@ -1471,15 +1446,17 @@ def loss(params, perturbations, inputs, targets): x = jax.random.uniform(jax.random.PRNGKey(1), shape=(10,)) y = jax.random.uniform(jax.random.PRNGKey(2), shape=(10,)) variables = Foo().init(jax.random.PRNGKey(0), x) - intm_grads = jax.grad( - loss, argnums=1)(variables['params'], variables['perturbations'], x, y) + intm_grads = jax.grad(loss, argnums=1)( + variables['params'], variables['perturbations'], x, y + ) # activation * 4 so reverse gradient also * 4 self.assertTrue( - all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply'])) + all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply']) + ) def test_perturb_noop(self): - class Foo(nn.Module): + @nn.compact def __call__(self, x): x = nn.Dense(10)(x) @@ -1505,7 +1482,6 @@ def __call__(self, x): module.apply({'params': params, 'perturbations': perturbations}, x) def test_functional_apply(self): - class Foo(nn.Module): def setup(self): @@ -1525,7 +1501,6 @@ def f(foo, x): self.assertEqual(y1, y2) def test_bind(self): - class Foo(nn.Module): def setup(self): @@ -1544,7 +1519,6 @@ def f(foo, x): self.assertEqual(y1, y2) def test_bind_stateful(self): - class Foo(nn.Module): def setup(self): @@ -1568,13 +1542,12 @@ def f(foo, x): self.assertEqual(y2, y3) bs_1 = new_state['batch_stats'] bs_2 = foo_b.variables['batch_stats'] - for x, y in zip( - jax.tree_util.tree_leaves(bs_1), jax.tree_util.tree_leaves(bs_2)): + for x, y in zip(jax.tree_util.tree_leaves(bs_1), jax.tree_util.tree_leaves(bs_2)): np.testing.assert_allclose(x, y) def test_unbind(self): - class Foo(nn.Module): + def setup(self): self.encoder = nn.Dense(4) self.decoder = nn.Dense(2) @@ -1599,7 +1572,6 @@ def __call__(self, x): np.testing.assert_equal(variables['params']['decoder'], decoder_vars['params']) def test_passing_mutable_variables(self): - class Foo(nn.Module): @nn.compact @@ -1612,7 +1584,6 @@ def __call__(self, x): self.assertEqual(y.shape, (2,)) def test_super_compact(self): - class Foo(nn.Module): @nn.compact @@ -1632,21 +1603,16 @@ def __call__(self, x): variables = Bar().init(k, x) shapes = jax.tree_util.tree_map(np.shape, variables['params']) self.assertEqual( - shapes, { - 'Dense_0': { - 'kernel': (7, 4), - 'bias': (4,) - }, - 'Dense_1': { - 'kernel': (4, 3), - 'bias': (3,) - }, - }) + shapes, + { + 'Dense_0': {'kernel': (7, 4), 'bias': (4,)}, + 'Dense_1': {'kernel': (4, 3), 'bias': (3,)}, + }, + ) y = Bar().apply(variables, x) self.assertEqual(y.shape, (4, 3)) def test_super_setup(self): - class Foo(nn.Module): def setup(self): @@ -1670,7 +1636,6 @@ def __call__(self, x): self.assertEqual(y.shape, (4, 3)) def test_freeze_attr(self): - class Foo(NamedTuple): a: int b: int @@ -1678,8 +1643,7 @@ class Foo(NamedTuple): self.assertEqual(nn.module._freeze_attr([1, 2]), (1, 2)) xs = nn.module._freeze_attr(Foo(1, 2)) self.assertEqual(xs, (1, 2)) - self.assertEqual(type(xs), - Foo) # equality test for NamedTuple doesn't check class! + self.assertEqual(type(xs), Foo) # equality test for NamedTuple doesn't check class! def test_generic_multiple_inheritance(self): T = TypeVar('T') @@ -1700,12 +1664,10 @@ def test_jit_rng_equivalance(self): model = nn.Dense(1, use_bias=False) jit_model = nn.jit(nn.Dense)(1, use_bias=False) param = model.init(random.PRNGKey(0), np.ones((1, 1)))['params']['kernel'] - param_2 = jit_model.init(random.PRNGKey(0), np.ones( - (1, 1)))['params']['kernel'] + param_2 = jit_model.init(random.PRNGKey(0), np.ones((1, 1)))['params']['kernel'] self.assertEqual(param, param_2) def test_rng_reuse_after_rewind(self): - class C(nn.Module): @nn.compact @@ -1736,7 +1698,6 @@ def __call__(self): self.assertFalse(rng_equals) def test_module_get_put_has_variable(self): - class A(nn.Module): @nn.compact @@ -1755,7 +1716,13 @@ def __call__(self, x): class C(nn.Module): def setup(self): - self.put_variable('test_col', 'a', jnp.ones(2,)) + self.put_variable( + 'test_col', + 'a', + jnp.ones( + 2, + ), + ) assert self.has_variable('test_col', 'a') def __call__(self): @@ -1796,7 +1763,6 @@ def __call__(self) -> None: D().init(rngs) def test_modifying_attribs_in_post_init(self): - class Foo(nn.Module): love: int = 99 @@ -1819,7 +1785,6 @@ def __post_init__(self): self.assertEqual(bar.love, 101) def test_has_rng(self): - class Foo(nn.Module): def __call__(self): @@ -1833,7 +1798,6 @@ def __call__(self): self.assertFalse(foo.apply({}, rngs={'baz': k})) def test_is_initializing(self): - class Foo(nn.Module): def __call__(self): @@ -1845,7 +1809,6 @@ def __call__(self): self.assertFalse(foo.apply({})) def test_throws_invalid_instance_module_error(self): - class B(nn.Module): @nn.compact @@ -1860,14 +1823,11 @@ def __call__(self, x): with self.assertRaises(errors.InvalidInstanceModuleError): B.init_with_output(k, x) with self.assertRaises(errors.InvalidInstanceModuleError): - B.apply({}, - x) # similar issue w. apply called on class instead of instance. + B.apply({}, x) # similar issue w. apply called on class instead of instance. with self.assertRaises(errors.InvalidInstanceModuleError): - B.bind({}, - x) # similar issue w. apply called on class instead of instance. + B.bind({}, x) # similar issue w. apply called on class instead of instance. def test_throws_incorrect_post_init_override_error(self): - class A(nn.Module): x: float @@ -1892,7 +1852,6 @@ def test_deepcopy_unspecified_parent(self): self.assertIs(unspecified_parent, copy.deepcopy(unspecified_parent)) def test_type_hints(self): - class Network(nn.Module): layers: int @@ -1900,7 +1859,6 @@ class Network(nn.Module): self.assertEqual(type_hints['layers'], int) def test_incorrect_property(self): - class Foo(nn.Module): @property @@ -1911,12 +1869,12 @@ def __call__(self): return self.prop foo = Foo() - with self.assertRaisesRegex(errors.DescriptorAttributeError, - 'Trying to access a property that'): + with self.assertRaisesRegex( + errors.DescriptorAttributeError, 'Trying to access a property that' + ): foo.apply({}) def test_custom_descriptor(self): - class Descriptor: def __get__(self, obj, objtype=None): @@ -1933,7 +1891,6 @@ def __call__(self): self.assertEqual(res, 10) def test_custom_descriptor_error(self): - class Descriptor: def __get__(self, obj, objtype=None): @@ -1946,8 +1903,9 @@ def __call__(self): return self.prop foo = Foo() - with self.assertRaisesRegex(errors.DescriptorAttributeError, - 'Trying to access a property that'): + with self.assertRaisesRegex( + errors.DescriptorAttributeError, 'Trying to access a property that' + ): foo.apply({}) def test_nested_external_modules(self): @@ -1967,6 +1925,7 @@ def __call__(self, x): return self.baz(x) class Foo(nn.Module): + def setup(self): self.bar = Bar(baz=Baz(a=1)) @@ -1979,23 +1938,29 @@ def __call__(self, x): def test_getattribute_triggers_setup(self): class B(nn.Module): + def setup(self): self.p1 = self.param('p1', lambda k: jnp.ones((2,))) + def fn1(self, x): return self.p1 + x + class A(nn.Module): b: nn.Module + def __call__(self, x): return self.b.fn1(x) + a = A(b=B()) k = random.PRNGKey(0) x = jnp.zeros((2,)) - vs = nn.init(lambda a,x: a(x), a)(k, x) - y = nn.apply(lambda a,x: a.b.fn1(x), a)(vs, x) + vs = nn.init(lambda a, x: a(x), a)(k, x) + y = nn.apply(lambda a, x: a.b.fn1(x), a)(vs, x) np.testing.assert_array_equal(y, jnp.ones((2,))) def test_nested_sequential_in_call(self): class Foo(nn.Module): + def setup(self): self.seq = nn.Sequential([nn.Dense(10) for i in range(10)]) @@ -2008,13 +1973,13 @@ def __call__(self, x): def test_setup_called_bounded_submodules(self): module = nn.Sequential([ - nn.Sequential([ - nn.Dense(2), + nn.Sequential([ + nn.Dense(2), + nn.relu, + nn.Dense(2), + ]), nn.relu, nn.Dense(2), - ]), - nn.relu, - nn.Dense(2), ]) x = jnp.ones((1, 3)) variables = module.init(jax.random.PRNGKey(0), x) @@ -2042,7 +2007,6 @@ def __call__(self, x): x = bar(x) return x - module = Foo(bars=[]) module.bars = [Bar(a=1)] @@ -2086,46 +2050,52 @@ def __call__(self, x): # run foo (y, bar_vars), variables = module.init_with_output( - jax.random.PRNGKey(0), jnp.ones(())) + jax.random.PRNGKey(0), jnp.ones(()) + ) self.assertIn('params', bar_vars) def test_nested_shared(self): class Shared(nn.Module): + @nn.compact def __call__(self, x): return nn.Dense(1)(x) class Unshared(nn.Module): shared: nn.Module + def __call__(self, x): return self.shared(x) class Super(nn.Module): a: nn.Module b: nn.Module + def run_a(self, x): return self.a(x) + def run_b(self, x): return self.b(x) + def __call__(self, x): return self.a(x) + self.b(x) - sh = Shared() a = Unshared(shared=sh) b = Unshared(shared=sh) module = Super(a=a, b=b) rng = jax.random.PRNGKey(0) - params = module.init(rng, jnp.ones(1))["params"] + params = module.init(rng, jnp.ones(1))['params'] - module.apply({"params": params}, jnp.ones(1)) # works as expected - module.apply({"params": params}, jnp.ones(1), method="run_a") # works as expected - module.apply({"params": params}, jnp.ones(1), method="run_b") # ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/b/shared/Dense_0" + module.apply({'params': params}, jnp.ones(1)) # works as expected + module.apply({'params': params}, jnp.ones(1), method='run_a') # works as expected + module.apply( + {'params': params}, jnp.ones(1), method='run_b' + ) # ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/b/shared/Dense_0" def test_repr(self): - class Base1(nn.Module): a: int @@ -2172,8 +2142,10 @@ class BaseLayer(nn.Module, kw_only=True): class ChildLayer(BaseLayer): child_multiplier: int # Don't want to have to set a default argument! + def __call__(self, x): return x * self.child_multiplier * self.base_multiplier + return BaseLayer, ChildLayer if tuple(sys.version_info)[:3] < (3, 10, 0): @@ -2218,6 +2190,7 @@ class RelaxedNamingTests(absltest.TestCase): def test_relaxed_adoption(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) @@ -2225,6 +2198,7 @@ def __call__(self, x): class Bar(nn.Module): sub: nn.Module + def __call__(self, x): return self.sub(x) @@ -2234,7 +2208,7 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.zeros((1,)) vs = bar.init(k, x) - self.assertTrue("foo" in vs['params'], "relaxed naming failure") + self.assertTrue('foo' in vs['params'], 'relaxed naming failure') y = bar.apply(vs, x) with set_config('flax_preserve_adopted_names', False): @@ -2243,11 +2217,12 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.zeros((1,)) vs = bar.init(k, x) - self.assertTrue("sub" in vs['params'], "old policy naming failure") + self.assertTrue('sub' in vs['params'], 'old policy naming failure') y = bar.apply(vs, x) def test_class_optional_adoption_name_preservation(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) @@ -2256,12 +2231,14 @@ def __call__(self, x): class Bar1(nn.Module): sub: nn.Module preserve_adopted_names = True + def __call__(self, x): return self.sub(x) class Bar2(nn.Module): sub: nn.Module preserve_adopted_names = False + def __call__(self, x): return self.sub(x) @@ -2271,7 +2248,7 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.zeros((1,)) vs = bar.init(k, x) - self.assertTrue("foo" in vs['params'], "adoption naming failure") + self.assertTrue('foo' in vs['params'], 'adoption naming failure') y = bar.apply(vs, x) with set_config('flax_preserve_adopted_names', True): @@ -2280,12 +2257,12 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.zeros((1,)) vs = bar.init(k, x) - self.assertTrue("sub" in vs['params'], "adoption naming failure") + self.assertTrue('sub' in vs['params'], 'adoption naming failure') y = bar.apply(vs, x) def test_nested_class_optional_adoption_name_preservation(self): - class Foo(nn.Module): + @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) @@ -2294,12 +2271,14 @@ def __call__(self, x): class Bar(nn.Module): sub: nn.Module preserve_adopted_names = True + def __call__(self, x): return self.sub(x) class Baz(nn.Module): sub: nn.Module preserve_adopted_names = True + def __call__(self, x): return self.sub(x) @@ -2310,12 +2289,13 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.zeros((1,)) vs = baz.init(k, x) - self.assertTrue("bar" in vs['params'], "adoption naming failure") - self.assertTrue("foo" in vs['params']['bar'], "adoption naming failure") + self.assertTrue('bar' in vs['params'], 'adoption naming failure') + self.assertTrue('foo' in vs['params']['bar'], 'adoption naming failure') y = baz.apply(vs, x) def test_relaxed_adoption_still_conflict_checks(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) @@ -2324,6 +2304,7 @@ def __call__(self, x): class Bar(nn.Module): sub1: nn.Module sub2: nn.Module + def __call__(self, x): return self.sub(x) @@ -2338,6 +2319,7 @@ def __call__(self, x): def test_relaxed_adoption_unnamed_adoptee(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) @@ -2345,6 +2327,7 @@ def __call__(self, x): class Bar(nn.Module): sub: nn.Module + def __call__(self, x): return self.sub(x) @@ -2354,7 +2337,7 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.zeros((1,)) vs = bar.init(k, x) - self.assertTrue("sub" in vs['params'], "relaxed naming failure") + self.assertTrue('sub' in vs['params'], 'relaxed naming failure') y = bar.apply(vs, x) with set_config('flax_preserve_adopted_names', False): @@ -2363,13 +2346,13 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.zeros((1,)) vs = bar.init(k, x) - self.assertTrue("sub" in vs['params'], "old policy naming failure") + self.assertTrue('sub' in vs['params'], 'old policy naming failure') y = bar.apply(vs, x) def test_relaxed_python_conflict(self): - class Foo(nn.Module): dummy = 0 + @nn.compact def __call__(self, x): p = self.param('dummy', nn.initializers.zeros, x.shape) @@ -2381,8 +2364,8 @@ def __call__(self, x): vs = foo.init(k, x) def test_relaxed_intercollection_conflict(self): - class Foo(nn.Module): + @nn.compact def __call__(self, x): v1 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape) @@ -2395,8 +2378,8 @@ def __call__(self, x): vs = foo.init(k, x) def test_relaxed_intercollection_conflict_set(self): - class Foo(nn.Module): + @nn.compact def __call__(self, x): v1 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape) @@ -2414,15 +2397,14 @@ def __call__(self, x): class FrozenDictTests(absltest.TestCase): def test_frozendict_flag(self): - with set_config('flax_return_frozendict', True): - x = jnp.zeros((2,3)) + x = jnp.zeros((2, 3)) layer = nn.Dense(5) params = layer.init(random.PRNGKey(0), x) self.assertTrue(isinstance(params, FrozenDict)) with set_config('flax_return_frozendict', False): - x = jnp.zeros((2,3)) + x = jnp.zeros((2, 3)) layer = nn.Dense(5) params = layer.init(random.PRNGKey(0), x) self.assertTrue(isinstance(params, dict)) diff --git a/tests/linen/linen_recurrent_test.py b/tests/linen/linen_recurrent_test.py index 5a09b9cd87..cbd1798186 100644 --- a/tests/linen/linen_recurrent_test.py +++ b/tests/linen/linen_recurrent_test.py @@ -30,6 +30,7 @@ class RNNTest(absltest.TestCase): + def test_rnn_basic_forward(self): batch_size = 10 seq_len = 40 @@ -129,7 +130,7 @@ def test_rnn_with_spatial_dimensions(self): channels_out = 15 rnn = nn.RNN( - nn.ConvLSTMCell(channels_out, kernel_size), + nn.ConvLSTMCell(channels_out, kernel_size), ) xs = jnp.ones((batch_size, seq_len, *image_size, channels_in)) @@ -148,8 +149,10 @@ def test_rnn_with_spatial_dimensions(self): for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: - self.assertEqual(layer_params['bias'].shape, (channels_out * 4,)) - self.assertIn(layer_params['kernel'].shape[2], [channels_in, channels_out, channels_out * 4]) + self.assertEqual(layer_params['bias'].shape, (channels_out * 4,)) + self.assertIn( + layer_params['kernel'].shape[2], [channels_in, channels_out, channels_out * 4] + ) self.assertEqual(layer_params['kernel'].shape[3], channels_out * 4) def test_numerical_equivalence(self): @@ -186,7 +189,9 @@ def test_numerical_equivalence_with_mask(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs, seq_lengths=seq_lengths) + (carry, ys), variables = rnn.init_with_output( + jax.random.PRNGKey(0), xs, seq_lengths=seq_lengths + ) cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) cell_params = variables['params']['cell'] @@ -220,10 +225,12 @@ def test_numerical_equivalence_single_batch(self): cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:1, 0].shape) for i in range(seq_len): - cell_carry, y = rnn.cell.apply({'params': cell_params}, cell_carry, xs[batch_idx, i, :][None]) + cell_carry, y = rnn.cell.apply( + {'params': cell_params}, cell_carry, xs[batch_idx, i, :][None] + ) np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-6) - carry_i = jax.tree_map(lambda x: x[batch_idx:batch_idx+1], carry) + carry_i = jax.tree_map(lambda x: x[batch_idx : batch_idx + 1], carry) np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-6) def test_numerical_equivalence_single_batch_nn_scan(self): @@ -233,9 +240,13 @@ def test_numerical_equivalence_single_batch_nn_scan(self): channels_out = 6 cell: nn.LSTMCell = nn.LSTMCell(channels_out) - rnn: nn.LSTMCell = nn.scan(nn.LSTMCell, in_axes=1, out_axes=1, - variable_broadcast='params', - split_rngs={'params': False})(channels_out) + rnn: nn.LSTMCell = nn.scan( + nn.LSTMCell, + in_axes=1, + out_axes=1, + variable_broadcast='params', + split_rngs={'params': False}, + )(channels_out) xs = jnp.ones((batch_size, seq_len, channels_in)) carry = rnn.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) @@ -248,10 +259,12 @@ def test_numerical_equivalence_single_batch_nn_scan(self): cell_carry = cell.initialize_carry(jax.random.PRNGKey(0), xs[:1, 0].shape) for i in range(seq_len): - cell_carry, y = cell.apply({'params': cell_params}, cell_carry, xs[batch_idx:batch_idx+1, i, :]) + cell_carry, y = cell.apply( + {'params': cell_params}, cell_carry, xs[batch_idx : batch_idx + 1, i, :] + ) np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-5) - carry_i = jax.tree_map(lambda x: x[batch_idx:batch_idx+1], carry) + carry_i = jax.tree_map(lambda x: x[batch_idx : batch_idx + 1], carry) np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-5) def test_numerical_equivalence_single_batch_jax_scan(self): @@ -299,11 +312,16 @@ def test_reverse(self): cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:1, 0].shape) for i in range(seq_len): - cell_carry, y = rnn.cell.apply({'params': cell_params}, cell_carry, xs[batch_idx, seq_len - i - 1, :][None]) + cell_carry, y = rnn.cell.apply( + {'params': cell_params}, cell_carry, xs[batch_idx, seq_len - i - 1, :][None] + ) np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-5) np.testing.assert_allclose( - cell_carry, jax.tree_map(lambda x: x[batch_idx:batch_idx+1], carry), rtol=1e-5) + cell_carry, + jax.tree_map(lambda x: x[batch_idx : batch_idx + 1], carry), + rtol=1e-5, + ) def test_reverse_but_keep_order(self): batch_size = 3 @@ -311,7 +329,9 @@ def test_reverse_but_keep_order(self): channels_in = 5 channels_out = 6 - rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True, reverse=True, keep_order=True) + rnn = nn.RNN( + nn.LSTMCell(channels_out), return_carry=True, reverse=True, keep_order=True + ) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray @@ -323,11 +343,16 @@ def test_reverse_but_keep_order(self): cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), xs[:1, 0].shape) for i in range(seq_len): - cell_carry, y = rnn.cell.apply({'params': cell_params}, cell_carry, xs[batch_idx, seq_len - i - 1, :][None]) + cell_carry, y = rnn.cell.apply( + {'params': cell_params}, cell_carry, xs[batch_idx, seq_len - i - 1, :][None] + ) np.testing.assert_allclose(y[0], ys[batch_idx, seq_len - i - 1, :], rtol=1e-5) np.testing.assert_allclose( - cell_carry, jax.tree_map(lambda x: x[batch_idx:batch_idx+1], carry), rtol=1e-5) + cell_carry, + jax.tree_map(lambda x: x[batch_idx : batch_idx + 1], carry), + rtol=1e-5, + ) def test_flip_sequence(self): x = jnp.arange(2 * 5).reshape((2, 5)) @@ -370,7 +395,6 @@ def test_flip_sequence_time_major_more_feature_dims(self): np.testing.assert_allclose(flipped[:2, 1], x[:2, 1][::-1]) def test_basic_seq_lengths(self): - x = jnp.ones((2, 10, 6)) lstm = nn.RNN(nn.LSTMCell(265)) variables = lstm.init(jax.random.PRNGKey(0), x) @@ -386,8 +410,7 @@ def test_bidirectional(self): channels_out = 6 bdirectional = nn.Bidirectional( - nn.RNN(nn.LSTMCell(channels_out)), - nn.RNN(nn.LSTMCell(channels_out)) + nn.RNN(nn.LSTMCell(channels_out)), nn.RNN(nn.LSTMCell(channels_out)) ) xs = jnp.ones((batch_size, seq_len, channels_in)) @@ -403,10 +426,7 @@ def test_shared_cell(self): channels_out = 6 cell = nn.LSTMCell(channels_out) - bdirectional = nn.Bidirectional( - nn.RNN(cell), - nn.RNN(cell) - ) + bdirectional = nn.Bidirectional(nn.RNN(cell), nn.RNN(cell)) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray @@ -421,9 +441,9 @@ def test_custom_merge_fn(self): channels_out = 6 bdirectional = nn.Bidirectional( - nn.RNN(nn.LSTMCell(channels_out)), - nn.RNN(nn.LSTMCell(channels_out)), - merge_fn=lambda x, y: x + y + nn.RNN(nn.LSTMCell(channels_out)), + nn.RNN(nn.LSTMCell(channels_out)), + merge_fn=lambda x, y: x + y, ) xs = jnp.ones((batch_size, seq_len, channels_in)) @@ -439,9 +459,9 @@ def test_return_carry(self): channels_out = 6 bdirectional = nn.Bidirectional( - nn.RNN(nn.LSTMCell(channels_out)), - nn.RNN(nn.LSTMCell(channels_out)), - return_carry=True + nn.RNN(nn.LSTMCell(channels_out)), + nn.RNN(nn.LSTMCell(channels_out)), + return_carry=True, ) xs = jnp.ones((batch_size, seq_len, channels_in)) @@ -451,46 +471,33 @@ def test_return_carry(self): self.assertEqual(ys.shape, (batch_size, seq_len, channels_out * 2)) self.assertEqual( - jax.tree_map(jnp.shape, carry_forward), - ((batch_size, channels_out), (batch_size, channels_out)) + jax.tree_map(jnp.shape, carry_forward), + ((batch_size, channels_out), (batch_size, channels_out)), ) self.assertEqual( - jax.tree_map(jnp.shape, carry_backward), - ((batch_size, channels_out), (batch_size, channels_out)) + jax.tree_map(jnp.shape, carry_backward), + ((batch_size, channels_out), (batch_size, channels_out)), ) + class TestRecurrentDeprecation(parameterized.TestCase): - @parameterized.product( - cell_type=[nn.LSTMCell, nn.GRUCell, nn.OptimizedLSTMCell] - ) + @parameterized.product(cell_type=[nn.LSTMCell, nn.GRUCell, nn.OptimizedLSTMCell]) def test_constructor(self, cell_type): - - with self.assertRaisesRegex( - TypeError, - "The RNNCellBase API has changed" - ): + with self.assertRaisesRegex(TypeError, 'The RNNCellBase API has changed'): cell_type() - @parameterized.product( - cell_type=[nn.LSTMCell, nn.GRUCell, nn.OptimizedLSTMCell] - ) + @parameterized.product(cell_type=[nn.LSTMCell, nn.GRUCell, nn.OptimizedLSTMCell]) def test_initialize_carry(self, cell_type): key = jax.random.PRNGKey(0) - with self.assertRaisesRegex( - TypeError, - "The RNNCellBase API has changed" - ): + with self.assertRaisesRegex(TypeError, 'The RNNCellBase API has changed'): cell_type.initialize_carry(key, (2,), 3) def test_rnn(self): cell = nn.LSTMCell(3) - with self.assertRaisesRegex( - TypeError, - "The RNNCellBase API has changed" - ): + with self.assertRaisesRegex(TypeError, 'The RNNCellBase API has changed'): nn.RNN(cell, cell_size=8) if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 3af2af8d7e..5767a3a95c 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -33,90 +33,95 @@ def check_eq(xs, ys): return jax.tree_util.tree_all( - jax.tree_util.tree_map(np.testing.assert_allclose, xs, ys)) + jax.tree_util.tree_map(np.testing.assert_allclose, xs, ys) + ) class PoolTest(parameterized.TestCase): def test_pool_custom_reduce(self): - x = jnp.full((1, 3, 3, 1), 2.) + x = jnp.full((1, 3, 3, 1), 2.0) mul_reduce = lambda x, y: x * y - y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID') - np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4)) + y = nn.pooling.pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID') + np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4)) - @parameterized.parameters( - {'count_include_pad': True}, - {'count_include_pad': False}) + @parameterized.parameters({'count_include_pad': True}, {'count_include_pad': False}) def test_avg_pool(self, count_include_pad): - x = jnp.full((1, 3, 3, 1), 2.) + x = jnp.full((1, 3, 3, 1), 2.0) pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad) y = pool(x) - np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.)) + np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0)) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array([ [0.25, 0.5, 0.25], - [0.5, 1., 0.5], + [0.5, 1.0, 0.5], [0.25, 0.5, 0.25], ]).reshape((1, 3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad) - @parameterized.parameters( - {'count_include_pad': True}, - {'count_include_pad': False}) + @parameterized.parameters({'count_include_pad': True}, {'count_include_pad': False}) def test_avg_pool_no_batch(self, count_include_pad): - x = jnp.full((3, 3, 1), 2.) + x = jnp.full((3, 3, 1), 2.0) pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad) y = pool(x) - np.testing.assert_allclose(y, np.full((2, 2, 1), 2.)) + np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0)) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array([ [0.25, 0.5, 0.25], - [0.5, 1., 0.5], + [0.5, 1.0, 0.5], [0.25, 0.5, 0.25], ]).reshape((3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad) - @parameterized.parameters( - {'count_include_pad': True}, - {'count_include_pad': False}) + @parameterized.parameters({'count_include_pad': True}, {'count_include_pad': False}) def test_avg_pool_padding_same(self, count_include_pad): x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1)) - pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad) + pool = lambda x: nn.avg_pool( + x, (2, 2), padding='SAME', count_include_pad=count_include_pad + ) y = pool(x) if count_include_pad: - expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1)) + expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape( + (1, 2, 2, 1) + ) else: - expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1)) + expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape( + (1, 2, 2, 1) + ) np.testing.assert_allclose(y, expected_y) def test_max_pool(self): x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) pool = lambda x: nn.max_pool(x, (2, 2)) expected_y = jnp.array([ - [4., 5.], - [7., 8.], + [4.0, 5.0], + [7.0, 8.0], ]).reshape((1, 2, 2, 1)) y = pool(x) np.testing.assert_allclose(y, expected_y) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array([ - [0., 0., 0.], - [0., 1., 1.], - [0., 1., 1.], + [0.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 1.0, 1.0], ]).reshape((1, 3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad) - @parameterized.parameters( - {'count_include_pad': True}, - {'count_include_pad': False}) + @parameterized.parameters({'count_include_pad': True}, {'count_include_pad': False}) def test_avg_pool_padding_same(self, count_include_pad): x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1)) - pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad) + pool = lambda x: nn.avg_pool( + x, (2, 2), padding='SAME', count_include_pad=count_include_pad + ) y = pool(x) if count_include_pad: - expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1)) + expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape( + (1, 2, 2, 1) + ) else: - expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1)) + expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape( + (1, 2, 2, 1) + ) np.testing.assert_allclose(y, expected_y) def test_pooling_variable_batch_dims(self): @@ -131,6 +136,7 @@ def test_pooling_no_batch_dims(self): assert y.shape == (16, 16, 3) + class NormalizationTest(parameterized.TestCase): def test_batch_norm(self): @@ -142,37 +148,43 @@ def test_batch_norm(self): mean = y.mean((0, 1)) var = y.var((0, 1)) - np.testing.assert_allclose(mean, np.array([0., 0.]), atol=1e-4) - np.testing.assert_allclose(var, np.array([1., 1.]), rtol=1e-4) + np.testing.assert_allclose(mean, np.array([0.0, 0.0]), atol=1e-4) + np.testing.assert_allclose(var, np.array([1.0, 1.0]), rtol=1e-4) y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats']) ema = vars_out['batch_stats'] np.testing.assert_allclose( - ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4) + ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4 + ) np.testing.assert_allclose( - ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4) + ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4 + ) def test_batch_norm_complex(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (4, 3, 2), dtype=jnp.complex64) - model_cls = nn.BatchNorm(momentum=0.9, use_running_average=False, dtype=jnp.complex64) + model_cls = nn.BatchNorm( + momentum=0.9, use_running_average=False, dtype=jnp.complex64 + ) y, initial_params = model_cls.init_with_output(key2, x) mean = y.mean((0, 1)) var = y.var((0, 1)) - np.testing.assert_allclose(mean, np.array([0., 0.]), atol=1e-4) - np.testing.assert_allclose(var, np.array([1., 1.]), rtol=1e-4) + np.testing.assert_allclose(mean, np.array([0.0, 0.0]), atol=1e-4) + np.testing.assert_allclose(var, np.array([1.0, 1.0]), rtol=1e-4) self.assertEqual(mean.dtype, jnp.complex64) y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats']) ema = vars_out['batch_stats'] np.testing.assert_allclose( - ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4) + ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4 + ) np.testing.assert_allclose( - ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4) + ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4 + ) @parameterized.parameters( {'reduction_axes': -1}, @@ -198,25 +210,26 @@ def test_layer_norm(self, reduction_axes, use_fast_variance=True): y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) - y_one_liner = ((x - x.mean(axis=reduction_axes, keepdims=True)) * - jax.lax.rsqrt(x.var(axis=reduction_axes, keepdims=True) + e)) + y_one_liner = (x - x.mean(axis=reduction_axes, keepdims=True)) * jax.lax.rsqrt( + x.var(axis=reduction_axes, keepdims=True) + e + ) np.testing.assert_allclose(y_one_liner, y, atol=1e-4) @parameterized.parameters( - {'reduction_axes': -1}, - {'reduction_axes': 1}, - {'reduction_axes': (1, 2)}) + {'reduction_axes': -1}, {'reduction_axes': 1}, {'reduction_axes': (1, 2)} + ) def test_rms_norm(self, reduction_axes): rng = random.PRNGKey(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 3, 4)) - model_cls = nn.RMSNorm(use_scale=False, epsilon=e, - reduction_axes=reduction_axes) + model_cls = nn.RMSNorm(use_scale=False, epsilon=e, reduction_axes=reduction_axes) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) - y_one_liner = (x * jax.lax.rsqrt(jnp.mean(jax.lax.square(x), axis=reduction_axes, keepdims=True) + e)) + y_one_liner = x * jax.lax.rsqrt( + jnp.mean(jax.lax.square(x), axis=reduction_axes, keepdims=True) + e + ) np.testing.assert_allclose(y_one_liner, y, atol=1e-4) def test_group_norm(self): @@ -231,8 +244,9 @@ def test_group_norm(self): self.assertEqual(x.shape, y.shape) x_gr = x.reshape([2, 5, 4, 4, 2, 16]) - y_test = ((x_gr - x_gr.mean(axis=[1, 2, 3, 5], keepdims=True)) * - jax.lax.rsqrt(x_gr.var(axis=[1, 2, 3, 5], keepdims=True) + e)) + y_test = (x_gr - x_gr.mean(axis=[1, 2, 3, 5], keepdims=True)) * jax.lax.rsqrt( + x_gr.var(axis=[1, 2, 3, 5], keepdims=True) + e + ) y_test = y_test.reshape([2, 5, 4, 4, 32]) np.testing.assert_allclose(y_test, y, atol=1e-4) @@ -249,12 +263,13 @@ def test_group_norm_raises(self): def test_batch_norm_multi_init(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): norm = nn.BatchNorm( - name="norm", + name='norm', use_running_average=False, - axis_name="batch", + axis_name='batch', ) x = norm(x) return x, norm(x) @@ -272,24 +287,20 @@ def test_dropout(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) module = nn.Dropout(rate=0.5) - y1 = module.apply({}, - jnp.ones((20, 20)), - deterministic=False, - rngs={'dropout': key1}) - y2 = module.apply({}, - jnp.ones((20, 20)), - deterministic=False, - rngs={'dropout': key2}) + y1 = module.apply( + {}, jnp.ones((20, 20)), deterministic=False, rngs={'dropout': key1} + ) + y2 = module.apply( + {}, jnp.ones((20, 20)), deterministic=False, rngs={'dropout': key2} + ) self.assertFalse(np.all(y1 == y2)) - y1 = module.apply({}, - jnp.ones((20, 20)), - deterministic=True, - rngs={'dropout': key1}) - y2 = module.apply({}, - jnp.ones((20, 20)), - deterministic=True, - rngs={'dropout': key2}) + y1 = module.apply( + {}, jnp.ones((20, 20)), deterministic=True, rngs={'dropout': key1} + ) + y2 = module.apply( + {}, jnp.ones((20, 20)), deterministic=True, rngs={'dropout': key2} + ) self.assertTrue(np.all(y1 == y2)) def test_dropout_rate_stats(self): @@ -300,10 +311,9 @@ def test_dropout_rate_stats(self): n_trials = 10 nonzero_counts = 0 for key in random.split(subkey, n_trials): - y = module.apply({}, - jnp.ones((100, 100)), - deterministic=False, - rngs={'dropout': key}) + y = module.apply( + {}, jnp.ones((100, 100)), deterministic=False, rngs={'dropout': key} + ) nonzero_counts += np.sum(y > 0.0) all_counts = np.prod((100, 100, n_trials)) frac = np.sum(nonzero_counts) / all_counts @@ -317,24 +327,19 @@ def test_dropout_rate_limits(self): key1, key2, key3 = random.split(rng, 3) inputs = jnp.ones((20, 20)) d0 = nn.Dropout(rate=0.0) - y1 = d0.apply({}, inputs, - deterministic=False, - rngs={'dropout': key1}) + y1 = d0.apply({}, inputs, deterministic=False, rngs={'dropout': key1}) np.testing.assert_array_equal(y1, inputs) d1 = nn.Dropout(rate=1.0) - y2 = d1.apply({}, inputs, - deterministic=False, - rngs={'dropout': key2}) + y2 = d1.apply({}, inputs, deterministic=False, rngs={'dropout': key2}) np.testing.assert_array_equal(y2, np.zeros_like(inputs)) # ensure gradient of rate==1.0 case is non-NaN - fn = lambda x, k: d1.apply({}, x, - rngs={'dropout': k}, - deterministic=False) + fn = lambda x, k: d1.apply({}, x, rngs={'dropout': k}, deterministic=False) res = jax.grad(lambda x, k: jnp.sum(fn(x, k)))(inputs, key3) self.assertFalse(np.isnan(res).any()) def test_dropout_manual_rng(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): key = self.make_rng('dropout') @@ -343,8 +348,7 @@ def __call__(self, x): return x1, x2 module = Foo() - x1, x2 = module.apply( - {}, jnp.ones((20, 20)), rngs={'dropout': random.PRNGKey(0)}) + x1, x2 = module.apply({}, jnp.ones((20, 20)), rngs={'dropout': random.PRNGKey(0)}) np.testing.assert_array_equal(x1, x2) @@ -365,16 +369,19 @@ def test_lstm(self): self.assertEqual(carry[1].shape, (2, 4)) np.testing.assert_allclose(y, carry[1]) param_shapes = jax.tree_util.tree_map(np.shape, initial_params['params']) - self.assertEqual(param_shapes, { - 'ii': {'kernel': (3, 4)}, - 'if': {'kernel': (3, 4)}, - 'ig': {'kernel': (3, 4)}, - 'io': {'kernel': (3, 4)}, - 'hi': {'kernel': (4, 4), 'bias': (4,)}, - 'hf': {'kernel': (4, 4), 'bias': (4,)}, - 'hg': {'kernel': (4, 4), 'bias': (4,)}, - 'ho': {'kernel': (4, 4), 'bias': (4,)}, - }) + self.assertEqual( + param_shapes, + { + 'ii': {'kernel': (3, 4)}, + 'if': {'kernel': (3, 4)}, + 'ig': {'kernel': (3, 4)}, + 'io': {'kernel': (3, 4)}, + 'hi': {'kernel': (4, 4), 'bias': (4,)}, + 'hf': {'kernel': (4, 4), 'bias': (4,)}, + 'hg': {'kernel': (4, 4), 'bias': (4,)}, + 'ho': {'kernel': (4, 4), 'bias': (4,)}, + }, + ) def test_gru(self): gru = nn.GRUCell(features=4) @@ -387,14 +394,17 @@ def test_gru(self): self.assertEqual(carry.shape, (2, 4)) np.testing.assert_allclose(y, carry) param_shapes = jax.tree_util.tree_map(np.shape, initial_params['params']) - self.assertEqual(param_shapes, { - 'ir': {'kernel': (3, 4), 'bias': (4,)}, - 'iz': {'kernel': (3, 4), 'bias': (4,)}, - 'in': {'kernel': (3, 4), 'bias': (4,)}, - 'hr': {'kernel': (4, 4)}, - 'hz': {'kernel': (4, 4)}, - 'hn': {'kernel': (4, 4), 'bias': (4,)}, - }) + self.assertEqual( + param_shapes, + { + 'ir': {'kernel': (3, 4), 'bias': (4,)}, + 'iz': {'kernel': (3, 4), 'bias': (4,)}, + 'in': {'kernel': (3, 4), 'bias': (4,)}, + 'hr': {'kernel': (4, 4)}, + 'hz': {'kernel': (4, 4)}, + 'hn': {'kernel': (4, 4), 'bias': (4,)}, + }, + ) def test_complex_input_gru(self): gru = nn.GRUCell(features=4) @@ -420,13 +430,15 @@ def test_convlstm(self): self.assertEqual(carry[1].shape, (2, 4, 4, 6)) np.testing.assert_allclose(y, carry[1]) param_shapes = jax.tree_util.tree_map(np.shape, initial_params['params']) - self.assertEqual(param_shapes, { - 'hh': {'bias': (6*4,), 'kernel': (3, 3, 6, 6*4)}, - 'ih': {'bias': (6*4,), 'kernel': (3, 3, 3, 6*4)}, - }) + self.assertEqual( + param_shapes, + { + 'hh': {'bias': (6 * 4,), 'kernel': (3, 3, 6, 6 * 4)}, + 'ih': {'bias': (6 * 4,), 'kernel': (3, 3, 3, 6 * 4)}, + }, + ) def test_optimized_lstm_cell_matches_regular(self): - # Create regular LSTMCell. lstm = nn.LSTMCell(features=4) rng = random.PRNGKey(0) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 6fb7c97b3d..3f0e3fa303 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -35,14 +35,15 @@ # pylint: disable=attribute-defined-outside-init,unused-variable,g-wrong-blank-lines,g-bare-generic + def tree_equals(x, y): - return jax.tree_util.tree_all( - jax.tree_util.tree_map(operator.eq, x, y)) + return jax.tree_util.tree_all(jax.tree_util.tree_map(operator.eq, x, y)) def tree_allclose(x, y): return jax.tree_util.tree_all( - jax.tree_util.tree_map(lambda x,y: np.all(np.isclose(x,y)), x, y)) + jax.tree_util.tree_map(lambda x, y: np.all(np.isclose(x, y)), x, y) + ) id_fn = lambda x: x @@ -77,6 +78,7 @@ def __call__(self, inputs): if i != len(self.features) - 1: x = nn.relu(x) return x + return MLP @@ -131,11 +133,14 @@ def test_remat_decorated(self): self.assertTrue(np.all(y1 == y2)) def test_remat_kwargs(self): - raise unittest.SkipTest("test breaks with grad") + raise unittest.SkipTest('test breaks with grad') + class ConditionalReLU(nn.Module): + @nn.compact - def __call__(self, input, apply_relu : bool = False): + def __call__(self, input, apply_relu: bool = False): return nn.relu(input) if apply_relu else input + key = random.PRNGKey(0) x = jnp.ones((4, 4)) * -1 remat_model = nn.remat(ConditionalReLU)() @@ -183,6 +188,7 @@ def test_remat_decorator_static_argnums(self): test = self class FooTrainStatic(nn.Module): + @partial(nn.remat, static_argnums=(2,)) @nn.compact def __call__(self, inputs, train: bool): @@ -199,6 +205,7 @@ def __call__(self, inputs, train: bool): self.assertEqual(y.shape, (1, 3)) class FooTrainDynamic(nn.Module): + @partial(nn.remat, static_argnums=()) @nn.compact def __call__(self, inputs, train: bool): @@ -213,23 +220,29 @@ def __call__(self, inputs, train: bool): y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) - def test_vmap(self): key1, key2 = random.split(random.PRNGKey(3), 2) x = random.uniform(key1, (4, 4)) x2 = random.uniform(key1, (5, 4, 4)) def vmap(cls): - return nn.vmap(cls, - in_axes=(0,), - variable_axes={'params': None}, - split_rngs={'params': False}) + return nn.vmap( + cls, + in_axes=(0,), + variable_axes={'params': None}, + split_rngs={'params': False}, + ) + normal_model = TransformedMLP(features=[3, 4, 5]) vmap_model = TransformedMLP(features=[3, 4, 5], transform=vmap) init_variables = normal_model.init(key2, x) # simulate vmap in python for comparison: - y1 = jnp.vstack([normal_model.apply(init_variables, x2[i])[None, ...] - for i in np.arange(x2.shape[0])]) + y1 = jnp.vstack( + [ + normal_model.apply(init_variables, x2[i])[None, ...] + for i in np.arange(x2.shape[0]) + ] + ) y2 = vmap_model.apply(init_variables, x2) np.testing.assert_allclose(y1, y2, atol=1e-7) @@ -239,16 +252,20 @@ def test_vmap_decorated(self): x2 = random.uniform(key1, (5, 4, 4)) def vmap(fn): - return nn.vmap(fn, - in_axes=(0,), - variable_axes={'params': None}, - split_rngs={'params': False}) + return nn.vmap( + fn, in_axes=(0,), variable_axes={'params': None}, split_rngs={'params': False} + ) + normal_model = decorated_MLP()(features=[3, 4, 5]) vmap_model = decorated_MLP(vmap)(features=[3, 4, 5]) init_variables = normal_model.init(key2, x) # simulate vmap in python for comparison: - y1 = jnp.vstack([normal_model.apply(init_variables, x2[i])[None, ...] - for i in np.arange(x2.shape[0])]) + y1 = jnp.vstack( + [ + normal_model.apply(init_variables, x2[i])[None, ...] + for i in np.arange(x2.shape[0]) + ] + ) y2 = vmap_model.apply(init_variables, x2) np.testing.assert_allclose(y1, y2, atol=1e-7) @@ -258,11 +275,14 @@ def test_vmap_batchnorm(self): x2 = random.uniform(key1, (5, 4, 4)) def vmap(cls): - return nn.vmap(cls, - in_axes=(0,), - variable_axes={'params': None, 'batch_stats': None}, - split_rngs={'params': False}, - axis_name='batch') + return nn.vmap( + cls, + in_axes=(0,), + variable_axes={'params': None, 'batch_stats': None}, + split_rngs={'params': False}, + axis_name='batch', + ) + class MlpBn(nn.Module): axis_name: Any = None @@ -275,7 +295,9 @@ def __call__(self, x): normal_model = MlpBn() vmap_model = vmap(MlpBn)(axis_name='batch') init_variables = normal_model.init(key2, x) - y1 = normal_model.apply(init_variables, x2.reshape((-1, 4)), mutable=['batch_stats'])[0] + y1 = normal_model.apply( + init_variables, x2.reshape((-1, 4)), mutable=['batch_stats'] + )[0] y1 = y1.reshape((5, 4, 3)) y2 = vmap_model.apply(init_variables, x2, mutable=['batch_stats'])[0] np.testing.assert_allclose(y1, y2, atol=1e-5) @@ -283,12 +305,13 @@ def __call__(self, x): def test_scan(self): class SimpleScan(nn.Module): features: int + @nn.compact def __call__(self, c, xs): - LSTM = nn.scan(nn.LSTMCell, - variable_broadcast='params', - split_rngs={'params': False}) - return LSTM(self.features, name="lstm_cell")(c, xs) + LSTM = nn.scan( + nn.LSTMCell, variable_broadcast='params', split_rngs={'params': False} + ) + return LSTM(self.features, name='lstm_cell')(c, xs) key1, key2 = random.split(random.PRNGKey(0), 2) xs = random.uniform(key1, (5, 3, 2)) @@ -313,14 +336,17 @@ def __call__(self, c, xs): def test_scan_decorated(self): class SimpleScan(nn.Module): features: int - @partial(nn.scan, - variable_broadcast='params', - in_axes=(nn.broadcast, 0), - split_rngs={'params': False}) + + @partial( + nn.scan, + variable_broadcast='params', + in_axes=(nn.broadcast, 0), + split_rngs={'params': False}, + ) @nn.compact def __call__(self, c, b, xs): assert b.shape == (4,) - return nn.LSTMCell(self.features, name="lstm_cell")(c, xs) + return nn.LSTMCell(self.features, name='lstm_cell')(c, xs) key1, key2 = random.split(random.PRNGKey(0), 2) xs = random.uniform(key1, (4, 3, 2)) @@ -345,22 +371,29 @@ def __call__(self, c, b, xs): def test_multiscope_lifting_simple(self): class Counter(nn.Module): + @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value + class Outer(nn.Module): + @nn.compact def __call__(self, x): cntr = nn.jit(Counter)(name='cntr')() return x + class Inner(nn.Module): outer_module: nn.Module + @nn.compact def __call__(self, x): return self.outer_module(x) + class Test(nn.Module): + @nn.compact def __call__(self, x): outer_dense = nn.jit(Outer)(name='outer') @@ -374,32 +407,41 @@ def __call__(self, x): rngs = random.PRNGKey(0) init_vars = Test(None).init(rngs, x) _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) - self.assertEqual(init_vars['counter']['outer']['cntr']['foo'], - jnp.array([2], jnp.int32)) - self.assertEqual(new_vars['counter']['outer']['cntr']['foo'], - jnp.array([4], jnp.int32)) + self.assertEqual( + init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) + ) + self.assertEqual( + new_vars['counter']['outer']['cntr']['foo'], jnp.array([4], jnp.int32) + ) def test_multiscope_lifting_simple_decorator(self): class Counter(nn.Module): + @nn.jit @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value + class Outer(nn.Module): + @nn.jit @nn.compact def __call__(self, x): cntr = Counter(name='cntr')() return x + class Inner(nn.Module): outer_module: nn.Module + @nn.jit @nn.compact def __call__(self, x): return self.outer_module(x) + class Test(nn.Module): + @nn.compact def __call__(self, x): outer_dense = Outer(name='outer') @@ -413,29 +455,38 @@ def __call__(self, x): rngs = random.PRNGKey(0) init_vars = Test(None).init(rngs, x) _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) - self.assertEqual(init_vars['counter']['outer']['cntr']['foo'], - jnp.array([2], jnp.int32)) - self.assertEqual(new_vars['counter']['outer']['cntr']['foo'], - jnp.array([4], jnp.int32)) + self.assertEqual( + init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) + ) + self.assertEqual( + new_vars['counter']['outer']['cntr']['foo'], jnp.array([4], jnp.int32) + ) def test_multiscope_lifting_argtree(self): class Counter(nn.Module): + @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value + class Outer(nn.Module): + @nn.compact def __call__(self, x): cntr = nn.jit(Counter)(name='cntr')() return x + class Inner(nn.Module): outer_module: Sequence[nn.Module] + @nn.compact def __call__(self, x): return self.outer_module[0](x) + self.outer_module[1](x) + class Test(nn.Module): + @nn.compact def __call__(self, x): outer_dense1 = nn.jit(Outer)(name='outer1') @@ -450,36 +501,47 @@ def __call__(self, x): rngs = random.PRNGKey(0) init_vars = Test(None).init(rngs, x) _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) - self.assertEqual(init_vars['counter']['outer1']['cntr']['foo'], - jnp.array([2], jnp.int32)) - self.assertEqual(new_vars['counter']['outer1']['cntr']['foo'], - jnp.array([4], jnp.int32)) - self.assertEqual(init_vars['counter']['outer2']['cntr']['foo'], - jnp.array([2], jnp.int32)) - self.assertEqual(new_vars['counter']['outer2']['cntr']['foo'], - jnp.array([4], jnp.int32)) + self.assertEqual( + init_vars['counter']['outer1']['cntr']['foo'], jnp.array([2], jnp.int32) + ) + self.assertEqual( + new_vars['counter']['outer1']['cntr']['foo'], jnp.array([4], jnp.int32) + ) + self.assertEqual( + init_vars['counter']['outer2']['cntr']['foo'], jnp.array([2], jnp.int32) + ) + self.assertEqual( + new_vars['counter']['outer2']['cntr']['foo'], jnp.array([4], jnp.int32) + ) def test_multiscope_lifting_argtree_decorator(self): class Counter(nn.Module): + @nn.jit @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value + class Outer(nn.Module): + @nn.jit @nn.compact def __call__(self, x): cntr = nn.jit(Counter)(name='cntr')() return x + class Inner(nn.Module): outer_module: Sequence[nn.Module] + @nn.jit @nn.compact def __call__(self, x): return self.outer_module[0](x) + self.outer_module[1](x) + class Test(nn.Module): + @nn.compact def __call__(self, x): outer_dense1 = Outer(name='outer1') @@ -494,37 +556,48 @@ def __call__(self, x): rngs = random.PRNGKey(0) init_vars = Test(None).init(rngs, x) _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) - self.assertEqual(init_vars['counter']['outer1']['cntr']['foo'], - jnp.array([2], jnp.int32)) - self.assertEqual(new_vars['counter']['outer1']['cntr']['foo'], - jnp.array([4], jnp.int32)) - self.assertEqual(init_vars['counter']['outer2']['cntr']['foo'], - jnp.array([2], jnp.int32)) - self.assertEqual(new_vars['counter']['outer2']['cntr']['foo'], - jnp.array([4], jnp.int32)) + self.assertEqual( + init_vars['counter']['outer1']['cntr']['foo'], jnp.array([2], jnp.int32) + ) + self.assertEqual( + new_vars['counter']['outer1']['cntr']['foo'], jnp.array([4], jnp.int32) + ) + self.assertEqual( + init_vars['counter']['outer2']['cntr']['foo'], jnp.array([2], jnp.int32) + ) + self.assertEqual( + new_vars['counter']['outer2']['cntr']['foo'], jnp.array([4], jnp.int32) + ) def test_multiscope_lifting_simple_decorator_w_jit(self): # TODO: actually test jaxpr on a simpler module. class Counter(nn.Module): + @nn.jit @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value + class Outer(nn.Module): + @nn.jit @nn.compact def __call__(self, x): cntr = Counter(name='cntr')() return x + class Inner(nn.Module): outer_module: nn.Module + @nn.jit @nn.compact def __call__(self, x): return self.outer_module(x) + class Test(nn.Module): + @nn.jit @nn.compact def __call__(self, x): @@ -539,28 +612,37 @@ def __call__(self, x): rngs = random.PRNGKey(0) init_vars = Test(None).init(rngs, x) _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) - self.assertEqual(init_vars['counter']['outer']['cntr']['foo'], - jnp.array([2], jnp.int32)) - self.assertEqual(new_vars['counter']['outer']['cntr']['foo'], - jnp.array([4], jnp.int32)) + self.assertEqual( + init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) + ) + self.assertEqual( + new_vars['counter']['outer']['cntr']['foo'], jnp.array([4], jnp.int32) + ) def test_vmapped_outer_module(self): class Outer(nn.Module): + @nn.jit @nn.compact def __call__(self, x): return nn.Dense(5)(x) + class Inner(nn.Module): outer_module: nn.Module - @partial(nn.vmap, - in_axes=(0,), - variable_axes={'params': 0}, - split_rngs={'params': True}) + + @partial( + nn.vmap, + in_axes=(0,), + variable_axes={'params': 0}, + split_rngs={'params': True}, + ) @nn.jit @nn.compact def __call__(self, x): return self.outer_module(x) + class Test(nn.Module): + @nn.compact def __call__(self, x): outer_dense = Outer(name='outer') @@ -573,47 +655,53 @@ def __call__(self, x): rngs = random.PRNGKey(0) init_vars = Test(None).init(rngs, x) y = Test(None).apply(init_vars, x) - self.assertEqual( - init_vars['params']['outer']['Dense_0']['kernel'].shape, - (3, 2, 5)) - self.assertEqual( - init_vars['params']['outer']['Dense_0']['bias'].shape, - (3, 5)) + self.assertEqual(init_vars['params']['outer']['Dense_0']['kernel'].shape, (3, 2, 5)) + self.assertEqual(init_vars['params']['outer']['Dense_0']['bias'].shape, (3, 5)) self.assertEqual(y.shape, (3, 1, 5)) def test_module_transform_with_setup(self): class Foo(nn.Module): + def setup(self): self.test = self.param('test', nn.initializers.ones_init(), ()) def __call__(self, x): return x * self.test - FooVmap = nn.vmap(Foo, in_axes=0, out_axes=0, - variable_axes={'params': 0}, split_rngs={'params': True}) + FooVmap = nn.vmap( + Foo, + in_axes=0, + out_axes=0, + variable_axes={'params': 0}, + split_rngs={'params': True}, + ) variables = FooVmap().init(random.PRNGKey(0), jnp.ones((4,))) self.assertEqual(variables['params']['test'].shape, (4,)) - def test_nested_module_args_vmap(self): class A(nn.Module): + @nn.compact def __call__(self, x): return nn.Dense(3)(x) + class B(nn.Module): A: nn.Module + @nn.compact def __call__(self, x): return self.A(x) + class C(nn.Module): B: nn.Module - @partial(nn.vmap, - variable_axes={'params': 0}, - split_rngs={'params': True}) + + @partial(nn.vmap, variable_axes={'params': 0}, split_rngs={'params': True}) @nn.compact def __call__(self, x): return self.B(x) + class D(nn.Module): + @nn.compact def __call__(self, x): a = A() @@ -626,34 +714,34 @@ def __call__(self, x): p = D().init(key, x) variable_shapes = jax.tree_util.tree_map(jnp.shape, p) - self.assertEqual( - variable_shapes['params']['A_0']['Dense_0']['kernel'], - (10, 10, 3)) - self.assertEqual( - variable_shapes['params']['A_0']['Dense_0']['bias'], - (10, 3)) + self.assertEqual(variable_shapes['params']['A_0']['Dense_0']['kernel'], (10, 10, 3)) + self.assertEqual(variable_shapes['params']['A_0']['Dense_0']['bias'], (10, 3)) def test_nested_module_args_vmap_2(self): class A(nn.Module): + @nn.compact def __call__(self, x): return nn.Dense(3)(x) + class B(nn.Module): A: nn.Module + @nn.compact def __call__(self, x): return self.A(x) + class C(nn.Module): A: nn.Module B: nn.Module - @partial( - nn.vmap, - variable_axes={'params': 0}, - split_rngs={'params': True}) + + @partial(nn.vmap, variable_axes={'params': 0}, split_rngs={'params': True}) @nn.compact def __call__(self, x): return self.B(x) + self.A(x) + class D(nn.Module): + @nn.compact def __call__(self, x): a1 = A() @@ -667,38 +755,36 @@ def __call__(self, x): p = D().init(key, x) variable_shapes = jax.tree_util.tree_map(jnp.shape, p) - self.assertEqual( - variable_shapes['params']['A_0']['Dense_0']['kernel'], - (10, 10, 3)) - self.assertEqual( - variable_shapes['params']['A_0']['Dense_0']['bias'], - (10, 3)) - self.assertEqual( - variable_shapes['params']['A_1']['Dense_0']['kernel'], - (10, 10, 3)) - self.assertEqual( - variable_shapes['params']['A_1']['Dense_0']['bias'], - (10, 3)) + self.assertEqual(variable_shapes['params']['A_0']['Dense_0']['kernel'], (10, 10, 3)) + self.assertEqual(variable_shapes['params']['A_0']['Dense_0']['bias'], (10, 3)) + self.assertEqual(variable_shapes['params']['A_1']['Dense_0']['kernel'], (10, 10, 3)) + self.assertEqual(variable_shapes['params']['A_1']['Dense_0']['bias'], (10, 3)) def test_nested_setup_calls_count(self): D = 3 N = 4 setup_cntr = 0 call_cntr = 0 + class Repeat(nn.Module): mdl_def: Any + def setup(self): self.lyrs = [self.mdl_def() for _ in range(N)] + @nn.remat # we just use remat as a convenient test of transform logic def __call__(self, x): for lyr in self.lyrs: lyr(x) return x + class Counter(nn.Module): + def setup(self): nonlocal setup_cntr setup_cntr += 1 self.dense = nn.Dense(2, use_bias=False) + @nn.remat def __call__(self, x): nonlocal call_cntr @@ -709,6 +795,7 @@ def nested_repeat(mdl): for _ in range(D): mdl = partial(Repeat, mdl) return mdl() + _ = nested_repeat(Counter).init(random.PRNGKey(0), jnp.ones((2,))) # setup_cntr == 128 due to 1 call in Counter.setup by _validate_setup # and 1 further "real" call. @@ -716,21 +803,28 @@ def nested_repeat(mdl): self.assertEqual(call_cntr, 64) def test_multimethod_setup_calls(self): - cntr=0 + cntr = 0 + class A(nn.Module): + def setup(self): nonlocal cntr - cntr+=1 + cntr += 1 self.d = nn.Dense(2) + @nn.remat def foo(self, x): return self.d(x) + @nn.remat def bar(self, x): return self.d(x) + class B(nn.Module): + def setup(self): self.a = A() + def __call__(self, x): y1 = self.a.foo(x) y2 = self.a.bar(x) @@ -748,31 +842,37 @@ def __call__(self, x): def test_toplevel_submodule_adoption_transform(self): class A(nn.Module): + @nn.compact def __call__(self, x): return nn.Dense(3)(x) + class B(nn.Module): A: nn.Module + @nn.compact def __call__(self, x): return self.A(x) + class C(nn.Module): A: nn.Module B: nn.Module - @partial( - nn.vmap, - variable_axes={'params': 0}, - split_rngs={'params': True}) + + @partial(nn.vmap, variable_axes={'params': 0}, split_rngs={'params': True}) @nn.compact def __call__(self, x): return self.B(x) + self.A(x) + class Csimple(nn.Module): A: nn.Module B: nn.Module + @nn.compact def __call__(self, x): return self.B(x) + self.A(x) + class D(nn.Module): + @nn.compact def __call__(self, x): a1 = A() @@ -789,26 +889,29 @@ def __call__(self, x): a1 = A() a2 = A() b = B(a1) - p2 = freeze({'params': { - 'A': p1['params']['A_0'], - 'B': { - 'A': p1['params']['A_1'], + p2 = freeze( + { + 'params': { + 'A': p1['params']['A_0'], + 'B': { + 'A': p1['params']['A_1'], + }, + } } - }}) + ) # Test method wrapper transform. y2 = C(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y2, atol=1e-7) # Test class transform. - Ctrafo = nn.vmap(Csimple, - variable_axes={'params': 0}, - split_rngs={'params': True}) + Ctrafo = nn.vmap(Csimple, variable_axes={'params': 0}, split_rngs={'params': True}) y3 = Ctrafo(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y3, atol=1e-7) def test_toplevel_submodule_adoption_pytree_transform(self): class A(nn.Module): + @nn.compact def __call__(self, c, x): counter = self.variable('counter', 'i', jnp.zeros, ()) @@ -818,17 +921,20 @@ def __call__(self, c, x): class B(nn.Module): A: Any + @nn.compact def __call__(self, c, x): return self.A['foo'](*self.A['bar'](c, x)) a = A() As = {'foo': A(), 'bar': A()} - b = nn.scan(B, - in_axes=0, - variable_carry='counter', - variable_broadcast='params', - split_rngs={'params': False})(As) + b = nn.scan( + B, + in_axes=0, + variable_carry='counter', + variable_broadcast='params', + split_rngs={'params': False}, + )(As) key = random.PRNGKey(0) x = jnp.ones((10, 2)) @@ -844,21 +950,24 @@ def __call__(self, c, x): 'i': jnp.array(11.0), }, }, - } - self.assertTrue(jax.tree_util.tree_all( - jax.tree_util.tree_map( - lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), - cntrs, ref_cntrs) - )) + } + self.assertTrue( + jax.tree_util.tree_all( + jax.tree_util.tree_map( + lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), + cntrs, + ref_cntrs, + ) + ) + ) def test_partially_applied_module_constructor_transform(self): k = random.PRNGKey(0) - x = jnp.ones((3,4,4)) + x = jnp.ones((3, 4, 4)) dense = partial(nn.Dense, use_bias=False) vmap_dense = nn.vmap( - dense, - variable_axes={'params':0}, - split_rngs={'params':True})(4) + dense, variable_axes={'params': 0}, split_rngs={'params': True} + )(4) init_vars = vmap_dense.init(k, x) init_vars_shapes = jax.tree_util.tree_map(jnp.shape, init_vars) ref_var_shapes = { @@ -870,7 +979,8 @@ def test_partially_applied_module_constructor_transform(self): def test_partial_module_method(self): k = random.PRNGKey(0) - x = jnp.ones((3,4,4)) + x = jnp.ones((3, 4, 4)) + class Foo(nn.Module): @nn.compact @@ -879,27 +989,26 @@ def inner(self, x): def __call__(self, x): return nn.vmap( - partial(Foo.inner), - variable_axes={'params':0}, - split_rngs={'params':True})(self, x) + partial(Foo.inner), variable_axes={'params': 0}, split_rngs={'params': True} + )(self, x) init_vars = Foo().init(k, x) init_vars_shapes = jax.tree_util.tree_map(jnp.shape, init_vars) ref_var_shapes = { - 'params': { - 'Dense_0': {'kernel': (3, 4, 2)} - }, + 'params': {'Dense_0': {'kernel': (3, 4, 2)}}, } self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes)) def test_variable_in_args_transform(self): class Test(nn.Module): + @nn.jit @nn.compact def __call__(self, x): baz = self.variable('test', 'baz', jnp.zeros, x.shape) y = self.mutate_variable_in_method(x, baz) return y + @nn.jit def mutate_variable_in_method(self, x, baz): baz.value += x @@ -908,14 +1017,25 @@ def mutate_variable_in_method(self, x, baz): k = random.PRNGKey(0) x = jnp.ones((1,)) variables = Test().init(k, x) - np.testing.assert_allclose(variables['test']['baz'], - jnp.array([1.0,]), atol=1e-7) + np.testing.assert_allclose( + variables['test']['baz'], + jnp.array([ + 1.0, + ]), + atol=1e-7, + ) y, variables = Test().apply(variables, x, mutable=['test']) - np.testing.assert_allclose(variables['test']['baz'], - jnp.array([2.0,]), atol=1e-7) + np.testing.assert_allclose( + variables['test']['baz'], + jnp.array([ + 2.0, + ]), + atol=1e-7, + ) def test_module_instance_in_args_transform(self): class Inner(nn.Module): + @nn.jit @nn.compact def __call__(self, x): @@ -924,12 +1044,14 @@ def __call__(self, x): return baz.value class Test(nn.Module): + @nn.jit @nn.compact def __call__(self, x): - inner = Inner(name="inner") + inner = Inner(name='inner') y = self.call_instance_arg_in_method(x, inner) return y + @nn.jit def call_instance_arg_in_method(self, x, inner): return inner(x) @@ -937,14 +1059,25 @@ def call_instance_arg_in_method(self, x, inner): k = random.PRNGKey(0) x = jnp.ones((1,)) variables = Test().init(k, x) - np.testing.assert_allclose(variables['test']['inner']['baz'], - jnp.array([1.0,]), atol=1e-7) + np.testing.assert_allclose( + variables['test']['inner']['baz'], + jnp.array([ + 1.0, + ]), + atol=1e-7, + ) y, variables = Test().apply(variables, x, mutable=['test']) - np.testing.assert_allclose(variables['test']['inner']['baz'], - jnp.array([2.0,]), atol=1e-7) + np.testing.assert_allclose( + variables['test']['inner']['baz'], + jnp.array([ + 2.0, + ]), + atol=1e-7, + ) def test_module_instance_in_args_transform_nested(self): class Inner(nn.Module): + @nn.jit @nn.compact def __call__(self, x): @@ -953,48 +1086,65 @@ def __call__(self, x): return baz.value class Outer(nn.Module): + @nn.jit @nn.compact def __call__(self, inner, x): y = self.call_instance_arg_in_method(x, inner) return y + @nn.jit def call_instance_arg_in_method(self, x, inner): return inner(x) class Test(nn.Module): + @nn.jit @nn.compact def __call__(self, x): - inner = Inner(name="inner") - outer = Outer(name="outer") + inner = Inner(name='inner') + outer = Outer(name='outer') return outer(inner, x) k = random.PRNGKey(0) x = jnp.ones((1,)) variables = Test().init(k, x) - np.testing.assert_allclose(variables['test']['inner']['baz'], - jnp.array([1.0,]), atol=1e-7) + np.testing.assert_allclose( + variables['test']['inner']['baz'], + jnp.array([ + 1.0, + ]), + atol=1e-7, + ) y, variables = Test().apply(variables, x, mutable=['test']) - np.testing.assert_allclose(variables['test']['inner']['baz'], - jnp.array([2.0,]), atol=1e-7) - + np.testing.assert_allclose( + variables['test']['inner']['baz'], + jnp.array([ + 2.0, + ]), + atol=1e-7, + ) def test_nested_variable_passing(self): class NestedVarUser(nn.Module): somevar: nn.Variable + @nn.jit @nn.compact def __call__(self, x): self.somevar.value += x return x + class VarUser(nn.Module): somevar: nn.Variable + @nn.jit @nn.compact def __call__(self, x): return NestedVarUser(self.somevar)(x) + class VarPasser(nn.Module): + @nn.jit @nn.compact def __call__(self, x): @@ -1005,47 +1155,67 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.ones((1,)) variables = VarPasser().init(k, x) - np.testing.assert_allclose(variables['test']['baz'], - jnp.array([1.0,]), atol=1e-7) + np.testing.assert_allclose( + variables['test']['baz'], + jnp.array([ + 1.0, + ]), + atol=1e-7, + ) y, variables = VarPasser().apply(variables, x, mutable=['test']) - np.testing.assert_allclose(variables['test']['baz'], - jnp.array([2.0,]), atol=1e-7) + np.testing.assert_allclose( + variables['test']['baz'], + jnp.array([ + 2.0, + ]), + atol=1e-7, + ) def test_returned_module_warning(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): return x + class Bar(nn.Module): + @nn.compact def __call__(self, x): f = self._helper() return f(x) + @nn.jit def _helper(self): return Foo() + b = Bar() with self.assertRaises(errors.TransformedMethodReturnValueError): b.apply({}, jnp.ones(2)) def test_returned_variable_warning(self): class Bar(nn.Module): + @nn.compact def __call__(self, x): f = self._helper() return f(x) + @nn.jit def _helper(self): return nn.Variable(None, None, None, False) + b = Bar() with self.assertRaises(errors.TransformedMethodReturnValueError): b.apply({}, jnp.ones(2)) def test_nowrap(self): class Bar(nn.Module): + @nn.compact def __call__(self, x): return self._helper(x) + @nn.nowrap def _helper(self, x): if len(nn.module._context.module_stack) > 2: # pylint: disable=protected-access @@ -1060,7 +1230,6 @@ def trans(variables): return jax.tree_util.tree_map(lambda x: x.T, variables) class TiedAutencoder(nn.Module): - features: int latents: int @@ -1073,7 +1242,7 @@ def f(self): map_fn = trans else: map_fn = lambda x: x - return nn.map_variables(f, "params", map_fn, map_fn, mutable=True)(self) + return nn.map_variables(f, 'params', map_fn, map_fn, mutable=True)(self) def encode(self, x): return self._call(x, False) @@ -1087,46 +1256,46 @@ def __call__(self, x): x = jnp.ones((2, 4)) ae = TiedAutencoder(4, 5) variables = ae.init(random.PRNGKey(0), x) - param_shapes = jax.tree_util.tree_map(jnp.shape, variables["params"]) - self.assertEqual(param_shapes, { - "Dense_0": {"kernel": (4, 5)} - }) - + param_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) + self.assertEqual(param_shapes, {'Dense_0': {'kernel': (4, 5)}}) def test_map_variables_bit_weights(self): class BitWeights(nn.Module): + @nn.compact def __call__(self, x): def sign(x): return jax.tree_util.tree_map(jnp.sign, x) - BitDense = nn.map_variables(nn.Dense, "params", sign, init=True) + + BitDense = nn.map_variables(nn.Dense, 'params', sign, init=True) return BitDense(4)(x) + bw = BitWeights() x = jnp.ones((2, 4)) y, variables = bw.init_with_output(random.PRNGKey(0), x) y_2 = bw.apply(variables, x) np.testing.assert_allclose(y, y_2) - def test_remat_scan(self): class BigModel(nn.Module): + @nn.compact def __call__(self, x): DenseStack = nn.remat_scan(nn.Dense, lengths=(100,)) - return DenseStack(8, name="dense_stack")(x) + return DenseStack(8, name='dense_stack')(x) x = jnp.ones((2, 8)) model = BigModel() variables = model.init(random.PRNGKey(0), x) param_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) - self.assertEqual(param_shapes["dense_stack"]["kernel"], (100, 8, 8)) - self.assertEqual(param_shapes["dense_stack"]["bias"], (100, 8)) + self.assertEqual(param_shapes['dense_stack']['kernel'], (100, 8, 8)) + self.assertEqual(param_shapes['dense_stack']['bias'], (100, 8)) y = model.apply(variables, x) self.assertEqual(y.shape, (2, 8)) - def test_vjp(self): class Bar(nn.Module): + @nn.compact def __call__(self, x, y): p = self.param('test', nn.initializers.constant(0.5), ()) @@ -1134,23 +1303,28 @@ def __call__(self, x, y): return p * x * y class Foo(nn.Module): + @nn.compact def __call__(self, x, y): z, bwd = nn.vjp(Bar.__call__, Bar(), x, y) return bwd(jnp.ones(z.shape)) - x = jnp.array([1., 2., 3.]) - y = jnp.array([4., 5., 6.]) + x = jnp.array([1.0, 2.0, 3.0]) + y = jnp.array([4.0, 5.0, 6.0]) params = Foo().init(random.PRNGKey(0), x, y) params_grad, x_grad, y_grad = Foo().apply(params, x, y) - self.assertEqual(params_grad, { - 'params': nn.FrozenDict({'test': 32.}), - }) - np.testing.assert_allclose(x_grad, [2., 2.5, 3.]) - np.testing.assert_allclose(y_grad, [0.5, 1., 1.5]) + self.assertEqual( + params_grad, + { + 'params': nn.FrozenDict({'test': 32.0}), + }, + ) + np.testing.assert_allclose(x_grad, [2.0, 2.5, 3.0]) + np.testing.assert_allclose(y_grad, [0.5, 1.0, 1.5]) def test_jvp(self): class Bar(nn.Module): + @nn.compact def __call__(self, x): p = self.param('test', nn.initializers.zeros, ()) @@ -1158,11 +1332,14 @@ def __call__(self, x): return p * x class Foo(nn.Module): + @nn.compact def __call__(self, x): bar = Bar() vars_t = jax.tree_util.tree_map(jnp.ones_like, bar.variables.get('params', {})) - _, out_t = nn.jvp(Bar.__call__, bar, (x,), (jnp.zeros_like(x),), {'params': vars_t}) + _, out_t = nn.jvp( + Bar.__call__, bar, (x,), (jnp.zeros_like(x),), {'params': vars_t} + ) return out_t x = jnp.ones((3,)) @@ -1173,19 +1350,24 @@ def __call__(self, x): def test_complicated_alias_mutation(self): class A(nn.Module): b: nn.Module + @nn.jit @nn.compact def __call__(self, x): return self.b(x) + class B(nn.Module): c: nn.Module + @nn.jit @nn.compact def __call__(self, x): y = C(name='outer_c')(x) z = self.c(x) return z + class C(nn.Module): + @nn.jit @nn.compact def __call__(self, x): @@ -1199,15 +1381,23 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.ones((1,), jnp.float32) vs = a.init(k, x) - y, vs_new = a.apply(vs, x, mutable=['muts',]) - np.testing.assert_array_equal(vs_new['muts']['b']['c']['v'], - jnp.array([1.], jnp.float32)) - np.testing.assert_array_equal(vs_new['muts']['b']['outer_c']['v'], - jnp.array([1.], jnp.float32)) + y, vs_new = a.apply( + vs, + x, + mutable=[ + 'muts', + ], + ) + np.testing.assert_array_equal( + vs_new['muts']['b']['c']['v'], jnp.array([1.0], jnp.float32) + ) + np.testing.assert_array_equal( + vs_new['muts']['b']['outer_c']['v'], jnp.array([1.0], jnp.float32) + ) def test_custom_vjp(self): - class Foo(nn.Module): + @nn.compact def __call__(self, x): def f(mdl, x): @@ -1221,55 +1411,69 @@ def bwd(vjp_fn, y_t): params_t = jax.tree_util.tree_map(jnp.sign, params_t) return params_t, input_t - sign_grad = nn.custom_vjp( - f, forward_fn=fwd, backward_fn=bwd) + sign_grad = nn.custom_vjp(f, forward_fn=fwd, backward_fn=bwd) return sign_grad(nn.Dense(1), x).reshape(()) + x = jnp.ones((2,)) variables = Foo().init(random.PRNGKey(0), x) grad = jax.grad(Foo().apply)(variables, x) for grad_leaf in jax.tree_util.tree_leaves(grad): - self.assertTrue(jnp.all(jnp.abs(grad_leaf) == 1.)) + self.assertTrue(jnp.all(jnp.abs(grad_leaf) == 1.0)) def test_transform_with_setup_and_methods_on_submodules(self): # This is the archetypal example motivating the introduction of # SetupState as a triple-enum to handle multiple setup() calls # across transform boundaries and scope reuse. class Foo(nn.Module): + def setup(self): self.inner = nn.Dense(2) + def helper(self, x, m): return m(x) + def __call__(self, x): return self.helper(x, self.inner) + k = random.PRNGKey(0) x = jnp.ones((2,)) vs_foo = Foo().init(k, x) class Bar(nn.Module): + def setup(self): self.inner = nn.Dense(2) + @nn.jit def helper(self, x, m): return m(x) + @nn.jit def __call__(self, x): return self.helper(x, self.inner) + vs_bar = Bar().init(k, x) - self.assertTrue(tree_equals( - jax.tree_util.tree_map(jnp.shape, vs_foo), - jax.tree_util.tree_map(jnp.shape, vs_bar))) + self.assertTrue( + tree_equals( + jax.tree_util.tree_map(jnp.shape, vs_foo), + jax.tree_util.tree_map(jnp.shape, vs_bar), + ) + ) def test_transform_methods_on_submodules_still_reserve_names(self): class Foo(nn.Module): + @nn.jit def helper(self, x, m): - conflicting_a = nn.Dense(2, name="a") + conflicting_a = nn.Dense(2, name='a') return m(x) + @nn.jit @nn.compact def __call__(self, x): - a = nn.Dense(2, name="a") + a = nn.Dense(2, name='a') return self.helper(x, a) + k = random.PRNGKey(0) x = jnp.ones((2,)) with self.assertRaises(errors.NameInUseError): @@ -1277,37 +1481,48 @@ def __call__(self, x): def test_transform_setup_still_reserve_names(self): class Identity(nn.Module): + @nn.compact def __call__(self, x): return x + class Test(nn.Module): + def setup(self): self.sub = Identity() self.sub = Identity() + @nn.jit def __call__(self, x): return x k = random.PRNGKey(0) - x = jnp.array([1.]) + x = jnp.array([1.0]) with self.assertRaises(errors.NameInUseError): y = Test().init(k, x) def test_transform_with_setup_and_methods_on_submodule_pytrees(self): class Foo(nn.Module): + def setup(self): self.inners = [nn.Dense(2), nn.Dense(2)] + def helper(self, x, ms): return ms[0](x) + ms[1](x) + def __call__(self, x): return self.helper(x, self.inners) + class JitFoo(nn.Module): + def setup(self): self.inners = [nn.Dense(2), nn.Dense(2)] + @nn.jit def helper(self, x, ms): return ms[0](x) + ms[1](x) + @nn.jit def __call__(self, x): return self.helper(x, self.inners) @@ -1322,19 +1537,23 @@ def __call__(self, x): def test_transform_setup_still_reserve_names_pytrees(self): class Identity(nn.Module): + @nn.compact def __call__(self, x): return x + class Test(nn.Module): + def setup(self): self.subs = [Identity(), Identity()] self.subs = [Identity(), Identity()] + @nn.jit def __call__(self, x): return x k = random.PRNGKey(0) - x = jnp.array([1.]) + x = jnp.array([1.0]) msg = r'Could not create submodule "subs_0".*' with self.assertRaisesRegex(errors.NameInUseError, msg): @@ -1342,16 +1561,17 @@ def __call__(self, x): def test_scan_of_setup_parameter(self): class Body(nn.Module): + def setup(self): self.dense = nn.Dense(1) self.p = self.param('p', lambda k: jnp.ones((1,))) + def __call__(self, x): return self.dense(x) + self.p, None + scanbody = nn.scan( - Body, - variable_axes={'params': 0}, - split_rngs={'params': True}, - length=2) + Body, variable_axes={'params': 0}, split_rngs={'params': True}, length=2 + ) k = random.PRNGKey(0) x = jnp.ones((1,)) vs = scanbody().init(k, x) @@ -1359,36 +1579,46 @@ def __call__(self, x): def test_multi_method_class_transform(self): class Foo(nn.Module): + def setup(self): self.dense0 = nn.Dense(2) self.dense1 = nn.Dense(2) + def method_0(self, x): return self.dense0(x), x + def method_1(self, x, y): return self.dense1(x) + y, None + class Bar(nn.Module): + @nn.compact def __call__(self, x): - ScanFoo = nn.scan(Foo, - methods={ - 'method_0': dict( - variable_axes={'params': 0}, - split_rngs={'params': True}, - in_axes=nn.broadcast, out_axes=0, - length=3), - 'method_1': dict( - variable_axes={'params': 0}, - split_rngs={'params': True}, - in_axes=0, - length=3) - }) + ScanFoo = nn.scan( + Foo, + methods={ + 'method_0': dict( + variable_axes={'params': 0}, + split_rngs={'params': True}, + in_axes=nn.broadcast, + out_axes=0, + length=3, + ), + 'method_1': dict( + variable_axes={'params': 0}, + split_rngs={'params': True}, + in_axes=0, + length=3, + ), + }, + ) sf = ScanFoo() y, ys = sf.method_0(x) z, _ = sf.method_1(y, ys) return z k = random.PRNGKey(0) - x = random.uniform(random.PRNGKey(1), (2,2)) + x = random.uniform(random.PRNGKey(1), (2, 2)) vs = Bar().init(k, x) y = Bar().apply(vs, x) @@ -1396,16 +1626,20 @@ def test_compact_aliasing_collision(self): class Foo(nn.Module): m1: nn.Module m2: nn.Module + @nn.compact def __call__(self, x): x = self.m2(self.m1(x)) return x + class Bar(nn.Module): + @nn.compact def __call__(self, x): dense = nn.Dense(2) x = nn.jit(Foo)(dense, dense)(x) return x + k = random.PRNGKey(0) x = jnp.zeros((2, 2)) _ = Bar().init(k, x) @@ -1413,40 +1647,52 @@ def __call__(self, x): def test_compact_aliasing_collision_arg_and_attrib(self): class Foo(nn.Module): m1: nn.Module + @nn.compact def __call__(self, x, m2): x = m2(self.m1(x)) return x + class Bar(nn.Module): + @nn.compact def __call__(self, x): dense = nn.Dense(2) x = nn.jit(Foo)(dense)(x, dense) return x + k = random.PRNGKey(0) x = jnp.zeros((2, 2)) _ = Bar().init(k, x) def test_jit_with_setup_helpers(self): class Foo(nn.Module): + def setup(self): self.a = nn.Dense(2) self.setup_helper() + def setup_helper(self): self.b = nn.Dense(2) + def __call__(self, x): return self.b(self.a(x)) + class JitFoo(nn.Module): + def setup(self): self.a = nn.Dense(2) self.setup_helper() + def setup_helper(self): self.b = nn.Dense(2) + @nn.jit def __call__(self, x): return self.b(self.a(x)) + k = random.PRNGKey(0) - x = jnp.ones((2,2)) + x = jnp.ones((2, 2)) vs = JitFoo().init(k, x) y0 = JitFoo().apply(vs, x) vs = Foo().init(k, x) @@ -1455,6 +1701,7 @@ def __call__(self, x): def test_while_loop(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): key_zero = random.PRNGKey(0) @@ -1467,31 +1714,58 @@ def __call__(self, x): def cond_fn(mdl, c): acc = mdl.get_variable('state', 'acc') return acc < x + def body_fn(mdl, c): i = mdl.get_variable('state', 'acc') p_rng = mdl.make_rng('params') l_rng = mdl.make_rng('loop') - mdl.put_variable('state', 'rng_params', mdl.get_variable('state', 'rng_params').at[i].set(p_rng)) - mdl.put_variable('state', 'rng_loop', mdl.get_variable('state', 'rng_loop').at[i].set(l_rng)) + mdl.put_variable( + 'state', + 'rng_params', + mdl.get_variable('state', 'rng_params').at[i].set(p_rng), + ) + mdl.put_variable( + 'state', + 'rng_loop', + mdl.get_variable('state', 'rng_loop').at[i].set(l_rng), + ) inc = mdl.get_variable('params', 'inc') mdl.put_variable('state', 'acc', i + inc) return c + return nn.while_loop( - cond_fn, body_fn, self, (), - carry_variables='state', split_rngs={'params': False, 'loop': True}) + cond_fn, + body_fn, + self, + (), + carry_variables='state', + split_rngs={'params': False, 'loop': True}, + ) + x = 2 mdl = Foo() - _, vars = mdl.apply({}, x, mutable=True, rngs={'params': random.PRNGKey(1), 'loop': random.PRNGKey(2)}) + _, vars = mdl.apply( + {}, + x, + mutable=True, + rngs={'params': random.PRNGKey(1), 'loop': random.PRNGKey(2)}, + ) self.assertEqual(vars['state']['acc'], x) - np.testing.assert_array_equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1]) - np.testing.assert_array_compare(operator.__ne__, vars['state']['rng_loop'][0], vars['state']['rng_loop'][1]) + np.testing.assert_array_equal( + vars['state']['rng_params'][0], vars['state']['rng_params'][1] + ) + np.testing.assert_array_compare( + operator.__ne__, vars['state']['rng_loop'][0], vars['state']['rng_loop'][1] + ) def test_cond(self): class Foo(nn.Module): + @nn.compact def __call__(self, x, pred): self.variable('state', 'true_count', lambda: 0) self.variable('state', 'false_count', lambda: 0) + def true_fn(mdl, x): mdl.variable('state', 'true_count').value += 1 return nn.Dense(2, name='dense')(x) @@ -1504,11 +1778,13 @@ def false_fn(mdl, x): def test_switch(self): class Foo(nn.Module): + @nn.compact def __call__(self, x, pred): self.variable('state', 'a_count', lambda: 0) self.variable('state', 'b_count', lambda: 0) self.variable('state', 'c_count', lambda: 0) + def a_fn(mdl, x): mdl.variable('state', 'a_count').value += 1 return nn.Dense(2, name='dense')(x) @@ -1527,22 +1803,23 @@ def c_fn(mdl, x): foo = Foo() y1, vars = foo.init_with_output(random.PRNGKey(0), x, 0) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 0, 'c_count': 0}) - y2, updates = foo.apply(vars, x, 1, mutable="state") + y2, updates = foo.apply(vars, x, 1, mutable='state') vars = copy(vars, updates) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 0}) np.testing.assert_allclose(y1, -y2) - y3, updates = foo.apply(vars, x, 2, mutable="state") + y3, updates = foo.apply(vars, x, 2, mutable='state') vars = copy(vars, updates) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1}) np.testing.assert_allclose(y1, y3) def test_switch_multihead(self): class Foo(nn.Module): + def setup(self) -> None: self.heads = [ - nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]), - nn.Sequential([nn.Dense(11), nn.Dense(5)]), - nn.Dense(5), + nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]), + nn.Sequential([nn.Dense(11), nn.Dense(5)]), + nn.Dense(5), ] @nn.compact @@ -1551,6 +1828,7 @@ def head_fn(i): def fn(mdl, x): mdl.variable('state', f'{i}_count', lambda: -1).value += 1 return mdl.heads[i](x) + return fn branches = [head_fn(i) for i in range(len(self.heads))] @@ -1565,10 +1843,10 @@ def fn(mdl, x): foo = Foo() y1, vars = foo.init_with_output(random.PRNGKey(0), x, 0) self.assertEqual(vars['state'], {'0_count': 1, '1_count': 0, '2_count': 0}) - y2, updates = foo.apply(vars, x, 1, mutable="state") + y2, updates = foo.apply(vars, x, 1, mutable='state') vars = copy(vars, updates) self.assertEqual(vars['state'], {'0_count': 1, '1_count': 1, '2_count': 0}) - y3, updates = foo.apply(vars, x, 2, mutable="state") + y3, updates = foo.apply(vars, x, 2, mutable='state') vars = copy(vars, updates) self.assertEqual(vars['state'], {'0_count': 1, '1_count': 1, '2_count': 1}) @@ -1587,13 +1865,13 @@ def fn(mdl, x): self.assertEqual(vars['params']['heads_2']['kernel'].shape, (3, 5)) self.assertEqual(vars['params']['heads_2']['bias'].shape, (5,)) - - def test_lift_instance_error(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): return nn.checkpoint(nn.Dense(2))(x) + with self.assertRaises(errors.TransformTargetError): Foo().init(random.PRNGKey(0), jnp.zeros((2, 3))) @@ -1605,7 +1883,13 @@ class Foo(nn.Module): def __call__(self, x): def body_fn(mdl, x): return nn.Dense(features=x.shape[-1])(x), () - x, _ = nn.scan(body_fn, length=self.num_layers, variable_axes={"params": 0}, split_rngs={"params": True})(self, x) + + x, _ = nn.scan( + body_fn, + length=self.num_layers, + variable_axes={'params': 0}, + split_rngs={'params': True}, + )(self, x) return x m = Foo() @@ -1616,6 +1900,7 @@ def body_fn(mdl, x): def test_bound_methods_in_direct_transforms(self): class CondModel(nn.Module): + def setup(self): self.dense = nn.Dense(3) @@ -1633,69 +1918,105 @@ def __call__(self, x): cond_model = CondModel() output, init_params = jax.jit(cond_model.init_with_output)( - jax.random.PRNGKey(0), - x=jnp.ones(3)) + jax.random.PRNGKey(0), x=jnp.ones(3) + ) def test_add_metadata_axis(self): vars_copy = None + class Foo(nn.Module): + @nn.compact def __call__(self, x): nonlocal vars_copy - kernel_init=nn.with_partitioning( - nn.initializers.lecun_normal(), ('foo', 'bar')) + kernel_init = nn.with_partitioning( + nn.initializers.lecun_normal(), ('foo', 'bar') + ) vars_copy = self.variables - return nn.Dense(4, kernel_init=kernel_init, use_bias=False, name="dense")(x) + return nn.Dense(4, kernel_init=kernel_init, use_bias=False, name='dense')(x) + class Test(nn.Module): - @partial(nn.add_metadata_axis, - variable_axes={'params': 0}, - metadata_params={nn.PARTITION_NAME: 'baz'}) + + @partial( + nn.add_metadata_axis, + variable_axes={'params': 0}, + metadata_params={nn.PARTITION_NAME: 'baz'}, + ) @nn.compact def __call__(self, x): - return Foo(name="foo")(x) + return Foo(name='foo')(x) k = random.PRNGKey(0) - x = jnp.ones((4,4), dtype=jnp.float32) + x = jnp.ones((4, 4), dtype=jnp.float32) vs = Test().init(k, x) y = Test().apply(vs, x) - outer_expect = jax.tree_map(jnp.shape, - freeze({'params': {'foo': {'dense': {'kernel': - nn.Partitioned(jnp.ones((4, 4)), names=('baz', 'foo', 'bar'))}}}})) - inner_expect = jax.tree_map(jnp.shape, - freeze({'params': {'dense': {'kernel': - nn.Partitioned(jnp.ones((4, 4)), names=('foo', 'bar'))}}})) + outer_expect = jax.tree_map( + jnp.shape, + freeze( + { + 'params': { + 'foo': { + 'dense': { + 'kernel': nn.Partitioned( + jnp.ones((4, 4)), names=('baz', 'foo', 'bar') + ) + } + } + } + } + ), + ) + inner_expect = jax.tree_map( + jnp.shape, + freeze( + { + 'params': { + 'dense': { + 'kernel': nn.Partitioned(jnp.ones((4, 4)), names=('foo', 'bar')) + } + } + } + ), + ) self.assertEqual(jax.tree_map(jnp.shape, vs), outer_expect) self.assertEqual(jax.tree_map(jnp.shape, vars_copy), inner_expect) - def test_outer_setup_called_with_sharing_across_transforms(self): class A(nn.Module): + def setup(self): - self.foo = self.param( - 'foo', nn.initializers.zeros, (2, 2), jnp.float32) + self.foo = self.param('foo', nn.initializers.zeros, (2, 2), jnp.float32) + def __call__(self, x): return self.foo + class B(nn.Module): a: Any + @nn.compact def __call__(self, x): return self.a(x) + class C(nn.Module): + def setup(self): self.a = A() self.b = nn.jit(B)(self.a) + def __call__(self, x): b = self.b(x) a = self.a(x) return a + b + k = random.PRNGKey(0) x = random.randint(k, (2, 2), minval=0, maxval=10) vs = C().init(k, x) y = C().apply(vs, x) - outer_expect = jax.tree_map(jnp.shape, - freeze({'params': {'a': {'foo': jnp.zeros((2, 2))}}})) + outer_expect = jax.tree_map( + jnp.shape, freeze({'params': {'a': {'foo': jnp.zeros((2, 2))}}}) + ) self.assertEqual(jax.tree_map(jnp.shape, vs), outer_expect) if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() diff --git a/tests/linen/partitioning_test.py b/tests/linen/partitioning_test.py index 70555f6f65..99784fcd8c 100644 --- a/tests/linen/partitioning_test.py +++ b/tests/linen/partitioning_test.py @@ -56,66 +56,67 @@ def test_logical_to_mesh_axes(self): axes_0 = ('foo', 'bar') # direct rule assignment self.assertEqual( - partitioning.logical_to_mesh_axes(axes_0, rules=AXIS_RULES_1), - ('data', 'model')) + partitioning.logical_to_mesh_axes(axes_0, rules=AXIS_RULES_1), ('data', 'model') + ) # axis rules context with partitioning.axis_rules(AXIS_RULES_1): - self.assertEqual( - partitioning.logical_to_mesh_axes(axes_0), ('data', 'model')) + self.assertEqual(partitioning.logical_to_mesh_axes(axes_0), ('data', 'model')) # nested context with partitioning.axis_rules(AXIS_RULES_2): - self.assertEqual( - partitioning.logical_to_mesh_axes(axes_0), ('model', None)) + self.assertEqual(partitioning.logical_to_mesh_axes(axes_0), ('model', None)) # duplicated logical names with partitioning.axis_rules(AXIS_RULES_1): with self.assertRaises(ValueError): partitioning.logical_to_mesh_axes(('foo', 'foo', 'baz')) def test_logical_to_mesh_axes_priorities(self): - p_rules = ( - ('foo', 'model'), - ('bar', 'model'), - ('baz', 'data')) + p_rules = (('foo', 'model'), ('bar', 'model'), ('baz', 'data')) with partitioning.axis_rules(p_rules): self.assertEqual( partitioning.logical_to_mesh_axes(('foo', 'bar', 'baz')), - ('model', None, 'data')) + ('model', None, 'data'), + ) self.assertEqual( partitioning.logical_to_mesh_axes(('bar', 'foo', 'baz')), - (None, 'model', 'data')) + (None, 'model', 'data'), + ) self.assertEqual( partitioning.logical_to_mesh_axes(('baz', 'bar', 'foo')), - ('data', None, 'model')) + ('data', None, 'model'), + ) self.assertEqual( - partitioning.logical_to_mesh_axes( - ('baz', 'bar', 'foo', 'unassigned')), - ('data', None, 'model', None)) + partitioning.logical_to_mesh_axes(('baz', 'bar', 'foo', 'unassigned')), + ('data', None, 'model', None), + ) @parameterized.parameters( - dict(rules=(('a', ('model', 'data')), ('b', 'data')), - axes=('a', 'b'), - expected=(('model', 'data'), None)), - dict(rules=(('a', ('model', 'replica')), ('b', 'data')), - axes=('a', 'b'), - expected=(('model', 'replica'), 'data')), - dict(rules=(('a', ('model', 'replica')), ('b', ('data', 'model'))), - axes=('a', 'b'), - expected=(('model', 'replica'), None)), - dict(rules=(('a', ('model', 'replica')), ('b', 'model')), - axes=('a', 'b', 'c'), - expected=(('model', 'replica'), None, None)), - dict(rules=(), - axes=('a', 'b', 'c'), - expected=(None, None, None)), - dict(rules=(('a', None), ('a', 'model')), - axes=('a', 'b'), - expected=(None, None)), - dict(rules=(('baz', 'data'), - ('bar', None), - ('foo', 'model'), - ('foo', 'data')), - axes=('baz', 'bar', 'foo'), - expected=('data', None, 'model')), + dict( + rules=(('a', ('model', 'data')), ('b', 'data')), + axes=('a', 'b'), + expected=(('model', 'data'), None), + ), + dict( + rules=(('a', ('model', 'replica')), ('b', 'data')), + axes=('a', 'b'), + expected=(('model', 'replica'), 'data'), + ), + dict( + rules=(('a', ('model', 'replica')), ('b', ('data', 'model'))), + axes=('a', 'b'), + expected=(('model', 'replica'), None), + ), + dict( + rules=(('a', ('model', 'replica')), ('b', 'model')), + axes=('a', 'b', 'c'), + expected=(('model', 'replica'), None, None), + ), + dict(rules=(), axes=('a', 'b', 'c'), expected=(None, None, None)), + dict(rules=(('a', None), ('a', 'model')), axes=('a', 'b'), expected=(None, None)), + dict( + rules=(('baz', 'data'), ('bar', None), ('foo', 'model'), ('foo', 'data')), + axes=('baz', 'bar', 'foo'), + expected=('data', None, 'model'), + ), ) def test_logical_to_mesh_axes_cases(self, rules, axes, expected): with partitioning.axis_rules(rules): @@ -142,21 +143,31 @@ def test_with_sharding_constraint_fallback(self, wsc_fn): arr = jnp.ones((2, 2)) with partitioning.axis_rules(AXIS_RULES_1): _ = partitioning.with_sharding_constraint(arr, ('foo', 'not_recognized')) - wsc_fn.assert_called_with(arr, jax.sharding.PartitionSpec('data', None), mesh=None) + wsc_fn.assert_called_with( + arr, jax.sharding.PartitionSpec('data', None), mesh=None + ) wsc_fn.reset_mock() _ = partitioning.with_sharding_constraint( - arr, ('foo', 'not_recognized'), - fallback=partitioning.RulesFallback.AXIS_IS_UNSHARDED) - wsc_fn.assert_called_with(arr, jax.sharding.PartitionSpec('data', None), mesh=None) + arr, + ('foo', 'not_recognized'), + fallback=partitioning.RulesFallback.AXIS_IS_UNSHARDED, + ) + wsc_fn.assert_called_with( + arr, jax.sharding.PartitionSpec('data', None), mesh=None + ) wsc_fn.reset_mock() with self.assertRaises(ValueError): _ = partitioning.with_sharding_constraint( - arr, ('foo', 'not_recognized'), - fallback=partitioning.RulesFallback.RAISE_ERROR) + arr, + ('foo', 'not_recognized'), + fallback=partitioning.RulesFallback.RAISE_ERROR, + ) wsc_fn.assert_not_called() _ = partitioning.with_sharding_constraint( - arr, ('foo', 'not_recognized'), - fallback=partitioning.RulesFallback.NO_CONSTRAINT) + arr, + ('foo', 'not_recognized'), + fallback=partitioning.RulesFallback.NO_CONSTRAINT, + ) wsc_fn.assert_not_called() @parameterized.parameters(dict(axes_spec=None), dict(axes_spec=())) @@ -166,8 +177,8 @@ class ParamTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.param_with_axes( - 'foo', lambda k, s, d: jnp.zeros(s, d), - (2, 2), x.dtype, axes=axes_spec) + 'foo', lambda k, s, d: jnp.zeros(s, d), (2, 2), x.dtype, axes=axes_spec + ) return x + foo k = random.PRNGKey(0) @@ -180,57 +191,62 @@ class ParamTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.param_with_axes( - 'foo', lambda k, s, d: jnp.zeros(s, d), - (2, 2), x.dtype, axes=('foo', 'bar')) + 'foo', lambda k, s, d: jnp.zeros(s, d), (2, 2), x.dtype, axes=('foo', 'bar') + ) return x + foo - p_rules = ( - ('foo', 'model'), - ('bar', 'data'), - ('baz', None)) + p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) k = random.PRNGKey(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = ParamTest().init(k, x) self.assertIn('params', variables) self.assertIn('params_axes', variables) - self.assertEqual(variables['params_axes']['foo_axes'], - partitioning.AxisMetadata(names=('foo', 'bar'))) + self.assertEqual( + variables['params_axes']['foo_axes'], + partitioning.AxisMetadata(names=('foo', 'bar')), + ) logical_axis_names = partitioning.get_axis_names(variables['params_axes']) - self.assertEqual(logical_axis_names, - {'foo': jax.sharding.PartitionSpec('foo', 'bar')}) + self.assertEqual( + logical_axis_names, {'foo': jax.sharding.PartitionSpec('foo', 'bar')} + ) def test_param_pytree_with_axes(self): def init_fn(k, s, d): del k return {'a': jnp.zeros(s, d), 'b': (jnp.zeros(s, d), jnp.zeros(s, d))} + axes = {'a': ('foo', 'bar'), 'b': (('foo', 'bar'), ('bar', 'foo'))} + class ParamTest(nn.Module): @nn.compact def __call__(self, x): - foo = partitioning.param_with_axes( - 'foo', init_fn, (2, 2), x.dtype, axes=axes) + foo = partitioning.param_with_axes('foo', init_fn, (2, 2), x.dtype, axes=axes) return x + foo['a'] - p_rules = ( - ('foo', 'model'), - ('bar', 'data'), - ('baz', None)) + p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) k = random.PRNGKey(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = ParamTest().init(k, x) self.assertIn('params', variables) self.assertIn('params_axes', variables) - self.assertEqual(variables['params_axes']['foo_axes'], - partitioning.AxisMetadata(names=axes)) + self.assertEqual( + variables['params_axes']['foo_axes'], partitioning.AxisMetadata(names=axes) + ) logical_axis_names = partitioning.get_axis_names(variables['params_axes']) expected = freeze( - {'foo': - {'a': jax.sharding.PartitionSpec('foo', 'bar'), - 'b': (jax.sharding.PartitionSpec('foo', 'bar'), - jax.sharding.PartitionSpec('bar', 'foo'))}}) + { + 'foo': { + 'a': jax.sharding.PartitionSpec('foo', 'bar'), + 'b': ( + jax.sharding.PartitionSpec('foo', 'bar'), + jax.sharding.PartitionSpec('bar', 'foo'), + ), + } + } + ) self.assertEqual(logical_axis_names, expected) @parameterized.parameters(dict(axes_spec=None), dict(axes_spec=())) @@ -240,7 +256,8 @@ class VarTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.variable_with_axes( - 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=axes_spec) + 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=axes_spec + ) return x + foo.value k = random.PRNGKey(0) @@ -248,13 +265,13 @@ def __call__(self, x): _ = VarTest().init(k, x) def test_variable_with_empty_tuple_has_empty_axes(self): - class VarTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.variable_with_axes( - 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=()) + 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=() + ) return x + foo.value k = random.PRNGKey(0) @@ -269,24 +286,25 @@ class VarTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.variable_with_axes( - 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=('foo', 'bar')) + 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=('foo', 'bar') + ) return x + foo.value - p_rules = ( - ('foo', 'model'), - ('bar', 'data'), - ('baz', None)) + p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) k = random.PRNGKey(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = VarTest().init(k, x) self.assertIn('test', variables) self.assertIn('test_axes', variables) - self.assertEqual(variables['test_axes']['foo_axes'], - partitioning.AxisMetadata(names=('foo', 'bar'))) + self.assertEqual( + variables['test_axes']['foo_axes'], + partitioning.AxisMetadata(names=('foo', 'bar')), + ) logical_axis_names = partitioning.get_axis_names(variables['test_axes']) - self.assertEqual(logical_axis_names, - {'foo': jax.sharding.PartitionSpec('foo', 'bar')}) + self.assertEqual( + logical_axis_names, {'foo': jax.sharding.PartitionSpec('foo', 'bar')} + ) @mock.patch('flax.linen.partitioning._with_sharding_constraint') def test_variable_with_axes_fallback(self, wsc_fn): @@ -295,14 +313,21 @@ class VarTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.variable_with_axes( - 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=('foo', 'bar'), - fallback=partitioning.RulesFallback.NO_CONSTRAINT) + 'test', + 'foo', + jnp.zeros, + (2, 2), + x.dtype, + axes=('foo', 'bar'), + fallback=partitioning.RulesFallback.NO_CONSTRAINT, + ) return x + foo.value p_rules = ( # No rule for 'foo': ('bar', 'data'), - ('baz', None)) + ('baz', None), + ) k = random.PRNGKey(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): @@ -311,11 +336,14 @@ def __call__(self, x): wsc_fn.assert_not_called() self.assertIn('test', variables) self.assertIn('test_axes', variables) - self.assertEqual(variables['test_axes']['foo_axes'], - partitioning.AxisMetadata(names=('foo', 'bar'))) + self.assertEqual( + variables['test_axes']['foo_axes'], + partitioning.AxisMetadata(names=('foo', 'bar')), + ) logical_axis_names = partitioning.get_axis_names(variables['test_axes']) - self.assertEqual(logical_axis_names, - {'foo': jax.sharding.PartitionSpec('foo', 'bar')}) + self.assertEqual( + logical_axis_names, {'foo': jax.sharding.PartitionSpec('foo', 'bar')} + ) def test_scan_with_axes(self): # MLP Hparams @@ -333,15 +361,18 @@ def __call__(self, x): 'W1', nn.initializers.xavier_normal(), (x.shape[-1], self.depth), - axes=('emb', 'mlp')) + axes=('emb', 'mlp'), + ) W2 = partitioning.param_with_axes( # pylint: disable=invalid-name 'W2', nn.initializers.xavier_normal(), (self.depth, x.shape[-1]), - axes=('mlp', 'emb')) + axes=('mlp', 'emb'), + ) y = jnp.dot(jnp.sin(jnp.dot(x, W1)), W2) _ = partitioning.variable_with_axes( - 'stats', 'y_st', lambda: y, axes=('batch', 'emb')) + 'stats', 'y_st', lambda: y, axes=('batch', 'emb') + ) # scan expects a (carry, out) return signature. return y, None @@ -358,8 +389,8 @@ def __call__(self, x): split_rngs={'params': True}, axis_name='layer', axes_collections=('params', 'stats'), - length=self.num_layers)(self.depth, - name='scanned_layer') + length=self.num_layers, + )(self.depth, name='scanned_layer') y, _ = scanned_sindot(x) # test calling again to test metadata compatibility across calls _, _ = scanned_sindot(x) @@ -377,27 +408,39 @@ def __call__(self, x): self.assertIn('stats_axes', variables) self.assertEqual( variables['params_axes']['scanned_layer']['W1_axes'], - partitioning.AxisMetadata(names=('layer', 'emb', 'mlp'))) + partitioning.AxisMetadata(names=('layer', 'emb', 'mlp')), + ) logical_axis_names = partitioning.get_axis_names(variables['params_axes']) self.assertEqual( logical_axis_names, - {'scanned_layer': { - 'W1': jax.sharding.PartitionSpec('layer', 'emb', 'mlp'), - 'W2': jax.sharding.PartitionSpec('layer', 'mlp', 'emb')}}) + { + 'scanned_layer': { + 'W1': jax.sharding.PartitionSpec('layer', 'emb', 'mlp'), + 'W2': jax.sharding.PartitionSpec('layer', 'mlp', 'emb'), + } + }, + ) logical_axis_names = partitioning.get_axis_names(variables['stats_axes']) self.assertEqual( logical_axis_names, - {'scanned_layer': { - 'y_st': jax.sharding.PartitionSpec('batch', 'layer', 'emb')}}) + { + 'scanned_layer': { + 'y_st': jax.sharding.PartitionSpec('batch', 'layer', 'emb') + } + }, + ) def test_vmap_with_axes(self): - class Foo(nn.Module): @nn.compact def __call__(self, x): - return partitioning.param_with_axes( - 'w', jax.nn.initializers.uniform(), [4, 3], axes=('out', 'in')) @ x + return ( + partitioning.param_with_axes( + 'w', jax.nn.initializers.uniform(), [4, 3], axes=('out', 'in') + ) + @ x + ) class Vmapped(nn.Module): @@ -409,7 +452,8 @@ def __call__(self, x): 'params': 1, }, split_rngs={'params': True}, - partitioning_axis_names={'params': 'vmap_axis'}) + partitioning_axis_names={'params': 'vmap_axis'}, + ) return FooVmapped(name='foo_vmapped')(x) p_rules = (('out', None), ('in', 'data'), ('vmap_axis', 'model')) @@ -420,36 +464,33 @@ def __call__(self, x): variables = unfreeze(variables) variables['params'] = jax.tree_util.tree_map(lambda x: x.shape, variables['params']) self.assertDictEqual( - variables, { - 'params': { - 'w': (4, 3) - }, - 'params_axes': { - 'w_axes': partitioning.AxisMetadata(names=('out', 'in')) - } - }) + variables, + { + 'params': {'w': (4, 3)}, + 'params_axes': {'w_axes': partitioning.AxisMetadata(names=('out', 'in'))}, + }, + ) # check that FooVmapped adds 'vmap_axis' to axis 1 with partitioning.axis_rules(p_rules): variables = Vmapped().init( - jax.random.PRNGKey(0), jnp.array([[1, 2, 3], [4, 5, 6]])) + jax.random.PRNGKey(0), jnp.array([[1, 2, 3], [4, 5, 6]]) + ) variables = unfreeze(variables) variables['params'] = jax.tree_util.tree_map(lambda x: x.shape, variables['params']) self.assertDictEqual( - variables, { - 'params': { - 'foo_vmapped': { - 'w': (4, 2, 3) - } - }, + variables, + { + 'params': {'foo_vmapped': {'w': (4, 2, 3)}}, 'params_axes': { 'foo_vmapped': { - 'w_axes': - partitioning.AxisMetadata( - names=('out', 'vmap_axis', 'in')) + 'w_axes': partitioning.AxisMetadata( + names=('out', 'vmap_axis', 'in') + ) } - } - }) + }, + }, + ) def test_logical_with_mesh_and_rules(self): devices = mesh_utils.create_device_mesh((jax.local_device_count(), 1)) @@ -458,10 +499,12 @@ def test_logical_with_mesh_and_rules(self): rules = (('a', 'in'), ('b', 'out')) class Foo(nn.Module): + @nn.compact def __call__(self, x): kernel_init = nn.with_logical_partitioning( - nn.initializers.ones_init(), ('a', 'b'), mesh=mesh, rules=rules) + nn.initializers.ones_init(), ('a', 'b'), mesh=mesh, rules=rules + ) kernel = self.param('kernel', kernel_init, (x.shape[-1], 2)) kernel_box = self.get_variable('params', 'kernel') test.assertIsInstance(kernel_box, nn.Partitioned) @@ -477,12 +520,11 @@ def create_state(): variables = jax.lax.with_sharding_constraint(variables, shardings) return variables - variables = create_state() - self.assertEqual(variables['params']['kernel'].names, - ('a', 'b')) + self.assertEqual(variables['params']['kernel'].names, ('a', 'b')) self.assertIs(variables['params']['kernel'].mesh, mesh) self.assertEqual(variables['params']['kernel'].rules, rules) + if __name__ == '__main__': absltest.main() diff --git a/tests/linen/summary_test.py b/tests/linen/summary_test.py index f20924ad22..736f179033 100644 --- a/tests/linen/summary_test.py +++ b/tests/linen/summary_test.py @@ -31,14 +31,17 @@ CONSOLE_TEST_KWARGS = dict(force_terminal=False, no_color=True, width=10_000) + def _get_shapes(pytree): - return jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, pytree) + return jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, "shape") else x, pytree) + def _get_obj_repr_value(x): if isinstance(x, summary._ObjectRepresentation): return x.obj return x + class ConvBlock(nn.Module): features: int kernel_size: List[int] @@ -53,7 +56,7 @@ def block_method(self, x: Array, training: bool) -> Array: x = self.conv(x) if self.test_sow: - self.sow('intermediates', 'INTERM', x) + self.sow("intermediates", "INTERM", x) x = self.bn(x, use_running_average=not training) x = self.dropout(x, deterministic=not training) @@ -64,13 +67,14 @@ def __call__(self, x: Array, training: bool) -> Array: x = self.conv(x) if self.test_sow: - self.sow('intermediates', 'INTERM', x) + self.sow("intermediates", "INTERM", x) x = self.bn(x, use_running_average=not training) x = self.dropout(x, deterministic=not training) x = nn.relu(x) return x + class CNN(nn.Module): test_sow: bool @@ -85,11 +89,11 @@ def cnn_method(self, x: Array, training: bool) -> Array: x = x.mean(axis=(1, 2)) if self.test_sow: - self.sow('intermediates', 'INTERM', x) + self.sow("intermediates", "INTERM", x) x = self.dense(x) - return x, dict(a=x, b=x+1.0) + return x, dict(a=x, b=x + 1.0) def __call__(self, x: Array, training: bool) -> Array: x = self.block1.block_method(x, training=training) @@ -97,11 +101,12 @@ def __call__(self, x: Array, training: bool) -> Array: x = x.mean(axis=(1, 2)) if self.test_sow: - self.sow('intermediates', 'INTERM', x) + self.sow("intermediates", "INTERM", x) x = self.dense(x) - return x, dict(a=x, b=x+1.0) + return x, dict(a=x, b=x + 1.0) + class SummaryTest(absltest.TestCase): @@ -117,8 +122,10 @@ def test_module_summary(self): module = CNN(test_sow=False) table = summary._get_module_table(module, depth=None, show_repeated=True)( - {"dropout":random.PRNGKey(0), "params": random.PRNGKey(1)}, - x, training=True, mutable=True, + {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + x, + training=True, + mutable=True, ) # get values for inputs and outputs from their _ValueRepresentation for row in table: @@ -145,30 +152,46 @@ def test_module_summary(self): # check outputs shapes self.assertEqual( - (table[0].inputs[0].shape, table[0].inputs[1]), - (x.shape, dict(training=True)), + (table[0].inputs[0].shape, table[0].inputs[1]), + (x.shape, dict(training=True)), ) self.assertEqual( - _get_shapes(table[0].outputs), - ((batch_size, 10), dict(a=(batch_size, 10), b=(batch_size, 10))), + _get_shapes(table[0].outputs), + ((batch_size, 10), dict(a=(batch_size, 10), b=(batch_size, 10))), ) - self.assertEqual(_get_shapes(table[1].inputs), ((batch_size, 28, 28, 1), {'training': True})) + self.assertEqual( + _get_shapes(table[1].inputs), ((batch_size, 28, 28, 1), {"training": True}) + ) self.assertEqual(table[1].outputs.shape, (batch_size, 28, 28, 32)) self.assertEqual(table[2].inputs.shape, (batch_size, 28, 28, 1)) self.assertEqual(table[2].outputs.shape, (batch_size, 28, 28, 32)) - self.assertEqual(_get_shapes(table[3].inputs), ((batch_size, 28, 28, 32), {'use_running_average': False})) + self.assertEqual( + _get_shapes(table[3].inputs), + ((batch_size, 28, 28, 32), {"use_running_average": False}), + ) self.assertEqual(table[3].outputs.shape, (batch_size, 28, 28, 32)) - self.assertEqual(_get_shapes(table[4].inputs), ((batch_size, 28, 28, 32), {'deterministic': False})) + self.assertEqual( + _get_shapes(table[4].inputs), + ((batch_size, 28, 28, 32), {"deterministic": False}), + ) self.assertEqual(table[4].outputs.shape, (batch_size, 28, 28, 32)) - self.assertEqual(_get_shapes(table[5].inputs), ((batch_size, 28, 28, 32), {'training': True})) + self.assertEqual( + _get_shapes(table[5].inputs), ((batch_size, 28, 28, 32), {"training": True}) + ) self.assertEqual(table[5].outputs.shape, (batch_size, 28, 28, 64)) self.assertEqual(table[6].inputs.shape, (batch_size, 28, 28, 32)) self.assertEqual(table[6].outputs.shape, (batch_size, 28, 28, 64)) - self.assertEqual(_get_shapes(table[7].inputs), ((batch_size, 28, 28, 64), {'use_running_average': False})) + self.assertEqual( + _get_shapes(table[7].inputs), + ((batch_size, 28, 28, 64), {"use_running_average": False}), + ) self.assertEqual(table[7].outputs.shape, (batch_size, 28, 28, 64)) - self.assertEqual(_get_shapes(table[8].inputs), ((batch_size, 28, 28, 64), {'deterministic': False})) + self.assertEqual( + _get_shapes(table[8].inputs), + ((batch_size, 28, 28, 64), {"deterministic": False}), + ) self.assertEqual(table[8].outputs.shape, (batch_size, 28, 28, 64)) self.assertEqual(table[9].inputs.shape, (batch_size, 64)) @@ -177,8 +200,8 @@ def test_module_summary(self): # check no summary is performed for row in table: self.assertEqual( - row.module_variables, - row.counted_variables, + row.module_variables, + row.counted_variables, ) def test_module_summary_with_depth(self): @@ -192,8 +215,10 @@ def test_module_summary_with_depth(self): module = CNN(test_sow=False) table = summary._get_module_table(module, depth=1, show_repeated=True)( - {"dropout":random.PRNGKey(0), "params": random.PRNGKey(1)}, - x, training=True, mutable=True, + {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + x, + training=True, + mutable=True, ) # get values for inputs and outputs from their _ValueRepresentation @@ -213,18 +238,22 @@ def test_module_summary_with_depth(self): # check outputs shapes self.assertEqual( - (table[0].inputs[0].shape, table[0].inputs[1]), - (x.shape, dict(training=True)), + (table[0].inputs[0].shape, table[0].inputs[1]), + (x.shape, dict(training=True)), ) self.assertEqual( - _get_shapes(table[0].outputs), - ((batch_size, 10), dict(a=(batch_size, 10), b=(batch_size, 10))), + _get_shapes(table[0].outputs), + ((batch_size, 10), dict(a=(batch_size, 10), b=(batch_size, 10))), ) - self.assertEqual(_get_shapes(table[1].inputs), ((batch_size, 28, 28, 1), {'training': True})) + self.assertEqual( + _get_shapes(table[1].inputs), ((batch_size, 28, 28, 1), {"training": True}) + ) self.assertEqual(table[1].outputs.shape, (batch_size, 28, 28, 32)) - self.assertEqual(_get_shapes(table[2].inputs), ((batch_size, 28, 28, 32), {'training': True})) + self.assertEqual( + _get_shapes(table[2].inputs), ((batch_size, 28, 28, 32), {"training": True}) + ) self.assertEqual(table[2].outputs.shape, (batch_size, 28, 28, 64)) self.assertEqual(table[3].inputs.shape, (batch_size, 64)) @@ -238,7 +267,6 @@ def test_module_summary_with_depth(self): self.assertEqual(table[0].module_variables, table[0].counted_variables) self.assertEqual(table[3].module_variables, table[3].counted_variables) - def test_tabulate(self): """ This test creates a string representation of a Module using `Module.tabulate` @@ -250,7 +278,7 @@ def test_tabulate(self): module = CNN(test_sow=False) module_repr = module.tabulate( - {"dropout":random.PRNGKey(0), "params": random.PRNGKey(1)}, + {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, x, training=True, console_kwargs=CONSOLE_TEST_KWARGS, @@ -285,37 +313,34 @@ def test_tabulate(self): self.assertIn("19,850", lines[-3]) self.assertIn("79.4 KB", lines[-3]) - def test_tabulate_with_sow(self): - batch_size = 32 x = jnp.ones((batch_size, 28, 28, 1)) module = CNN(test_sow=True) module_repr = module.tabulate( - {"dropout":random.PRNGKey(0), "params": random.PRNGKey(1)}, - x, - training=True, - console_kwargs=CONSOLE_TEST_KWARGS, + {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + x, + training=True, + console_kwargs=CONSOLE_TEST_KWARGS, ) self.assertIn("intermediates", module_repr) self.assertIn("INTERM", module_repr) def test_tabulate_with_method(self): - batch_size = 32 x = jnp.ones((batch_size, 28, 28, 1)) module = CNN(test_sow=False) module_repr = module.tabulate( - {"dropout":random.PRNGKey(0), "params": random.PRNGKey(1)}, - x, - training=True, - method=CNN.cnn_method, - console_kwargs=CONSOLE_TEST_KWARGS, + {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + x, + training=True, + method=CNN.cnn_method, + console_kwargs=CONSOLE_TEST_KWARGS, ) self.assertIn("(block_method)", module_repr) @@ -332,12 +357,12 @@ def test_tabulate_function(self): module = CNN(test_sow=False) module_repr = nn.tabulate( - module, - {"dropout":random.PRNGKey(0), "params": random.PRNGKey(1)}, - console_kwargs=CONSOLE_TEST_KWARGS, + module, + {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + console_kwargs=CONSOLE_TEST_KWARGS, )( - x, - training=True, + x, + training=True, ) lines = module_repr.split("\n") @@ -366,33 +391,32 @@ def test_tabulate_function(self): self.assertIn("19,850", lines[-3]) self.assertIn("79.4 KB", lines[-3]) - def test_lifted_transform(self): class LSTM(nn.Module): features: int @nn.compact def __call__(self, x): - carry = nn.LSTMCell(self.features).initialize_carry( - random.PRNGKey(0), x[:, 0].shape - ) - ScanLSTM = nn.scan( - nn.LSTMCell, - variable_broadcast="params", - split_rngs={"params": False}, - in_axes=1, - out_axes=1, - ) - return ScanLSTM(self.features, name="ScanLSTM")(carry, x) - + carry = nn.LSTMCell(self.features).initialize_carry( + random.PRNGKey(0), x[:, 0].shape + ) + ScanLSTM = nn.scan( + nn.LSTMCell, + variable_broadcast="params", + split_rngs={"params": False}, + in_axes=1, + out_axes=1, + ) + return ScanLSTM(self.features, name="ScanLSTM")(carry, x) lstm = LSTM(features=128) with jax.check_tracer_leaks(True): module_repr = lstm.tabulate( - random.PRNGKey(0), - x=jnp.ones((32, 128, 64)), - console_kwargs=CONSOLE_TEST_KWARGS) + random.PRNGKey(0), + x=jnp.ones((32, 128, 64)), + console_kwargs=CONSOLE_TEST_KWARGS, + ) lines = module_repr.splitlines() @@ -408,26 +432,26 @@ class LSTM(nn.Module): @nn.compact def __call__(self, x): - carry = nn.LSTMCell(self.features).initialize_carry( - random.PRNGKey(0), x[:, 0].shape - ) - ScanLSTM = nn.scan( - nn.LSTMCell, - variable_broadcast="params", - split_rngs={"params": False}, - in_axes=1, - out_axes=1, - ) - return ScanLSTM(self.features)(carry, x) - + carry = nn.LSTMCell(self.features).initialize_carry( + random.PRNGKey(0), x[:, 0].shape + ) + ScanLSTM = nn.scan( + nn.LSTMCell, + variable_broadcast="params", + split_rngs={"params": False}, + in_axes=1, + out_axes=1, + ) + return ScanLSTM(self.features)(carry, x) lstm = LSTM(features=128) with jax.check_tracer_leaks(True): module_repr = lstm.tabulate( - random.PRNGKey(0), - x=jnp.ones((32, 128, 64)), - console_kwargs=CONSOLE_TEST_KWARGS) + random.PRNGKey(0), + x=jnp.ones((32, 128, 64)), + console_kwargs=CONSOLE_TEST_KWARGS, + ) lines = module_repr.splitlines() @@ -439,6 +463,7 @@ def __call__(self, x): def test_module_reuse(self): class ConvBlock(nn.Module): + @nn.compact def __call__(self, x): x = nn.Conv(32, [3, 3])(x) @@ -448,6 +473,7 @@ def __call__(self, x): return x class CNN(nn.Module): + @nn.compact def __call__(self, x): block = ConvBlock() @@ -458,10 +484,11 @@ def __call__(self, x): x = jnp.ones((4, 28, 28, 32)) module_repr = CNN().tabulate( - jax.random.PRNGKey(0), - x=x, - show_repeated=True, - console_kwargs=CONSOLE_TEST_KWARGS) + jax.random.PRNGKey(0), + x=x, + show_repeated=True, + console_kwargs=CONSOLE_TEST_KWARGS, + ) lines = module_repr.splitlines() # first call @@ -490,6 +517,7 @@ def __call__(self, x): def test_empty_input(self): class EmptyInput(nn.Module): + @nn.compact def __call__(self): return 1 @@ -498,14 +526,16 @@ def __call__(self): module_repr = module.tabulate({}, console_kwargs=CONSOLE_TEST_KWARGS) lines = module_repr.splitlines() - self.assertRegex(lines[5], r'|\s*|\s*EmptyInput\s*|\s*|\s*1\s*|') + self.assertRegex(lines[5], r"|\s*|\s*EmptyInput\s*|\s*|\s*1\s*|") def test_numpy_scalar(self): class Submodule(nn.Module): + def __call__(self, x): return x + 1 class EmptyInput(nn.Module): + @nn.compact def __call__(self): return Submodule()(x=np.pi) @@ -514,56 +544,57 @@ def __call__(self): module_repr = module.tabulate({}, console_kwargs=CONSOLE_TEST_KWARGS) lines = module_repr.splitlines() - self.assertIn('4.141592', lines[5]) - self.assertIn('x: 3.141592', lines[7]) - self.assertIn('4.141592', lines[7]) + self.assertIn("4.141592", lines[5]) + self.assertIn("x: 3.141592", lines[7]) + self.assertIn("4.141592", lines[7]) def test_partitioned_params(self): - class Classifier(nn.Module): + @nn.compact def __call__(self, x): hidden = nn.Dense( - features=1024, - kernel_init=nn.with_partitioning( - nn.initializers.lecun_normal(), (None, 'data') - ), - bias_init=nn.with_partitioning( - nn.initializers.zeros, (None,) - ), - name='hidden', + features=1024, + kernel_init=nn.with_partitioning( + nn.initializers.lecun_normal(), (None, "data") + ), + bias_init=nn.with_partitioning(nn.initializers.zeros, (None,)), + name="hidden", ) x = x / 255.0 x = x.reshape((x.shape[0], -1)) # flatten x = nn.relu(hidden(x)) - x = nn.Dense(features=10, name='head')(x) + x = nn.Dense(features=10, name="head")(x) return x module = Classifier() - lines = module.tabulate(jax.random.PRNGKey(0), jnp.empty((1, 28, 28, 1)), - console_kwargs=CONSOLE_TEST_KWARGS).splitlines() - self.assertIn('P(None,)', lines[7]) - self.assertIn('P(None, data)', lines[8]) + lines = module.tabulate( + jax.random.PRNGKey(0), + jnp.empty((1, 28, 28, 1)), + console_kwargs=CONSOLE_TEST_KWARGS, + ).splitlines() + self.assertIn("P(None,)", lines[7]) + self.assertIn("P(None, data)", lines[8]) def test_non_array_variables(self): - class Metadata(struct.PyTreeNode): names: tuple = struct.field(pytree_node=False) class Foo(nn.Module): + @nn.compact def __call__(self): - self.sow('foo', 'bar', Metadata(('baz', 'qux'))) + self.sow("foo", "bar", Metadata(("baz", "qux"))) module = Foo() - lines = module.tabulate({}, - console_kwargs=CONSOLE_TEST_KWARGS).splitlines() - self.assertIn('names', lines[6]) - self.assertIn('baz', lines[7]) - self.assertIn('qux', lines[8]) + lines = module.tabulate({}, console_kwargs=CONSOLE_TEST_KWARGS).splitlines() + self.assertIn("names", lines[6]) + self.assertIn("baz", lines[7]) + self.assertIn("qux", lines[8]) def test_tabulate_param_count(self): class Foo(nn.Module): + @nn.compact def __call__(self, x): h = nn.Dense(4)(x) @@ -572,8 +603,8 @@ def __call__(self, x): x = jnp.ones((16, 9)) rep = Foo().tabulate(jax.random.PRNGKey(0), x, console_kwargs=CONSOLE_TEST_KWARGS) lines = rep.splitlines() - self.assertIn('Total Parameters: 50', lines[-2]) + self.assertIn("Total Parameters: 50", lines[-2]) -if __name__ == '__main__': - absltest.main() \ No newline at end of file +if __name__ == "__main__": + absltest.main() diff --git a/tests/linen/toplevel_test.py b/tests/linen/toplevel_test.py index e4a043b004..992d5c0c49 100644 --- a/tests/linen/toplevel_test.py +++ b/tests/linen/toplevel_test.py @@ -28,11 +28,14 @@ # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() + class Dummy(nn.Module): + @nn.compact def __call__(self): self.param('foo', lambda rng: 1) + class ModuleTopLevelTest(absltest.TestCase): pass # def test_toplevel_immutable(self): diff --git a/tests/serialization_test.py b/tests/serialization_test.py index f6d5a384bb..895b2584ab 100644 --- a/tests/serialization_test.py +++ b/tests/serialization_test.py @@ -55,6 +55,7 @@ class WrongTuple(NamedTuple): class OriginalModule(nn.Module): + @nn.compact def __call__(self, x): x = nn.Dense(10)(x) @@ -62,6 +63,7 @@ def __call__(self, x): class WrongModule(nn.Module): + @nn.compact def __call__(self, x): x = nn.Dense(10)(x) @@ -74,10 +76,13 @@ class SerializationTest(parameterized.TestCase): def test_dataclass_serialization(self): p = Point(x=1, y=2, meta={'dummy': True}) state_dict = serialization.to_state_dict(p) - self.assertEqual(state_dict, { - 'x': 1, - 'y': 2, - }) + self.assertEqual( + state_dict, + { + 'x': 1, + 'y': 2, + }, + ) restored_p = serialization.from_state_dict(p, {'x': 3, 'y': 4}) expected_p = Point(x=3, y=4, meta={'dummy': True}) self.assertEqual(restored_p, expected_p) @@ -93,12 +98,15 @@ def test_model_serialization(self): x = jnp.ones((1, 1), jnp.float32) initial_params = module.init(rng, x) state = serialization.to_state_dict(initial_params) - self.assertEqual(state, { - 'params': { - 'kernel': np.ones((1, 1)), - 'bias': np.zeros((1,)), - } - }) + self.assertEqual( + state, + { + 'params': { + 'kernel': np.ones((1, 1)), + 'bias': np.zeros((1,)), + } + }, + ) state = { 'params': { 'kernel': np.zeros((1, 1)), @@ -111,10 +119,7 @@ def test_model_serialization(self): def test_partial_serialization(self): add_one = Partial(jnp.add, 1) state = serialization.to_state_dict(add_one) - self.assertEqual(state, { - 'args': {'0': 1}, - 'keywords': {} - }) + self.assertEqual(state, {'args': {'0': 1}, 'keywords': {}}) restored_add_one = serialization.from_state_dict(add_one, state) self.assertEqual(add_one.args, restored_add_one.args) @@ -130,13 +135,13 @@ def test_optimizer_serialization(self): '0': { 'trace': { 'params': { - 'bias': np.array([0.], dtype=jnp.float32), - 'kernel': np.array([[0.]], dtype=jnp.float32) - } + 'bias': np.array([0.0], dtype=jnp.float32), + 'kernel': np.array([[0.0]], dtype=jnp.float32), } - }, - '1': {} - } + } + }, + '1': {}, + } self.assertEqual(state, expected_state) state = jax.tree_map(lambda x: x + 1, expected_state) restored_tx_state = serialization.from_state_dict(tx_state, state) @@ -144,7 +149,6 @@ def test_optimizer_serialization(self): self.assertEqual(restored_tx_state, tx_state_plus1) def test_collection_serialization(self): - @struct.dataclass class DummyDataClass: x: float @@ -152,84 +156,155 @@ class DummyDataClass: @classmethod def initializer(cls, shape): del shape - return cls(x=0.) + return cls(x=0.0) class StatefulModule(nn.Module): + @nn.compact def __call__(self): state = self.variable('state', 'dummy', DummyDataClass.initializer, ()) - state.value = state.value.replace(x=state.value.x + 1.) + state.value = state.value.replace(x=state.value.x + 1.0) initial_variables = StatefulModule().init(random.PRNGKey(0)) _, variables = StatefulModule().apply(initial_variables, mutable=['state']) serialized_state_dict = serialization.to_state_dict(variables) - self.assertEqual(serialized_state_dict, - {'state': {'dummy': {'x': 2.0}}}) - deserialized_state = serialization.from_state_dict(variables, - serialized_state_dict) + self.assertEqual(serialized_state_dict, {'state': {'dummy': {'x': 2.0}}}) + deserialized_state = serialization.from_state_dict(variables, serialized_state_dict) self.assertEqual(variables, deserialized_state) - @parameterized.parameters( - ['byte', 'b', 'ubyte', 'short', - 'h', 'ushort', 'i', 'uint', 'intp', - 'p', 'uintp', 'long', 'l', 'longlong', - 'q', 'ulonglong', 'half', 'e', 'f', - 'double', 'd', 'longdouble', 'g', - 'cfloat', 'cdouble', 'clongdouble', 'm', - 'bool8', 'b1', 'int64', 'i8', 'uint64', 'u8', - 'float16', 'f2', 'float32', 'f4', 'float64', - 'f8', 'float128', 'f16', 'complex64', 'c8', - 'complex128', 'c16', 'complex256', 'c32', - 'm8', 'int32', 'i4', 'uint32', 'u4', 'int16', - 'i2', 'uint16', 'u2', 'int8', 'i1', 'uint8', - 'u1', 'complex_', 'int0', 'uint0', 'single', - 'csingle', 'singlecomplex', 'float_', 'intc', - 'uintc', 'int_', 'longfloat', 'clongfloat', - 'longcomplex', 'bool_', 'int', 'float', - 'complex', 'bool'] - ) + @parameterized.parameters([ + 'byte', + 'b', + 'ubyte', + 'short', + 'h', + 'ushort', + 'i', + 'uint', + 'intp', + 'p', + 'uintp', + 'long', + 'l', + 'longlong', + 'q', + 'ulonglong', + 'half', + 'e', + 'f', + 'double', + 'd', + 'longdouble', + 'g', + 'cfloat', + 'cdouble', + 'clongdouble', + 'm', + 'bool8', + 'b1', + 'int64', + 'i8', + 'uint64', + 'u8', + 'float16', + 'f2', + 'float32', + 'f4', + 'float64', + 'f8', + 'float128', + 'f16', + 'complex64', + 'c8', + 'complex128', + 'c16', + 'complex256', + 'c32', + 'm8', + 'int32', + 'i4', + 'uint32', + 'u4', + 'int16', + 'i2', + 'uint16', + 'u2', + 'int8', + 'i1', + 'uint8', + 'u1', + 'complex_', + 'int0', + 'uint0', + 'single', + 'csingle', + 'singlecomplex', + 'float_', + 'intc', + 'uintc', + 'int_', + 'longfloat', + 'clongfloat', + 'longcomplex', + 'bool_', + 'int', + 'float', + 'complex', + 'bool', + ]) def test_numpy_serialization(self, dtype): np.random.seed(0) - if (dtype in {'float128', 'f16', 'complex256', 'c32'}) and (platform.system() == 'Darwin') and (platform.machine() == 'arm64'): - pytest.skip(f'Mac M1 does not support dtype {dtype}') # skip testing these dtypes if user is on Mac M1 + if ( + (dtype in {'float128', 'f16', 'complex256', 'c32'}) + and (platform.system() == 'Darwin') + and (platform.machine() == 'arm64') + ): + pytest.skip( + f'Mac M1 does not support dtype {dtype}' + ) # skip testing these dtypes if user is on Mac M1 v = np.random.uniform(-100, 100, size=()).astype(dtype)[()] - restored_v = serialization.msgpack_restore( - serialization.msgpack_serialize(v)) + restored_v = serialization.msgpack_restore(serialization.msgpack_serialize(v)) self.assertEqual(restored_v.dtype, v.dtype) np.testing.assert_array_equal(restored_v, v) for shape in [(), (5,), (10, 10), (1, 20, 30, 1)]: arr = np.random.uniform(-100, 100, size=shape).astype(dtype) - restored_arr = serialization.msgpack_restore( - serialization.msgpack_serialize(arr)) + restored_arr = serialization.msgpack_restore(serialization.msgpack_serialize(arr)) self.assertEqual(restored_arr.dtype, arr.dtype) np.testing.assert_array_equal(restored_arr, arr) def test_jax_numpy_serialization(self): - jax_dtypes = [jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, - jnp.int8, jnp.int16, jnp.int32, - jnp.bfloat16, jnp.float16, jnp.float32, - jnp.complex64] + jax_dtypes = [ + jnp.bool_, + jnp.uint8, + jnp.uint16, + jnp.uint32, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.bfloat16, + jnp.float16, + jnp.float32, + jnp.complex64, + ] for dtype in jax_dtypes: v = jnp.array(np.random.uniform(-100, 100, size=())).astype(dtype)[()] - restored_v = serialization.msgpack_restore( - serialization.msgpack_serialize(v)) + restored_v = serialization.msgpack_restore(serialization.msgpack_serialize(v)) self.assertEqual(restored_v.dtype, v.dtype) np.testing.assert_array_equal(restored_v, v) for shape in [(), (5,), (10, 10), (1, 20, 30, 1)]: - arr = jnp.array( - np.random.uniform(-100, 100, size=shape)).astype(dtype) + arr = jnp.array(np.random.uniform(-100, 100, size=shape)).astype(dtype) restored_arr = serialization.msgpack_restore( - serialization.msgpack_serialize(arr)) + serialization.msgpack_serialize(arr) + ) self.assertEqual(restored_arr.dtype, arr.dtype) np.testing.assert_array_equal(restored_arr, arr) def test_complex_serialization(self): - for x in [1j, 1+2j]: - restored_x = serialization.msgpack_restore( - serialization.msgpack_serialize(x)) + for x in [1j, 1 + 2j]: + restored_x = serialization.msgpack_restore(serialization.msgpack_serialize(x)) self.assertEqual(x, restored_x) def test_restore_chunked(self): @@ -246,10 +321,12 @@ def test_restore_chunked(self): def test_restore_unchunked(self): """Check if mgspack_restore works for unchunked inputs.""" + def msgpack_serialize_legacy(pytree): """Old implementation that was not chunking.""" - return msgpack.packb(pytree, default=serialization._msgpack_ext_pack, - strict_types=True) + return msgpack.packb( + pytree, default=serialization._msgpack_ext_pack, strict_types=True + ) tmp = np.random.uniform(-100, 100, size=(21, 37)) serialized = msgpack_serialize_legacy(tmp) @@ -276,16 +353,8 @@ def test_namedtuple_restore_legacy(self): x1 = foo_class(a=1, b=2, c=3) legacy_encoding = { 'name': 'Foo', - 'fields': { - '0': 'a', - '1': 'b', - '2': 'c' - }, - 'values': { - '0': 1, - '1': 2, - '2': 3 - }, + 'fields': {'0': 'a', '1': 'b', '2': 'c'}, + 'values': {'0': 1, '1': 2, '2': 3}, } x2 = foo_class(a=0, b=0, c=0) restored_x1 = serialization.from_state_dict(x2, legacy_encoding) @@ -323,14 +392,8 @@ def test_serialization_chunking(self): ref = { 'a': { '__msgpack_chunked_array__': (), - 'chunks': { - '0': (91,), - '1': (9,) - }, - 'shape': { - '0': (), - '1': () - } + 'chunks': {'0': (91,), '1': (9,)}, + 'shape': {'0': (), '1': ()}, } } self.assertEqual(test, ref) @@ -359,33 +422,62 @@ def test_serialization_chunking3(self): jax.tree_map(np.testing.assert_array_equal, tmp, newtmp) @parameterized.parameters( - {'target': [[[1, 2, 3], [4, 5]]], 'wrong_target': [[[1, 2, 3], [4]]], - 'msg': ('The size of the list and the state dict do not match,' - ' got 1 and 2 at path ./0/1')}, - {'target': (((1, 2, 3), (4, 5)),), - 'wrong_target': (((1, 2, 3), (4,)),), - 'msg': ('The size of the list and the state dict do not match,' - ' got 1 and 2 at path ./0/1')}, - {'target': (((1, 2, 3), (OriginalTuple([4, 5]), 6)),), - 'wrong_target': (((1, 2, 3), (WrongTuple([4, 5]), 6)),), - 'msg': ("The field names of the state dict and the named tuple do " - "not match, got {'value'} and {'wrong_field'} at path ./0/1/0")}, - {'target': {'a': {'b': {'c': [1, 2, 3], 'd': [4, 5]}}}, - 'wrong_target': {'a': {'b': {'c': [1, 2, 3], 'd': [4]}}}, - 'msg': ('The size of the list and the state dict do not match,' - ' got 1 and 2 at path ./a/b/d')}, - {'target': {'a': {'b': {'c': [1, 2, 3], 'd': [4, 5]}}}, - 'wrong_target': {'a': {'b': {'c': [1, 2, 3], 'e': [4, 5]}}}, - 'msg': ("The target dict keys and state dict keys do not match, " - "target dict contains keys {'e'} which are not present in state dict at path ./a/b")}, - {'target': 'original_params', - 'wrong_target': 'wrong_params', - 'msg': ("The target dict keys and state dict keys do not match, " - "target dict contains keys {'Dense_1'} which are not present in state dict at path ./params")}, - {'target': 'original_train_state', - 'wrong_target': 'wrong_train_state', - 'msg': ("The target dict keys and state dict keys do not match, " - "target dict contains keys {'Dense_1'} which are not present in state dict at path ./params/params")} + { + 'target': [[[1, 2, 3], [4, 5]]], + 'wrong_target': [[[1, 2, 3], [4]]], + 'msg': ( + 'The size of the list and the state dict do not match,' + ' got 1 and 2 at path ./0/1' + ), + }, + { + 'target': (((1, 2, 3), (4, 5)),), + 'wrong_target': (((1, 2, 3), (4,)),), + 'msg': ( + 'The size of the list and the state dict do not match,' + ' got 1 and 2 at path ./0/1' + ), + }, + { + 'target': (((1, 2, 3), (OriginalTuple([4, 5]), 6)),), + 'wrong_target': (((1, 2, 3), (WrongTuple([4, 5]), 6)),), + 'msg': ( + 'The field names of the state dict and the named tuple do ' + "not match, got {'value'} and {'wrong_field'} at path ./0/1/0" + ), + }, + { + 'target': {'a': {'b': {'c': [1, 2, 3], 'd': [4, 5]}}}, + 'wrong_target': {'a': {'b': {'c': [1, 2, 3], 'd': [4]}}}, + 'msg': ( + 'The size of the list and the state dict do not match,' + ' got 1 and 2 at path ./a/b/d' + ), + }, + { + 'target': {'a': {'b': {'c': [1, 2, 3], 'd': [4, 5]}}}, + 'wrong_target': {'a': {'b': {'c': [1, 2, 3], 'e': [4, 5]}}}, + 'msg': ( + 'The target dict keys and state dict keys do not match, ' + "target dict contains keys {'e'} which are not present in state dict at path ./a/b" + ), + }, + { + 'target': 'original_params', + 'wrong_target': 'wrong_params', + 'msg': ( + 'The target dict keys and state dict keys do not match, ' + "target dict contains keys {'Dense_1'} which are not present in state dict at path ./params" + ), + }, + { + 'target': 'original_train_state', + 'wrong_target': 'wrong_train_state', + 'msg': ( + 'The target dict keys and state dict keys do not match, ' + "target dict contains keys {'Dense_1'} which are not present in state dict at path ./params/params" + ), + }, ) def test_serialization_errors(self, target, wrong_target, msg): if target == 'original_params': @@ -406,9 +498,11 @@ def test_serialization_errors(self, target, wrong_target, msg): tx = optax.sgd(learning_rate=0.1, momentum=0.9) target = train_state.TrainState.create( - apply_fn=original_module.apply, params=original_params, tx=tx) + apply_fn=original_module.apply, params=original_params, tx=tx + ) wrong_target = train_state.TrainState.create( - apply_fn=wrong_module.apply, params=wrong_params, tx=tx) + apply_fn=wrong_module.apply, params=wrong_params, tx=tx + ) encoded_bytes = serialization.to_bytes(target) with self.assertRaisesWithLiteralMatch(ValueError, msg): diff --git a/tests/struct_test.py b/tests/struct_test.py index 12049a6d52..122afd5d3d 100644 --- a/tests/struct_test.py +++ b/tests/struct_test.py @@ -60,12 +60,11 @@ def test_pytree_nodes(self): def test_keypath_error(self): # TODO(mattjj): avoid using internal prefix_errors by testing vmap error msg - e, = prefix_errors(Point(1., [2.], meta={}), Point(1., 2., meta={})) + (e,) = prefix_errors(Point(1.0, [2.0], meta={}), Point(1.0, 2.0, meta={})) with self.assertRaisesRegex(ValueError, r'in_axes\.y'): raise e('in_axes') def test_double_wrap_no_op(self): - class A: a: int @@ -74,7 +73,7 @@ class A: A = struct.dataclass(A) self.assertTrue(hasattr(A, '_flax_dataclass')) - A = struct.dataclass(A) # no-op + A = struct.dataclass(A) # no-op self.assertTrue(hasattr(A, '_flax_dataclass')) def test_wrap_pytree_node_no_error(self): @@ -82,5 +81,6 @@ def test_wrap_pytree_node_no_error(self): class A(struct.PyTreeNode): a: int + if __name__ == '__main__': absltest.main() diff --git a/tests/tensorboard_test.py b/tests/tensorboard_test.py index 91b02b2b7f..a1f0161a33 100644 --- a/tests/tensorboard_test.py +++ b/tests/tensorboard_test.py @@ -27,9 +27,10 @@ from flax.metrics.tensorboard import SummaryWriter, _flatten_dict + def _process_event(event): for value in event.summary.value: - yield {'wall_time': event.wall_time, 'step': event.step, 'value': value} + yield {"wall_time": event.wall_time, "step": event.step, "value": value} def _disk_usage(path: pathlib.Path): @@ -52,16 +53,18 @@ def parse_and_return_summary_value(self, path): only summary value.""" event_value_list = [] event_file_generator = directory_watcher.DirectoryWatcher( - path, event_file_loader.EventFileLoader).Load() + path, event_file_loader.EventFileLoader + ).Load() event_values = itertools.chain.from_iterable( - map(_process_event, event_file_generator)) + map(_process_event, event_file_generator) + ) for value_dict in event_values: event_value_list.append(value_dict) self.assertLen(event_value_list, 1) - self.assertEqual(event_value_list[0]['step'], 1) - self.assertGreater(event_value_list[0]['wall_time'], 0.0) - return event_value_list[0]['value'] + self.assertEqual(event_value_list[0]["step"], 1) + self.assertGreater(event_value_list[0]["wall_time"], 0.0) + return event_value_list[0]["value"] def test_summarywriter_flush_after_close(self): log_dir = tempfile.mkdtemp() @@ -75,62 +78,61 @@ def test_summarywriter_scalar(self): summary_writer = SummaryWriter(log_dir=log_dir) # Write the scalar and check if the event exists and check data. float_value = 99.1232 - summary_writer.scalar(tag='scalar_test', value=float_value, step=1) + summary_writer.scalar(tag="scalar_test", value=float_value, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'scalar_test') - self.assertTrue(np.allclose( - tensor_util.make_ndarray(summary_value.tensor).item(), - float_value)) + self.assertEqual(summary_value.tag, "scalar_test") + self.assertTrue( + np.allclose(tensor_util.make_ndarray(summary_value.tensor).item(), float_value) + ) def test_summarywriter_text(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - text = 'hello world.' - summary_writer.text(tag='text_test', textdata=text, step=1) + text = "hello world." + summary_writer.text(tag="text_test", textdata=text, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'text_test') + self.assertEqual(summary_value.tag, "text_test") self.assertEqual( - tensor_util.make_ndarray(summary_value.tensor).item().decode('utf-8'), - text) + tensor_util.make_ndarray(summary_value.tensor).item().decode("utf-8"), text + ) def test_summarywriter_image(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - expected_img = np.random.uniform(low=0., high=255., size=(30, 30, 3)) + expected_img = np.random.uniform(low=0.0, high=255.0, size=(30, 30, 3)) expected_img = expected_img.astype(np.uint8) - summary_writer.image(tag='image_test', image=expected_img, step=1) + summary_writer.image(tag="image_test", image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'image_test') + self.assertEqual(summary_value.tag, "image_test") actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) self.assertTrue(np.allclose(actual_img, expected_img)) def test_summarywriter_image_float_pixel_values(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - expected_img = np.random.uniform(low=0., high=1., size=(30, 30, 3)) - summary_writer.image(tag='image_test', image=expected_img, step=1) + expected_img = np.random.uniform(low=0.0, high=1.0, size=(30, 30, 3)) + summary_writer.image(tag="image_test", image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) # convert and scale expected_img appropriately to numpy uint8. - expected_img = tf.image.convert_image_dtype( - image=expected_img, dtype=np.uint8) + expected_img = tf.image.convert_image_dtype(image=expected_img, dtype=np.uint8) - self.assertEqual(summary_value.tag, 'image_test') + self.assertEqual(summary_value.tag, "image_test") actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) self.assertTrue(np.allclose(actual_img, expected_img)) def test_summarywriter_2dimage_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - img = np.random.uniform(low=0., high=255., size=(30, 30)) + img = np.random.uniform(low=0.0, high=255.0, size=(30, 30)) img = img.astype(np.uint8) - summary_writer.image(tag='2dimage_test', image=img, step=1) + summary_writer.image(tag="2dimage_test", image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, '2dimage_test') + self.assertEqual(summary_value.tag, "2dimage_test") actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) # assert the image was increased in dimension self.assertEqual(actual_img.shape, (30, 30, 3)) @@ -138,12 +140,12 @@ def test_summarywriter_2dimage_scaled(self): def test_summarywriter_single_channel_image_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - img = np.random.uniform(low=0., high=255., size=(30, 30, 1)) + img = np.random.uniform(low=0.0, high=255.0, size=(30, 30, 1)) img = img.astype(np.uint8) - summary_writer.image(tag='2dimage_1channel_test', image=img, step=1) + summary_writer.image(tag="2dimage_1channel_test", image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, '2dimage_1channel_test') + self.assertEqual(summary_value.tag, "2dimage_1channel_test") actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) # assert the image was increased in dimension self.assertEqual(actual_img.shape, (30, 30, 3)) @@ -151,66 +153,61 @@ def test_summarywriter_single_channel_image_scaled(self): def test_summarywriter_multiple_images(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - expected_img = np.random.uniform(low=0., high=255., size=(2, 30, 30, 3)) + expected_img = np.random.uniform(low=0.0, high=255.0, size=(2, 30, 30, 3)) expected_img = expected_img.astype(np.uint8) - summary_writer.image(tag='multiple_images_test', image=expected_img, step=1) + summary_writer.image(tag="multiple_images_test", image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'multiple_images_test') - actual_imgs = [tf.image.decode_image(s) - for s in summary_value.tensor.string_val[2:]] + self.assertEqual(summary_value.tag, "multiple_images_test") + actual_imgs = [ + tf.image.decode_image(s) for s in summary_value.tensor.string_val[2:] + ] self.assertTrue(np.allclose(np.stack(actual_imgs, axis=0), expected_img)) def test_summarywriter_multiple_2dimages_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - img = np.random.uniform(low=0., high=255., size=(2, 30, 30)) + img = np.random.uniform(low=0.0, high=255.0, size=(2, 30, 30)) img = img.astype(np.uint8) - summary_writer.image(tag='multiple_2dimages_test', image=img, step=1) + summary_writer.image(tag="multiple_2dimages_test", image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'multiple_2dimages_test') - actual_imgs = [tf.image.decode_image(s) - for s in summary_value.tensor.string_val[2:]] + self.assertEqual(summary_value.tag, "multiple_2dimages_test") + actual_imgs = [ + tf.image.decode_image(s) for s in summary_value.tensor.string_val[2:] + ] # assert the images were increased in dimension self.assertEqual(np.stack(actual_imgs, axis=0).shape, (2, 30, 30, 3)) def test_summarywriter_audio(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - expected_audio_samples = np.random.uniform( - low=-1., high=1., size=(2, 48000, 2)) - summary_writer.audio( - tag='audio_test', audiodata=expected_audio_samples, step=1) + expected_audio_samples = np.random.uniform(low=-1.0, high=1.0, size=(2, 48000, 2)) + summary_writer.audio(tag="audio_test", audiodata=expected_audio_samples, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'audio_test') + self.assertEqual(summary_value.tag, "audio_test") # Assert two audio files are parsed. self.assertLen(summary_value.tensor.string_val, 2) # Assert values. - actual_audio_1 = tf.audio.decode_wav( - summary_value.tensor.string_val[0]).audio - self.assertTrue(np.allclose( - expected_audio_samples[0], actual_audio_1, atol=1e-04)) + actual_audio_1 = tf.audio.decode_wav(summary_value.tensor.string_val[0]).audio + self.assertTrue(np.allclose(expected_audio_samples[0], actual_audio_1, atol=1e-04)) - actual_audio_2 = tf.audio.decode_wav( - summary_value.tensor.string_val[1]).audio - self.assertTrue(np.allclose( - expected_audio_samples[1], actual_audio_2, atol=1e-04)) + actual_audio_2 = tf.audio.decode_wav(summary_value.tensor.string_val[1]).audio + self.assertTrue(np.allclose(expected_audio_samples[1], actual_audio_2, atol=1e-04)) def test_summarywriter_audio_sampled_output(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - expected_audio_samples = np.random.uniform( - low=-1., high=1., size=(2, 48000, 2)) + expected_audio_samples = np.random.uniform(low=-1.0, high=1.0, size=(2, 48000, 2)) summary_writer.audio( - tag='audio_test', audiodata=expected_audio_samples, step=1, - max_outputs=1) + tag="audio_test", audiodata=expected_audio_samples, step=1, max_outputs=1 + ) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'audio_test') + self.assertEqual(summary_value.tag, "audio_test") # Assert only the first audio clip is available. self.assertLen(summary_value.tensor.string_val, 1) @@ -218,113 +215,124 @@ def test_summarywriter_audio_sampled_output(self): # Assert values. actual_audio = tf.audio.decode_wav(summary_value.tensor.string_val[0]).audio - self.assertTrue(np.allclose( - expected_audio_samples[0], actual_audio, atol=1e-04)) + self.assertTrue(np.allclose(expected_audio_samples[0], actual_audio, atol=1e-04)) def test_summarywriter_clipped_audio(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) - expected_audio_samples = np.random.uniform( - low=-2., high=2., size=(2, 48000, 2)) + expected_audio_samples = np.random.uniform(low=-2.0, high=2.0, size=(2, 48000, 2)) summary_writer.audio( - tag='audio_test', audiodata=expected_audio_samples, step=1, - max_outputs=1) + tag="audio_test", audiodata=expected_audio_samples, step=1, max_outputs=1 + ) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'audio_test') + self.assertEqual(summary_value.tag, "audio_test") # Assert one audio files is parsed. self.assertLen(summary_value.tensor.string_val, 1) # actual_audio is clipped. - actual_audio = tf.audio.decode_wav( - summary_value.tensor.string_val[0]).audio - self.assertFalse(np.allclose( - expected_audio_samples[0], actual_audio, atol=1e-04)) + actual_audio = tf.audio.decode_wav(summary_value.tensor.string_val[0]).audio + self.assertFalse(np.allclose(expected_audio_samples[0], actual_audio, atol=1e-04)) clipped_audio = np.clip(np.array(expected_audio_samples[0]), -1, 1) - self.assertTrue( - np.allclose(clipped_audio, actual_audio, atol=1e-04)) + self.assertTrue(np.allclose(clipped_audio, actual_audio, atol=1e-04)) def test_summarywriter_histogram_defaultbins(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) histogram = np.arange(1000) # Histogram will be created for 30 (default) bins. - summary_writer.histogram(tag='histogram_test', values=histogram, step=1) + summary_writer.histogram(tag="histogram_test", values=histogram, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'histogram_test') + self.assertEqual(summary_value.tag, "histogram_test") actual_histogram = tensor_util.make_ndarray(summary_value.tensor) self.assertTrue(actual_histogram.shape, (30, 3)) - self.assertTrue( - np.allclose(actual_histogram[0], (0.0, 33.3, 34.0), atol=1e-01)) + self.assertTrue(np.allclose(actual_histogram[0], (0.0, 33.3, 34.0), atol=1e-01)) def test_summarywriter_histogram_2bins(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) histogram = np.arange(1000) - summary_writer.histogram( - tag='histogram_test', values=histogram, step=1, bins=2) + summary_writer.histogram(tag="histogram_test", values=histogram, step=1, bins=2) summary_value = self.parse_and_return_summary_value(path=log_dir) - self.assertEqual(summary_value.tag, 'histogram_test') + self.assertEqual(summary_value.tag, "histogram_test") actual_histogram = tensor_util.make_ndarray(summary_value.tensor) self.assertTrue(actual_histogram.shape, (2, 3)) - self.assertTrue( - np.allclose(actual_histogram[0], (0.0, 499.5, 500.0), atol=1e-01)) - self.assertTrue( - np.allclose(actual_histogram[1], (499.5, 999.0, 500.0), atol=1e-01)) + self.assertTrue(np.allclose(actual_histogram[0], (0.0, 499.5, 500.0), atol=1e-01)) + self.assertTrue(np.allclose(actual_histogram[1], (499.5, 999.0, 500.0), atol=1e-01)) def test_flatten_dict(self): # Valid types according to https://github.com/tensorflow/tensorboard/blob/1204566da5437af55109f7a4af18f9f8b7c4f864/tensorboard/plugins/hparams/summary_v2.py - input_hparams={ - # Example Invalid Types - "None": None, "List": [1, 2, 3], "Tuple": (1, 2, 3), "Complex": complex("1+1j"), "np.complex_": np.complex_("1+1j"), - # Valid Python Types - "Bool": True, "Int": 1, "Float": 1.0, "Str": "test", - # Valid Numpy Types - "np.bool_": np.bool_(1), "np.integer": np.int_(1), "np.floating": np.float_(1.0), "np.character": np.str_("test"), - # Nested dict to flatten - "Nested_Dict": { + input_hparams = { + # Example Invalid Types "None": None, "List": [1, 2, 3], "Tuple": (1, 2, 3), "Complex": complex("1+1j"), "np.complex_": np.complex_("1+1j"), + # Valid Python Types "Bool": True, "Int": 1, "Float": 1.0, "Str": "test", + # Valid Numpy Types "np.bool_": np.bool_(1), "np.integer": np.int_(1), "np.floating": np.float_(1.0), - "np.character": np.str_("test") - } + "np.character": np.str_("test"), + # Nested dict to flatten + "Nested_Dict": { + "None": None, + "List": [1, 2, 3], + "Tuple": (1, 2, 3), + "Complex": complex("1+1j"), + "np.complex_": np.complex_("1+1j"), + "Bool": True, + "Int": 1, + "Float": 1.0, + "Str": "test", + "np.bool_": np.bool_(1), + "np.integer": np.int_(1), + "np.floating": np.float_(1.0), + "np.character": np.str_("test"), + }, } result_hparams = _flatten_dict(input_hparams) - expected_hparams={ - "None": "None", "List": "[1, 2, 3]", "Tuple": "(1, 2, 3)", "Complex": "(1+1j)", "np.complex_": "(1+1j)", - # Valid Python Types - "Bool": True, "Int": 1, "Float": 1.0, "Str": "test", - # Valid Numpy Types - "np.bool_": np.bool_(1), "np.integer": np.int_(1), "np.floating": np.float_(1.0), "np.character": np.str_("test"), - # Nested Dict - "Nested_Dict.None": "None", - "Nested_Dict.List": "[1, 2, 3]", - "Nested_Dict.Tuple": "(1, 2, 3)", - "Nested_Dict.Complex": "(1+1j)", - "Nested_Dict.np.complex_": "(1+1j)", - "Nested_Dict.Bool": True, - "Nested_Dict.Int": 1, - "Nested_Dict.Float": 1.0, - "Nested_Dict.Str": "test", - "Nested_Dict.np.bool_": np.bool_(1), - "Nested_Dict.np.integer": np.int_(1), - "Nested_Dict.np.floating": np.float_(1.0), - "Nested_Dict.np.character": np.str_("test") + expected_hparams = { + "None": "None", + "List": "[1, 2, 3]", + "Tuple": "(1, 2, 3)", + "Complex": "(1+1j)", + "np.complex_": "(1+1j)", + # Valid Python Types + "Bool": True, + "Int": 1, + "Float": 1.0, + "Str": "test", + # Valid Numpy Types + "np.bool_": np.bool_(1), + "np.integer": np.int_(1), + "np.floating": np.float_(1.0), + "np.character": np.str_("test"), + # Nested Dict + "Nested_Dict.None": "None", + "Nested_Dict.List": "[1, 2, 3]", + "Nested_Dict.Tuple": "(1, 2, 3)", + "Nested_Dict.Complex": "(1+1j)", + "Nested_Dict.np.complex_": "(1+1j)", + "Nested_Dict.Bool": True, + "Nested_Dict.Int": 1, + "Nested_Dict.Float": 1.0, + "Nested_Dict.Str": "test", + "Nested_Dict.np.bool_": np.bool_(1), + "Nested_Dict.np.integer": np.int_(1), + "Nested_Dict.np.floating": np.float_(1.0), + "Nested_Dict.np.character": np.str_("test"), } self.assertDictEqual(result_hparams, expected_hparams) @@ -348,5 +356,5 @@ def test_no_auto_flush(self): self.assertLess(filesize_before_flush, filesize_after_flush) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tests/traceback_util_test.py b/tests/traceback_util_test.py index 0d98cc4817..d3a79c8d90 100644 --- a/tests/traceback_util_test.py +++ b/tests/traceback_util_test.py @@ -42,18 +42,22 @@ def test_exclusion_list(self): traceback_util.hide_flax_in_tracebacks() exclusion_len_w_flax = len(jax_traceback_util._exclude_paths) self.assertLen( - traceback_util._flax_exclusions, - exclusion_len_w_flax - exclusion_len_wo_flax) + traceback_util._flax_exclusions, exclusion_len_w_flax - exclusion_len_wo_flax + ) def test_simple_exclusion_tracebackhide(self): if not TRACEBACKHIDE_SUPPORTED: return + class Test1(nn.Module): + @nn.remat @nn.compact def __call__(self, x): return Test2()(x) + class Test2(nn.Module): + @nn.jit @nn.compact def __call__(self, x): @@ -81,14 +85,16 @@ def __call__(self, x): self.assertEqual(filtered_frames, 3) self.assertGreater(unfiltered_frames, filtered_frames) - def test_simple_exclusion_remove_frames(self): class Test1(nn.Module): + @nn.remat @nn.compact def __call__(self, x): return Test2()(x) + class Test2(nn.Module): + @nn.jit @nn.compact def __call__(self, x): @@ -119,18 +125,19 @@ def __call__(self, x): self.assertEqual(filtered_frames, 3) self.assertGreater(unfiltered_frames, filtered_frames) - def test_dynamic_exclusion(self): - if not TRACEBACKHIDE_SUPPORTED: return class Test1(nn.Module): + @nn.remat @nn.compact def __call__(self, x): return Test2()(x) + class Test2(nn.Module): + @nn.jit @nn.compact def __call__(self, x): @@ -185,10 +192,14 @@ def __call__(self, x): else: filtered_frames_w_flax += 1 - self.assertEqual(unfiltered_frames_all + filtered_frames_all, - unfiltered_frames_w_flax + filtered_frames_w_flax) - self.assertEqual(unfiltered_frames_all + filtered_frames_all, - unfiltered_frames_no_flax + filtered_frames_no_flax) + self.assertEqual( + unfiltered_frames_all + filtered_frames_all, + unfiltered_frames_w_flax + filtered_frames_w_flax, + ) + self.assertEqual( + unfiltered_frames_all + filtered_frames_all, + unfiltered_frames_no_flax + filtered_frames_no_flax, + ) self.assertEqual(unfiltered_frames_no_flax, 3) self.assertGreater(unfiltered_frames_all, unfiltered_frames_w_flax) self.assertGreater(unfiltered_frames_w_flax, unfiltered_frames_no_flax) diff --git a/tests/traverse_util_test.py b/tests/traverse_util_test.py index 16b72c768d..506088be65 100644 --- a/tests/traverse_util_test.py +++ b/tests/traverse_util_test.py @@ -103,8 +103,9 @@ def test_traverse_dataclass_attr(self): def test_traverse_merge(self): x = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}] traversal_base = traverse_util.t_identity.each() - traversal = traversal_base.merge(traverse_util.TraverseItem('foo'), - traverse_util.TraverseItem('bar')) + traversal = traversal_base.merge( + traverse_util.TraverseItem('foo'), traverse_util.TraverseItem('bar') + ) self.assertEqual(list(traversal.iterate(x)), [1, 2, 3, 4]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, [{'foo': 2, 'bar': 4}, {'foo': 6, 'bar': 8}]) @@ -150,60 +151,72 @@ def test_traversal_set(self): def test_flatten_dict(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs) - self.assertEqual(flat_xs, { - ('foo',): 1, - ('bar', 'a'): 2, - }) + self.assertEqual( + flat_xs, + { + ('foo',): 1, + ('bar', 'a'): 2, + }, + ) flat_xs = traverse_util.flatten_dict(freeze(xs)) - self.assertEqual(flat_xs, { - ('foo',): 1, - ('bar', 'a'): 2, - }) + self.assertEqual( + flat_xs, + { + ('foo',): 1, + ('bar', 'a'): 2, + }, + ) flat_xs = traverse_util.flatten_dict(xs, sep='/') - self.assertEqual(flat_xs, { - 'foo': 1, - 'bar/a': 2, - }) + self.assertEqual( + flat_xs, + { + 'foo': 1, + 'bar/a': 2, + }, + ) def test_unflatten_dict(self): - expected_xs = { - 'foo': 1, - 'bar': {'a': 2} - } + expected_xs = {'foo': 1, 'bar': {'a': 2}} xs = traverse_util.unflatten_dict({ - ('foo',): 1, - ('bar', 'a'): 2, + ('foo',): 1, + ('bar', 'a'): 2, }) self.assertEqual(xs, expected_xs) - xs = traverse_util.unflatten_dict({ - 'foo': 1, - 'bar/a': 2, - }, sep='/') + xs = traverse_util.unflatten_dict( + { + 'foo': 1, + 'bar/a': 2, + }, + sep='/', + ) self.assertEqual(xs, expected_xs) def test_flatten_dict_keep_empty(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs, keep_empty_nodes=True) - self.assertEqual(flat_xs, { - ('foo',): 1, - ('bar', 'a'): 2, - ('bar', 'b'): traverse_util.empty_node, - }) + self.assertEqual( + flat_xs, + { + ('foo',): 1, + ('bar', 'a'): 2, + ('bar', 'b'): traverse_util.empty_node, + }, + ) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore) def test_flatten_dict_is_leaf(self): xs = {'foo': {'c': 4}, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict( - xs, - is_leaf=lambda k, x: len(k) == 1 and len(x) == 2) - self.assertEqual(flat_xs, { - ('foo', 'c'): 4, - ('bar',): { - 'a': 2, - 'b': {} + xs, is_leaf=lambda k, x: len(k) == 1 and len(x) == 2 + ) + self.assertEqual( + flat_xs, + { + ('foo', 'c'): 4, + ('bar',): {'a': 2, 'b': {}}, }, - }) + ) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore) @@ -235,13 +248,15 @@ def test_param_selection(self): 'kernel': 6, 'bias': 4, }, - 'z': {} + 'z': {}, }, } names = [] + def filter_fn(name, _): names.append(name) # track names passed to filter_fn for testing return 'kernel' in name + traversal = traverse_util.ModelParamTraversal(filter_fn) values = list(traversal.iterate(params)) @@ -251,59 +266,73 @@ def filter_fn(name, _): ] for model, expected_model in configs: self.assertEqual(values, [1, 3]) - self.assertEqual(set(names), set([ - '/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias'])) + self.assertEqual( + set(names), set(['/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias']) + ) new_model = traversal.update(lambda x: x + x, model) self.assertEqual(new_model, expected_model) def test_path_value(self): params_in = {'a': {'b': 10, 'c': 2}} params_out = traverse_util.path_aware_map( - lambda path, x: x + 1 if 'b' in path else -x, params_in) + lambda path, x: x + 1 if 'b' in path else -x, params_in + ) self.assertEqual(params_out, {'a': {'b': 11, 'c': -2}}) def test_path_aware_map_with_multi_transform(self): - params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, - 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} + params = { + 'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, + 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}, + } gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients param_labels = traverse_util.path_aware_map( - lambda path, x: 'kernel' if 'w' in path else 'bias', params) + lambda path, x: 'kernel' if 'w' in path else 'bias', params + ) tx = optax.multi_transform( - {'kernel': optax.sgd(1.0), 'bias': optax.set_to_zero()}, param_labels) + {'kernel': optax.sgd(1.0), 'bias': optax.set_to_zero()}, param_labels + ) state = tx.init(params) updates, new_state = tx.update(gradients, state, params) new_params = optax.apply_updates(params, updates) - self.assertTrue(np.allclose(new_params['linear_1']['b'], params['linear_1']['b'])) self.assertTrue(np.allclose(new_params['linear_2']['b'], params['linear_2']['b'])) self.assertFalse(np.allclose(new_params['linear_1']['w'], params['linear_1']['w'])) self.assertFalse(np.allclose(new_params['linear_2']['w'], params['linear_2']['w'])) def test_path_aware_map_with_masked(self): - params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, - 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} + params = { + 'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, + 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}, + } gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients - params_mask = traverse_util.path_aware_map( - lambda path, x: 'w' in path, params) + params_mask = traverse_util.path_aware_map(lambda path, x: 'w' in path, params) tx = optax.masked(optax.sgd(1.0), params_mask) state = tx.init(params) updates, new_state = tx.update(gradients, state, params) new_params = optax.apply_updates(params, updates) - - self.assertTrue(np.allclose(new_params['linear_1']['b'], gradients['linear_1']['b'])) - self.assertTrue(np.allclose(new_params['linear_2']['b'], gradients['linear_2']['b'])) - self.assertTrue(np.allclose(new_params['linear_1']['w'], -gradients['linear_1']['w'])) - self.assertTrue(np.allclose(new_params['linear_2']['w'], -gradients['linear_2']['w'])) + self.assertTrue( + np.allclose(new_params['linear_1']['b'], gradients['linear_1']['b']) + ) + self.assertTrue( + np.allclose(new_params['linear_2']['b'], gradients['linear_2']['b']) + ) + self.assertTrue( + np.allclose(new_params['linear_1']['w'], -gradients['linear_1']['w']) + ) + self.assertTrue( + np.allclose(new_params['linear_2']['w'], -gradients['linear_2']['w']) + ) def test_path_aware_map_with_empty_nodes(self): params_in = {'a': {'b': 10, 'c': 2}, 'b': {}} params_out = traverse_util.path_aware_map( - lambda path, x: x + 1 if 'b' in path else -x, params_in) + lambda path, x: x + 1 if 'b' in path else -x, params_in + ) self.assertEqual(params_out, {'a': {'b': 11, 'c': -2}, 'b': {}}) From 97d038ccfeef101dd4c72486ec12c309402a9db8 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 21 Jul 2023 19:01:25 +0000 Subject: [PATCH 3/3] fix mypy + add .git-blame-ignore-revs --- .git-blame-ignore-revs | 2 ++ flax/linen/attention.py | 4 ++-- flax/linen/module.py | 14 ++++++-------- flax/linen/recurrent.py | 20 ++++++++------------ 4 files changed, 18 insertions(+), 22 deletions(-) create mode 100644 .git-blame-ignore-revs diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000000..ec1e14033b --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# apply pyink +40a6e074e5224d733f964be00e21e0a1cb98bd2e \ No newline at end of file diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 28b277b5bc..98a15dfcf7 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -381,10 +381,10 @@ class SelfAttention(MultiHeadDotProductAttention): """Self-attention special case of multi-head dot-product attention.""" @compact - def __call__( + def __call__( # type: ignore self, inputs_q: Array, - mask: Optional[Array] = None, # type: ignore + mask: Optional[Array] = None, deterministic: Optional[bool] = None, ): """Applies multi-head dot product self-attention on the input data. diff --git a/flax/linen/module.py b/flax/linen/module.py index e3a22fc078..2567124e04 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -828,13 +828,11 @@ def _customized_dataclass_transform(cls, kw_only: bool): for name, annotation, default in extra_fields: # pytype: disable=invalid-annotation setattr(cls, name, default) cls.__annotations__[name] = annotation - dataclasses.dataclass( + dataclasses.dataclass( # type: ignore[call-overload] unsafe_hash='__hash__' not in cls.__dict__, repr=False, kw_only=True, - )( - cls - ) # type: ignore[call-overload] + )(cls) else: raise TypeError('`kw_only` is not available before Py 3.10.') else: @@ -1900,8 +1898,8 @@ def sow( name: str, value: T, reduce_fn: Callable[[K, T], K] = tuple_reduce, - init_fn: Callable[[], K] = tuple_init, - ) -> bool: # type: ignore + init_fn: Callable[[], K] = tuple_init, # type: ignore + ) -> bool: ... def sow( @@ -1910,8 +1908,8 @@ def sow( name: str, value: T, reduce_fn: Callable[[K, T], K] = tuple_reduce, - init_fn: Callable[[], K] = tuple_init, - ) -> bool: # type: ignore + init_fn: Callable[[], K] = tuple_init, # type: ignore + ) -> bool: """Stores a value in a collection. Collections can be used to collect intermediate values without diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 8506e22321..1a676347e8 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -342,20 +342,16 @@ def _concat_dense( param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, - name=f'i{component}', - )( - inputs - ) # type: ignore[call-arg] + name=f'i{component}', # type: ignore[call-arg] + )(inputs) dense_params_h[component] = DenseParams( features=hidden_features, use_bias=True, param_dtype=self.param_dtype, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init, - name=f'h{component}', - )( - h - ) # type: ignore[call-arg] + name=f'h{component}', # type: ignore[call-arg] + )(h) dense_h = _concat_dense(h, dense_params_h, use_bias=True) dense_i = _concat_dense(inputs, dense_params_i, use_bias=False) @@ -809,8 +805,8 @@ def __call__( if reverse: inputs = jax.tree_map( lambda x: flip_sequences( - x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major - ), # type: ignore + x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major # type: ignore + ), inputs, ) @@ -867,8 +863,8 @@ def scan_fn( if reverse and keep_order: outputs = jax.tree_map( lambda x: flip_sequences( - x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major - ), # type: ignore + x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major # type: ignore + ), outputs, )