Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing batch_dim_name attribute #20674

Merged
merged 11 commits into from
Jan 7, 2025
Merged

Conversation

martin-gorner
Copy link
Contributor

@martin-gorner martin-gorner commented Dec 20, 2024

ModelParallel(batch_dim_name='batch') is currently dysfunctional and will work only if batch_dim_name corresponds to the first dimension of the mesh, which is the default anyway. There is also a problem for meshes with 3 and more dimensions.

Minimal repro 1 (showing error):
https://colab.research.google.com/drive/1jzmCZ2WNlKtD4j2heSaq-mxBoG-9WeeS?usp=sharing

Minimal repro 2 with a 3D mesh (showing error):
https://colab.research.google.com/drive/1AGku4hjwhTN_2h5yiU7Q-a6vvSrc8nRH

Real-world repro 1 (showing successful run with fix):
https://colab.research.google.com/drive/1cyn_XUFwdLUJE4pRNWPgZ2H5wzKzto-T?usp=sharing

Real-world repro 2 (showing a run without errors - but unfortunately no convergence):
https://colab.research.google.com/drive/1kY9qq27YxpowqYDT3gL98U5RuN6CYQ7b?usp=sharing

The use case is not just hypothetical.
With DeviceMesh((4,2), ("model", "batch")), fine-tuning proceeds at 147ms/step.
With DeviceMesh((2,4), ("batch", "model")), fine-tuning proceeds at 205ms/step.
The fix makes the first, faster use case work, as tested with the real-world repro 1 notebook on TPU v5e.

Remaining issues:

  • Real-world repro 2 does not show convergence. Maybe there is another bug in loss aggregation??
  • The fixes should work for combined data and model parallelism where the data is sharded along one axis and the model along a different set of axes. That is the assumption in backend/jax/distribution_lb.py:
    ** num_model_replicas_total = layout.mesh.shape[batch_dim_name] i.e. the number of model replicas is the nb of devices along the "batch" axis of the mesh
    ** mesh_model_dim_size computation: data is replicated as many times as there are unique model shards.
  • However, the default layout map for Gemma shards the model also along the "batch" dimension. This will work as long as the "batch" dimension is 1 but is useless in that case. When the "batch" dimension is >=2, I don't know what it means, i.e. how many model model replicas there are and therefore how input data should be split. The Keras team should chime in on this.

@codecov-commenter
Copy link

codecov-commenter commented Dec 20, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.83%. Comparing base (881d8da) to head (a7afa74).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #20674   +/-   ##
=======================================
  Coverage   81.83%   81.83%           
=======================================
  Files         552      552           
  Lines       51363    51370    +7     
  Branches     7944     7946    +2     
=======================================
+ Hits        42034    42041    +7     
  Misses       7375     7375           
  Partials     1954     1954           
