Skip to content

Commit

Permalink
custom allgather support multiple dtypes (#3498)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3498

X-link: facebookresearch/FBGEMM#576

Adding more dtype support for CAR allgather

Reviewed By: jasonjk-park

Differential Revision: D66941779

fbshipit-source-id: f54112907e4a3fda9f889dba94f32410306cde30
  • Loading branch information
xw285cornell authored and facebook-github-bot committed Dec 14, 2024
1 parent 6e9b083 commit c932a35
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 23 deletions.
49 changes: 36 additions & 13 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,47 @@ void nccl_comm_init_rank(
"ncclCommInitRank");
}

void nccl_allgather(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
using namespace c10d;
TORCH_CHECK(src.is_contiguous());
TORCH_CHECK(dst.is_contiguous());
ncclDataType_t type;
switch (src.scalar_type()) {
ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
switch (type) {
case at::kFloat:
type = ncclDataType_t::ncclFloat;
break;
return ncclDataType_t::ncclFloat;
case at::kHalf:
type = ncclDataType_t::ncclHalf;
break;
return ncclDataType_t::ncclHalf;
case at::kDouble:
return ncclDataType_t::ncclDouble;
case at::kLong:
return ncclDataType_t::ncclInt64;
case at::kInt:
return ncclDataType_t::ncclInt;
case at::kChar:
return ncclDataType_t::ncclChar;
case at::kByte:
return ncclDataType_t::ncclUint8;
case at::kBool:
return ncclDataType_t::ncclUint8;
#if defined(USE_ROCM)
case at::kFloat8_e4m3fnuz:
return ncclDataType_t::ncclUint8;
case at::kFloat8_e5m2fnuz:
return ncclDataType_t::ncclUint8;
#else
case at::kFloat8_e4m3fn:
return ncclDataType_t::ncclUint8;
case at::kFloat8_e5m2:
return ncclDataType_t::ncclUint8;
#endif
case at::kBFloat16:
type = ncclDataType_t::ncclBfloat16;
break;
return ncclDataType_t::ncclBfloat16;
default:
TORCH_CHECK(false, "unsupported type: ", src.scalar_type());
TORCH_CHECK(false, "Unconvertible NCCL type ", type);
}
}

void nccl_allgather(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
using namespace c10d;
TORCH_CHECK(src.is_contiguous());
TORCH_CHECK(dst.is_contiguous());
ncclDataType_t type = to_nccl_data_type(src.scalar_type());
C10D_NCCL_CHECK(
ncclAllGather(
src.data_ptr(),
Expand Down
40 changes: 30 additions & 10 deletions fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import torch
from hypothesis import given, settings, strategies as st, Verbosity
from torch.distributed.launcher.api import elastic_launch, LaunchConfig

logger: logging.Logger = logging.getLogger()
Expand All @@ -35,28 +36,24 @@ def has_nvswitch() -> bool:
return "GRANDTETON" in model or "SUPERMICRO" in model


def _run_allgather_inner(rdvz: str) -> None:
def _run_allgather_inner(rdvz: str, dtype: torch.dtype) -> None:
rank = int(os.environ["LOCAL_RANK"])
W = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
torch.ops.fbgemm.nccl_init(rank, W, rdvz)
# torch.distributed.init_process_group(backend="nccl")

B, T, D = 2, 4096, 1024
y = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda")
y = torch.empty(size=(B, T, D), dtype=dtype, device="cuda")
y[:] = rank
y_gather = torch.zeros(size=(W, B, T, D), dtype=torch.bfloat16, device="cuda")
y_gather[:] = -1
y_gather = torch.full(size=(W, B, T, D), fill_value=-1, dtype=dtype, device="cuda")
# Here we test to confirm that allgather is compatible with torch.compile.
torch.compile(torch.ops.fbgemm.nccl_allgather)(y_gather, y)
for w in range(W):
torch.testing.assert_close(
y_gather[w],
torch.full(
size=(B, T, D), fill_value=w, dtype=torch.bfloat16, device=y.device
),
torch.full(size=(B, T, D), fill_value=w, dtype=dtype, device=y.device),
)

for _ in range(20):
Expand Down Expand Up @@ -268,7 +265,30 @@ def round_up(a: int, b: int) -> int:
"Skip when CUDA is not available or when there are not enough GPUs; these tests require at least two GPUs",
)
class LLamaMultiGpuTests(unittest.TestCase):
def test_allgather(self) -> None:
@given(
dtype=st.sampled_from(
[
torch.bfloat16,
torch.float16,
torch.int,
torch.long,
torch.float,
torch.float8_e4m3fn,
]
)
)
@settings(verbosity=Verbosity.verbose, max_examples=6, deadline=60000)
def test_allgather(self, dtype: torch.dtype) -> None:
# float8 is only supported in H100 or MI300x
if dtype == torch.float8_e4m3fn:
if torch.version.hip:
dtype = torch.float8_e4m3fnuz
elif torch.cuda.get_device_capability() < (9, 0):
self.skipTest(
"float8_e4m3fn is only supported in H100 or MI300x, but we're running "
f"on {torch.cuda.get_device_capability()}"
)

with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as path:
lc = LaunchConfig(
min_nodes=1,
Expand All @@ -283,7 +303,7 @@ def test_allgather(self) -> None:
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=_run_allgather_inner)(
os.path.join(path, "rdvz")
os.path.join(path, "rdvz"), dtype
)

def test_allreduce(self) -> None:
Expand Down

0 comments on commit c932a35

Please sign in to comment.