Skip to content

Commit

Permalink
added data sharding for 3D+ meshes
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-gorner committed Dec 20, 2024
1 parent a64e093 commit 64293de
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,10 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name):
num_model_replicas_total = layout.mesh.shape[batch_dim_name]
mesh_shape = list(layout.mesh.shape.values())

# TODO: THIS IS COMPLETELY WRONG AS WELL FOR REPLICATING DATA ON "MODEL"
# dimensions: there may be more than one and the index ins not always "1"
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1

# TODO: proper fix for this quick and dirty hack
# this only works for 2D meshes
mesh_model_dim_size = 1
for name, dim_size in layout.mesh.shape.items():
if not name == batch_dim_name:
mesh_model_dim_size = dim_size
mesh_model_dim_size *= dim_size

num_model_replicas_per_process = num_model_replicas_total / num_processes()
per_process_batch_size = per_process_batch.shape[0]
Expand Down

0 comments on commit 64293de

Please sign in to comment.