Flag Coverage Δ
keras 81.66% <100.00%> (+<0.01%) ⬆️
keras-jax 63.95% <100.00%> (+<0.01%) ⬆️
keras-numpy 58.81% <26.31%> (-0.01%) ⬇️
keras-openvino 29.84% <26.31%> (+<0.01%) ⬆️
keras-tensorflow 64.62% <26.31%> (-0.01%) ⬇️
keras-torch 63.94% <26.31%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Dec 20, 2024
jax_distribution_lib.distribute_data_input, data, layouts
jax_dist_data_input = partial(
jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution._batch_dim_name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this without accessing the private variable _batch_dim_name? Could we consider passing the batch_dim_name as an argument to the relevant functions? Or, maybe the distribution object provides a public method or property to access the batch dimension name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll think of a cleaner way.

The goal at this point is to get a second pair of eyes on this fix and validate it is correct. See use cases at the end of the intro paragraph. Also, since you implemented the multi-host code, could you check if this fix does not break it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll run the internal multi-host test to make sure it still works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also need more tests. The failure is not in a complex case. This should have been covered by tests. I can add the a couple of tests on 8-core TPUs, but I'll let you extend them to multi-host settings.

But right now, what is your opinion on the case where model and data parallelism are used at the same time and the "batch" dimension is also a sharding dimension for the model, as is the default for Gemma and Llama? How should data batches be split in that case ? (And I don't think my fix covers that case - I'm not sure I understand how that case makes sense..).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the internal multi-host test with your changes and it passed! I think we should be able to merge this PR after updating private variable usage (_batch_dim_name).

Sharding along the batch dimension should work! Our multi-host tests test that and they pass! We test all the following configs:

  @parameterized.named_parameters([
      ("data_only", (8, 1), 2, False,),
      ("data_model", (4, 2), 2, False,),
      ("model_data", (2, 4), 4, False,),
      ("model_only", (1, 8), 8, True,),
  ])

Could you point me to the colab that shows sharding along the batch dimension doesn't work for a 2D mesh?

I think what is not supported yet is 3D+ mesh. I agree that this would be a great feature to have. Maybe we can create a feature request issue and plan for supporting it.

PS: US holidays will start tomorrow and I'll be back after the new year! Happy Holidays, Martin!

Copy link
Contributor Author

@martin-gorner martin-gorner Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the repro colabs are in the intro.

The tests may be passing but if we don't understand the use case, it could be by accident. The thing I do not understand and that the fix does not cover is:

num_model_replicas_total = layout.mesh.shape[batch_dim_name]
mesh_model_dim_size = nb_devices / num_model_replicas_total # not actual code but it amounts to this

It seems to me that these expressions assume that the model is NOT sharded on the 'batch' dimension. It is only when the model is replicated on the 'batch' dimension and sharded on all other dimensions that the expression num_model_replicas_total = layout.mesh.shape[batch_dim_name] is true. If the model is also sharded on the 'batch' dimension, I'm not sure how many model replicas there are ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model should not be sharded on the batch dimension. Model should only be sharded on the model dimension. Data should be sharded on the batch dimension. The reason that the number of model replicas is the same as the batch dimension (number of data shards) is that when we shard the data, for each shard, we need the full replication of the model to process that data shard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to agree with you but the default sharding for Gemma does shard the model on the "batch" dimension. See here. The code references this article as a rationale.

Copy link
Contributor Author

@martin-gorner martin-gorner Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the article and the only mention of sharding or replication I could find is this sentence "Within a pod, we use 16-way model sharding and 16-way data replication for the 7B model." It does not say anything about sharding the model on the batch dimension.

The sharding on the batch dimension was added by Scott in PR1491. The PR discussion says "The new setting is based on the Gemma training script internally". Can you find those scripts to investigate? Are there other places we could check this? Maybe some JAX Gemma fine-tuning scripts?

Anyway, in the short term, I think we should just remove the batch dimension from the default layout map and then safely assume that the model is NOT sharded on the batch dim. This should work, I think, even for 3D sharding (model sharded on "model" and "sequence" dims, while data is sharded on the "batch" dim).

@SamanehSaadat
Copy link
Member

SamanehSaadat commented Dec 20, 2024

Real-world repro 2 does not show convergence. Maybe there is another bug in loss aggregation?

@martin-gorner Could you clarify what you meant by 'no coverage' in this context?

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Dec 20, 2024

Real-world repro 2 does not show convergence. Maybe there is another bug in loss aggregation?

@martin-gorner Could you clarify what you meant by 'no coverage' in this context?

I meant "convergence", i.e. the loss is not decreasing.

@fchollet
Copy link
Collaborator

@SamanehSaadat is this LGTM?

@SamanehSaadat
Copy link
Member

@SamanehSaadat is this LGTM?

Not yet! I'll tag you when it's ready.

@fchollet
Copy link
Collaborator

fchollet commented Jan 4, 2025

Note that I will do a new release soon -- should this fix be in the release (i.e. is it blocking?)

@martin-gorner
Copy link
Contributor Author

Note that I will do a new release soon -- should this fix be in the release (i.e. is it blocking?)

Not blocking as long as you do not use the batch_dim_name parameter and the batch dimension is the first dimension of the device array (which is the default). But still pretty serious and sending the wrong message about Keras readiness for model parallel distributed computing.

I'll send a revised PR today fixing this and removing model sharding on the 'batch' dim from default layout maps. The mystery of why this was introduced in the first place in PR1491 will remain.

@martin-gorner
Copy link
Contributor Author

  • I fixed the the private variable _batch_dim_name, replaced it with a property
  • I filed Keras-hub PR 2035 to remove the "batch" dimension from default model shardings (they appeared in Gemma and Llama)

Copy link
Member

@SamanehSaadat SamanehSaadat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Martin! Just left a nit comment.

@@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh):
"Expect `mesh` to be an instance of `DeviceMesh`. "
f"Received: mesh={device_mesh} (of type {type(device_mesh)})"
)
super().__init__(device_mesh)
super().__init__(device_mesh, device_mesh.axis_names[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't device_mesh.axis_names[0] be DEFAULT_BATCH_DIM_NAME here too? Otherwise, we'll rely on the order to get the batch dim name again, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this either way when the user provides a device_mesh:

  • either treat the first dimension of device_mesh as the data sharding dimension, whatever its name
  • or require that one of the dimensions of device mesh be explicitly called "batch".
    Here, the docstring stated "In case that the mesh has multiple axes, then the first axis will be treated as the data parallel dimension" so I kept that behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to let users shard data on the second dimension in DataParallel, then we'll have to go for your solution. Is there a use case where it could be useful?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is DataParallel so the mesh is 1D anyway! Sounds good!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jan 6, 2025
@fchollet fchollet merged commit fbf0af7 into keras-team:master Jan 7, 2025
10 checks passed
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues size:S
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants