[CPU] Allow deepspeed.comm.inference_all_reduce in torch.compile graph #5604
+85
−19
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR allows
deepspeed.comm.inference_all_reduce()
enters torch.compile graph even it is implemented as C++ kernel in DeepSpeed.Previous implementation register
inference_all_reduce()
C++ kernel as pybind function so it can be called inside PyThon code. However pybind function cannot be recognized by PyTorch so graph breaks wheninference_all_reduce
is called.We address issue by register
inference_all_reduce
as a PyTorch custom optorch.ops.deepspeed.inference_all_reduce
, so it can be built into PyTorch graphThe output trace code from torchinductor
Note in this PR the inference_all_reduce op for CPU does not handle multinode and FP16 data type. For FP16 data type support, we will align with PyTorch CPU FP16 plan. For multinode, we are still looking at the possibility to upstream oneCCL integration into PyTorch, so we are able to get use of oneCCL for multinode tensor parallel inference with PyTorch.
This PR is independent to #5571. They can work seperately or together without issue.