Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing batch_dim_name attribute #20674

Merged
merged 11 commits into from
Jan 7, 2025
12 changes: 8 additions & 4 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout):
return global_value


def distribute_data_input(per_process_batch, layout):
def distribute_data_input(per_process_batch, layout, batch_dim_name):
"""Distribute the input data with the corresponding layout.

Note that the inputs here is a local worker batch. Within the local worker,
Expand All @@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout):
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)

mesh_shape = list(layout.mesh.shape.values())
num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1
num_model_replicas_total = layout.mesh.shape[batch_dim_name]

mesh_model_dim_size = 1
for name, dim_size in layout.mesh.shape.items():
if not name == batch_dim_name:
mesh_model_dim_size *= dim_size

num_model_replicas_per_process = num_model_replicas_total / num_processes()
per_process_batch_size = per_process_batch.shape[0]

Expand Down
4 changes: 3 additions & 1 deletion keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def test_distribute_data_input(self):
mesh, jax.sharding.PartitionSpec("batch", None)
)

result = backend_dlib.distribute_data_input(per_process_batch, layout)
result = backend_dlib.distribute_data_input(
per_process_batch, layout, "batch"
)

# Check the shape of the global batch array
self.assertEqual(
Expand Down
8 changes: 6 additions & 2 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
from functools import partial

import jax
import numpy as np
Expand Down Expand Up @@ -988,15 +989,18 @@ def _get_jax_state(

def _distribute_data(data, layouts=None):
distribution = distribution_lib.distribution()

if distribution is not None:
if layouts is None:
layouts = tree.map_structure(
lambda d: distribution.get_data_layout(d.shape),
data,
)
return tree.map_structure(
jax_distribution_lib.distribute_data_input, data, layouts
jax_dist_data_input = partial(
jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution.batch_dim_name,
)
return tree.map_structure(jax_dist_data_input, data, layouts)

return tree.map_structure(jax.device_put, data)

Expand Down
20 changes: 12 additions & 8 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ class Distribution:
device_mesh: A `DeviceMesh` instance.
"""

def __init__(self, device_mesh):
def __init__(self, device_mesh, batch_dim_name=None):
self._device_mesh = device_mesh
self._batch_dim_name = batch_dim_name

def get_data_layout(self, data_shape):
"""Retrieve the `TensorLayout` for the input data.
Expand Down Expand Up @@ -341,6 +342,10 @@ def scope(self):
def device_mesh(self):
return self._device_mesh

@property
def batch_dim_name(self):
return self._batch_dim_name

def distribute_dataset(self, dataset):
"""Create a distributed dataset instance from the original user dataset.

Expand Down Expand Up @@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
else:
self._initialize_mesh_from_list_devices()

self._batch_dim_name = self.device_mesh.axis_names[0]
# Those following attributes might get convert to public methods.
self._num_process = distribution_lib.num_processes()
self._process_id = distribution_lib.process_id()
Expand All @@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh):
"Expect `mesh` to be an instance of `DeviceMesh`. "
f"Received: mesh={device_mesh} (of type {type(device_mesh)})"
)
super().__init__(device_mesh)
super().__init__(device_mesh, device_mesh.axis_names[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't device_mesh.axis_names[0] be DEFAULT_BATCH_DIM_NAME here too? Otherwise, we'll rely on the order to get the batch dim name again, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this either way when the user provides a device_mesh:

  • either treat the first dimension of device_mesh as the data sharding dimension, whatever its name
  • or require that one of the dimensions of device mesh be explicitly called "batch".
    Here, the docstring stated "In case that the mesh has multiple axes, then the first axis will be treated as the data parallel dimension" so I kept that behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to let users shard data on the second dimension in DataParallel, then we'll have to go for your solution. Is there a use case where it could be useful?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is DataParallel so the mesh is 1D anyway! Sounds good!

if self.device_mesh.devices.ndim != 1:
warnings.warn(
"Expect the input mesh to be 1D, but received "
Expand All @@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices):
axis_names=[DEFAULT_BATCH_DIM_NAME],
devices=devices,
)
super().__init__(device_mesh)
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)

def _initialize_mesh_from_list_devices(self):
devices = np.array(list_devices())
Expand All @@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self):
axis_names=[DEFAULT_BATCH_DIM_NAME],
devices=devices,
)
super().__init__(device_mesh)
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
Expand Down Expand Up @@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
Expand Down Expand Up @@ -631,7 +635,7 @@ def distribute_dataset(self, dataset):
# Note that this might be smaller than one if model replicas are sharded
# across multiple processes.
mesh_batch_dim_index = self.device_mesh.axis_names.index(
self._batch_dim_name
self.batch_dim_name
)
num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]
if num_model_replicas == 1:
Expand Down
6 changes: 3 additions & 3 deletions keras/src/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_create_with_device_mesh(self):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["data"])
self.assertEqual(distribution._batch_dim_name, "data")
self.assertEqual(distribution.batch_dim_name, "data")

self.assertFalse(distribution._is_multi_process)
self.assertEqual(distribution._process_id, 0)
Expand All @@ -197,7 +197,7 @@ def test_create_with_devices(self):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["batch"])
self.assertEqual(distribution._batch_dim_name, "batch")
self.assertEqual(distribution.batch_dim_name, "batch")

@mock.patch.object(
distribution_lib,
Expand All @@ -211,7 +211,7 @@ def test_create_with_list_devices(self, mock_list_devices):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["batch"])
self.assertEqual(distribution._batch_dim_name, "batch")
self.assertEqual(distribution.batch_dim_name, "batch")

def test_get_data_layout(self):
distribution = distribution_lib.DataParallel(
Expand Down
Loading