From d635f472105395ab2db79b9a2878221b56ef54f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Fri, 20 Dec 2024 14:59:02 +0100 Subject: [PATCH] lint3 --- keras/src/backend/jax/distribution_lib.py | 1 - keras/src/backend/jax/distribution_lib_test.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 124027938dc..5dc5c057d29 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -118,7 +118,6 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name): layout = _to_jax_layout(layout) num_model_replicas_total = layout.mesh.shape[batch_dim_name] - mesh_shape = list(layout.mesh.shape.values()) mesh_model_dim_size = 1 for name, dim_size in layout.mesh.shape.items(): diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py index 2605f058487..81ceddfd305 100644 --- a/keras/src/backend/jax/distribution_lib_test.py +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -337,7 +337,9 @@ def test_distribute_data_input(self): mesh, jax.sharding.PartitionSpec("batch", None) ) - result = backend_dlib.distribute_data_input(per_process_batch, layout, "batch") + result = backend_dlib.distribute_data_input( + per_process_batch, layout, "batch" + ) # Check the shape of the global batch array self.assertEqual(