diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index c4f1a71e1d4..124027938dc 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -120,16 +120,10 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name): num_model_replicas_total = layout.mesh.shape[batch_dim_name] mesh_shape = list(layout.mesh.shape.values()) - # TODO: THIS IS COMPLETELY WRONG AS WELL FOR REPLICATING DATA ON "MODEL" - # dimensions: there may be more than one and the index ins not always "1" - mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1 - - # TODO: proper fix for this quick and dirty hack - # this only works for 2D meshes mesh_model_dim_size = 1 for name, dim_size in layout.mesh.shape.items(): if not name == batch_dim_name: - mesh_model_dim_size = dim_size + mesh_model_dim_size *= dim_size num_model_replicas_per_process = num_model_replicas_total / num_processes() per_process_batch_size = per_process_batch.shape[0]