Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8_model_init doesn't work with DDP #1135

Open
MaciejBalaNV opened this issue Aug 26, 2024 · 3 comments
Open

fp8_model_init doesn't work with DDP #1135

MaciejBalaNV opened this issue Aug 26, 2024 · 3 comments

Comments

@MaciejBalaNV
Copy link

When I'm trying to use fp8_model_init feature, it doesn't seem compatible with DDP. It throws an error:
RuntimeError: Modules with uninitialized parameters can't be used with "DistributedDataParallel". Run a dummy forward pass to correctly initialize the modules

Running a dummy forward pass doesn't help, using reset_parameters doesn't help either. Using a separate stream for DDP also does not fix this issue.

A simple reproducible case:

import os
import torch
import torch.nn as nn
import functools
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy
import transformer_engine as te

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12364"

    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = te.pytorch.Linear(1024, 1024)
        self.fc2 = te.pytorch.Linear(1024, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


def fsdp_main(rank, world_size):
    setup(rank, world_size)

    torch.cuda.set_device(rank)


    with te.pytorch.fp8.fp8_model_init(enabled=True):
        model = Net().to(rank)
    for i, m in enumerate(model.modules()):
        if hasattr(m, "reset_parameters"):
            print(f"resetting {i}")
            m.reset_parameters()
    input_data = torch.randn((16, 1024)).cuda()
    with torch.no_grad():
        model(input_data)
    torch.cuda.synchronize()
    model = DDP(model)
    torch.cuda.synchronize()

    dist.barrier()
    cleanup()


if __name__ == "__main__":
    WORLD_SIZE = 8
    mp.spawn(fsdp_main, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True)

@denera

@timmoon10
Copy link
Collaborator

Do you need both DDP and FP8 params for your use-case? We haven't considered this combination so far since optimizing FP8 params tends to have poor convergence. There are a few ways to proceed:

  • Initialize your weights in higher precision and rely on TE's automatic FP8 casting. If your training loop involves multiple grad accumulation steps, you can pass in is_first_microbatch=True/False to cache the FP8 weights.
  • Maintain a separate set of FP32 master params for the grad all-reduce and optimizer. This is generally how DDP is implemented in Megatron-LM and NeMo.
  • Use FSDP mixed precision support to store the sharded weights in FP32 and the gathered weights in FP8. This isn't supported with TE yet, but FSDP has some callback hooks where we can add FP8-related logic (see fsdp_pre_all_gather and fsdp_post_all_gather).
  • DDP doesn't actually need the param values, just the grads. We could debug this case and figure out a way to bypass this error.

@denera
Copy link
Collaborator

denera commented Aug 26, 2024

@MaciejBalaNV Transformer Engine modules that are initialized under te.pytorch.fp8_model_init() still need to be executed under te.pytorch.fp8_autocast() with an FP8 recipe for operations that we have to perform in higher precision. Missing this context might be the reason why the model parameters were not correctly initialized in your case, and if so, we should definitely catch that and show a useful error message.

For reference, here's a modified version of your DDP example that works correctly on my end:

import os
import socket
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import transformer_engine as te


class BasicMLP(nn.Module):
    """Basic MLP block"""

    def __init__(self, hidden_size, ffn_hidden_size, **kwargs):
        super().__init__()
        tp_group = kwargs.pop("tp_group", None)
        parallel_mode = kwargs.pop("parallel_mode", None)
        fc1_parallel_mode = fc2_parallel_mode = parallel_mode
        if tp_group is not None:
            fc1_parallel_mode = "row"
            fc2_parallel_mode = "column"
        self.fc1 = te.pytorch.Linear(hidden_size, ffn_hidden_size,parallel_mode=fc1_parallel_mode,
                                     **kwargs)
        self.fc2 = te.pytorch.Linear(ffn_hidden_size, hidden_size, parallel_mode=fc2_parallel_mode,
                                     **kwargs)

    def forward(self, x):
        """Forward pass: FC2(act_fn(FC1(x)))"""
        return self.fc2(self.fc1(x))


def _ddp_main(rank, world_size, num_replicas):
    SEQ_LENGTH = 512
    BATCH_SIZE = 2
    HIDDEN_SIZE = 256
    FFN_HIDDEN_SIZE = 4 * HIDDEN_SIZE

    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["MASTER_ADDR"] = socket.gethostname()
    os.environ["MASTER_PORT"] = "12345"
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(rank)

    if num_replicas == 1:
        dp_group = None
        tp_group = dist.new_group()
    elif num_replicas == world_size:
        dp_group = dist.new_group()
        tp_group = None
    else:
        assert num_replicas > 0 and num_replicas < world_size and world_size % num_replicas == 0
        replica_size = world_size // num_replicas
        mesh_2d = dist.init_device_mesh("cuda", (num_replicas, replica_size))
        dp_group, tp_group = mesh_2d.get_all_groups()

    with te.pytorch.fp8.fp8_model_init(enabled=True):
        model = BasicMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, tp_group=tp_group)

    if dp_group is not None:
        model = DDP(model, process_group=dp_group)
    optim = torch.optim.Adam(model.parameters())

    for _ in range(10):
        optim.zero_grad()
        input_data = torch.randn((SEQ_LENGTH, BATCH_SIZE, HIDDEN_SIZE), device="cuda")
        with te.pytorch.fp8_autocast(enabled=True):
            output = model(input_data)
        loss = output.sum()
        loss.backward()
        optim.step()

    dist.destroy_process_group()


if __name__ == "__main__":
    NUM_REPLICAS = 2

    if "TORCHELASTIC_RUN_ID" in os.environ:
        # Using the `torchrun` utility
        WORLD_RANK = int(os.getenv("RANK", "0"))
        WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
        _ddp_main(WORLD_RANK, WORLD_SIZE, NUM_REPLICAS)

    else:
        WORLD_SIZE = 8
        mp.spawn(_ddp_main, args=(WORLD_SIZE, 2), nprocs=WORLD_SIZE, join=True)

@MaciejBalaNV
Copy link
Author

@timmoon10
The use case is to use TP/SP without FSDP (it's problematic for many reason, fp8_model_init not working there is one of them) for large model training, while still utilizing data parallel through DDP. Are you suggesting that there are better option to achieve data parallelism than DDP?

@denera
I don't think the error is because of the lact of te.pytorch.fp8_autocast() - if we delete the forward pass before DDP wrapping, the error still happens at the wrapping stage. I only included this forward pass to try to answer the error message, which suggested running a dummy forward. Thanks for this example - I've played around with it and found that it only works on nightly build with PyTorch 2.5 and TE 1.10. It still breaks with the same error message on 24.07 PyTorch container (which has 2.4 PyTorch), even if I reinstall TE to 1.10 or 1.11 version. Seems like something changed in PyTorch very recently then, in which case I'm not sure if any fixes are necessary on TE side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@timmoon10 @denera @MaciejBalaNV and others