-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tests for distributed #1196
base: main
Are you sure you want to change the base?
Tests for distributed #1196
Conversation
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
As mentioned by @ptrendx, we'll need to include these tests in one of the QA scripts (see |
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…merEngine into distributed_tests
|
||
WORLD_RANK = dist.get_rank() | ||
WORLD_SIZE = dist.get_world_size() | ||
assert WORLD_SIZE == 2, "This test uses 2 GPUs. Run with torchrun --nproc_per_node=2." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_numerics.py
can launch with a different number of GPUs:
NUM_PROCS: int = min(torch.cuda.device_count(), 4) |
assert WORLD_SIZE == 2, "This test uses 2 GPUs. Run with torchrun --nproc_per_node=2." |
class HalfGradient(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input): | ||
ctx.save_for_backward(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed?
ctx.save_for_backward(input) |
output_failed, output_info = _compare_tensors( | ||
"outputs", output_distributed, output_single_node, rtol, atol | ||
) | ||
dist_print(output_info, src=WORLD_RANK, error=output_failed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Printing successful checks might be too verbose:
dist_print(output_info, src=WORLD_RANK, error=output_failed) | |
if output_failed: | |
dist_print(output_info, src=WORLD_RANK, error=True) |
rtol = 0.125 if FP8 else 0.025 | ||
atol = 0.0625 if FP8 else 0.00125 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we would use the tightest tolerances possible for the dtype, like in torch.testing.assert_close
. I see that the tensor dimensions are small (~64), so we should be able to get away with this.
return to_output | ||
|
||
|
||
def _check_gradients(model_distributed, model_single, main_grad_check=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar questions as in _check_outputs
.
for kwargs in kwargs_list: | ||
for parallel_mode in ["column", "row"]: | ||
for sequence_parallel in [False, True]: | ||
_test_linear(parallel_mode, sequence_parallel, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it helpful to print the test configs so we can narrow down in the case there's a test failure:
_test_linear(parallel_mode, sequence_parallel, **kwargs) | |
dist_print(f"_test_linear with {kwargs=}, {parallel_mode=}, {sequence_parallel=}") | |
_test_linear(parallel_mode, sequence_parallel, **kwargs) |
We'd want to do similar logging in all other layer tests.
@pytest.mark.parametrize("fp8", all_boolean) | ||
def test_linear(fp8): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a separate test for each layer is nice for error reporting, but it adds a few seconds overhead to launch the parallel jobs. It's more scalable to launch a single parallel job and to test all layers and layer configurations internally.
Description
I am working on debug API. Before it can be merged, it needs to be tested. We need to ensure that all the new layers will work properly in the distributed case also. The tests present in the repo focus on testing gemm/comm overlapping - the tests I want to add focus more on checking numerical corectness of multiple configurations of TE layers. Moreover, current
Type of change
Changes
Checklist: