Skip to content

Commit

Permalink
lint3
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-gorner committed Dec 20, 2024
1 parent 64293de commit d635f47
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 0 additions & 1 deletion keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name):
layout = _to_jax_layout(layout)

num_model_replicas_total = layout.mesh.shape[batch_dim_name]
mesh_shape = list(layout.mesh.shape.values())

mesh_model_dim_size = 1
for name, dim_size in layout.mesh.shape.items():
Expand Down
4 changes: 3 additions & 1 deletion keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def test_distribute_data_input(self):
mesh, jax.sharding.PartitionSpec("batch", None)
)

result = backend_dlib.distribute_data_input(per_process_batch, layout, "batch")
result = backend_dlib.distribute_data_input(
per_process_batch, layout, "batch"
)

# Check the shape of the global batch array
self.assertEqual(
Expand Down

0 comments on commit d635f47

Please sign in to comment.