diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp index 50e90afa90..e8d07e271d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp @@ -101,24 +101,47 @@ void nccl_comm_init_rank( "ncclCommInitRank"); } -void nccl_allgather(at::Tensor dst, at::Tensor src, int64_t comm_idx) { - using namespace c10d; - TORCH_CHECK(src.is_contiguous()); - TORCH_CHECK(dst.is_contiguous()); - ncclDataType_t type; - switch (src.scalar_type()) { +ncclDataType_t to_nccl_data_type(c10::ScalarType type) { + switch (type) { case at::kFloat: - type = ncclDataType_t::ncclFloat; - break; + return ncclDataType_t::ncclFloat; case at::kHalf: - type = ncclDataType_t::ncclHalf; - break; + return ncclDataType_t::ncclHalf; + case at::kDouble: + return ncclDataType_t::ncclDouble; + case at::kLong: + return ncclDataType_t::ncclInt64; + case at::kInt: + return ncclDataType_t::ncclInt; + case at::kChar: + return ncclDataType_t::ncclChar; + case at::kByte: + return ncclDataType_t::ncclUint8; + case at::kBool: + return ncclDataType_t::ncclUint8; +#if defined(USE_ROCM) + case at::kFloat8_e4m3fnuz: + return ncclDataType_t::ncclUint8; + case at::kFloat8_e5m2fnuz: + return ncclDataType_t::ncclUint8; +#else + case at::kFloat8_e4m3fn: + return ncclDataType_t::ncclUint8; + case at::kFloat8_e5m2: + return ncclDataType_t::ncclUint8; +#endif case at::kBFloat16: - type = ncclDataType_t::ncclBfloat16; - break; + return ncclDataType_t::ncclBfloat16; default: - TORCH_CHECK(false, "unsupported type: ", src.scalar_type()); + TORCH_CHECK(false, "Unconvertible NCCL type ", type); } +} + +void nccl_allgather(at::Tensor dst, at::Tensor src, int64_t comm_idx) { + using namespace c10d; + TORCH_CHECK(src.is_contiguous()); + TORCH_CHECK(dst.is_contiguous()); + ncclDataType_t type = to_nccl_data_type(src.scalar_type()); C10D_NCCL_CHECK( ncclAllGather( src.data_ptr(), diff --git a/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py b/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py index 057651c0f4..47065972c4 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py @@ -19,6 +19,7 @@ import numpy as np import torch +from hypothesis import given, settings, strategies as st, Verbosity from torch.distributed.launcher.api import elastic_launch, LaunchConfig logger: logging.Logger = logging.getLogger() @@ -35,28 +36,24 @@ def has_nvswitch() -> bool: return "GRANDTETON" in model or "SUPERMICRO" in model -def _run_allgather_inner(rdvz: str) -> None: +def _run_allgather_inner(rdvz: str, dtype: torch.dtype) -> None: rank = int(os.environ["LOCAL_RANK"]) W = int(os.environ["WORLD_SIZE"]) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" torch.ops.fbgemm.nccl_init(rank, W, rdvz) - # torch.distributed.init_process_group(backend="nccl") B, T, D = 2, 4096, 1024 - y = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") + y = torch.empty(size=(B, T, D), dtype=dtype, device="cuda") y[:] = rank - y_gather = torch.zeros(size=(W, B, T, D), dtype=torch.bfloat16, device="cuda") - y_gather[:] = -1 + y_gather = torch.full(size=(W, B, T, D), fill_value=-1, dtype=dtype, device="cuda") # Here we test to confirm that allgather is compatible with torch.compile. torch.compile(torch.ops.fbgemm.nccl_allgather)(y_gather, y) for w in range(W): torch.testing.assert_close( y_gather[w], - torch.full( - size=(B, T, D), fill_value=w, dtype=torch.bfloat16, device=y.device - ), + torch.full(size=(B, T, D), fill_value=w, dtype=dtype, device=y.device), ) for _ in range(20): @@ -268,7 +265,30 @@ def round_up(a: int, b: int) -> int: "Skip when CUDA is not available or when there are not enough GPUs; these tests require at least two GPUs", ) class LLamaMultiGpuTests(unittest.TestCase): - def test_allgather(self) -> None: + @given( + dtype=st.sampled_from( + [ + torch.bfloat16, + torch.float16, + torch.int, + torch.long, + torch.float, + torch.float8_e4m3fn, + ] + ) + ) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=60000) + def test_allgather(self, dtype: torch.dtype) -> None: + # float8 is only supported in H100 or MI300x + if dtype == torch.float8_e4m3fn: + if torch.version.hip: + dtype = torch.float8_e4m3fnuz + elif torch.cuda.get_device_capability() < (9, 0): + self.skipTest( + "float8_e4m3fn is only supported in H100 or MI300x, but we're running " + f"on {torch.cuda.get_device_capability()}" + ) + with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as path: lc = LaunchConfig( min_nodes=1, @@ -283,7 +303,7 @@ def test_allgather(self) -> None: max_restarts=0, ) elastic_launch(config=lc, entrypoint=_run_allgather_inner)( - os.path.join(path, "rdvz") + os.path.join(path, "rdvz"), dtype ) def test_allreduce(self) -> None: