Skip to content

Commit

Permalink
disable collective fusion when chunk size is too small
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <bill@neuralmagic.com>
  • Loading branch information
bnellnm committed Nov 25, 2024
1 parent 97703da commit 7ebd94c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 24 deletions.
23 changes: 7 additions & 16 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import vllm.envs as envs
from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem,
last_node_in_match)
last_node_in_match, use_cc_kernels)
from vllm.config import CompilationConfig
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
Expand All @@ -26,25 +26,16 @@
try:
import flux
use_flux = True
logger.info("USING FLUX")
logger.info("Using flux kernels for collective communication fusion.")
except ImportError:
logger.info("Attempting to use flux but flux not installed.")
use_flux = False

FLUX_TILE_SIZE: int = 128


# TODO: is this ok?
def get_world_name() -> str:
return torch.distributed.group.WORLD.group_name


# Note: this heuristic is unique to flux
def should_slice(shape: torch.Size) -> bool:
n_slices = get_tensor_model_parallel_world_size()
return (shape[0] % (FLUX_TILE_SIZE * n_slices) == 0
and shape[0] >= FLUX_TILE_SIZE * n_slices)


def residual_slice_shape(residual: torch.Tensor, rank: int) -> int:
n_slices = get_tensor_model_parallel_world_size()
chunk, rem = divmod(residual.shape[0], n_slices)
Expand Down Expand Up @@ -173,15 +164,15 @@ def gemm_rs_ag_gemm(
first_layer: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if first_layer and should_slice(residual.shape):
if first_layer and use_cc_kernels(residual.shape[0]):
slice_shape = residual_slice_shape(residual, rank)
residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape)
my_residual = residual_chunk[0]
else:
my_residual = residual
slice_shape = residual.shape[0]

if not should_slice(residual.shape):
if not use_cc_kernels(residual.shape[0]):
output = torch.ops.aten.mm.default(gemm_1_activations,
gemm_1_weights.transpose(1, 0))
reduced_output = tensor_model_parallel_all_reduce(output)
Expand Down Expand Up @@ -222,7 +213,7 @@ def gemm_rs_ag_gemm_fake(
gemm_2_weights: torch.Tensor,
first_layer: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if first_layer and should_slice(gemm_1_activations.shape):
if first_layer and use_cc_kernels(gemm_1_activations.shape[0]):
slice_shape = residual_slice_shape_fake(residual, rank)
split_1 = torch.ops.aten.split.Tensor(residual, slice_shape)
my_residual = split_1[0]
Expand Down Expand Up @@ -293,7 +284,7 @@ def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,

reduced = tensor_model_parallel_all_reduce(mm_1)

if should_slice(gemm_1_activations.shape):
if use_cc_kernels(gemm_1_activations.shape[0]):
wait_tensor = tensor_model_parallel_all_gather(my_residual)
else:
wait_tensor = my_residual
Expand Down
23 changes: 15 additions & 8 deletions vllm/compilation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import Match

from vllm.config import CompilationConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
get_tensor_model_parallel_world_size as get_tp_world_size)
from vllm.distributed import model_parallel_is_initialized as p_is_init
from vllm.logger import init_logger

# yapf: enable

logger = init_logger(__name__)

COUNTS: Dict[str, int] = {}

# Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h
# Can be 256 on sm80.
FLUX_TILE_SIZE: int = 128


def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]:
for node in nodes:
Expand Down Expand Up @@ -53,21 +53,28 @@ def last_node_in_match(match: Match) -> fx.Node:
raise ValueError("No nodes in graph")


def dump_graph(config: CompilationConfig.PassConfig, graph: fx.Graph,
name: str) -> None:
def dump_graph(pass_config, graph: fx.Graph, name: str) -> None:
global COUNTS
count = COUNTS.get(name, 0)

# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
filepath = config.dump_graph_dir / f"{name}{rank}-{count}.py"
filepath = pass_config.dump_graph_dir / f"{name}{rank}-{count}.py"
COUNTS[name] = count + 1

os.makedirs(config.dump_graph_dir, exist_ok=True)
os.makedirs(pass_config.dump_graph_dir, exist_ok=True)
logger.info("%s printing graph to %s", name, filepath)
with open(filepath, "w") as f:
src = graph.owning_module.print_readable(print_output=False)
# Add imports so it's not full of errors
print("import torch; from torch import device", file=f)
print(src, file=f)


# Note: this heuristic is unique to flux
def use_cc_kernels(m_shape: int, n_slices: Optional[int] = None) -> bool:
if n_slices is None:
n_slices = get_tp_world_size()
return (m_shape % (FLUX_TILE_SIZE * n_slices) == 0
and m_shape >= FLUX_TILE_SIZE * n_slices)
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.utils import use_cc_kernels
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
Expand Down Expand Up @@ -2431,6 +2432,14 @@ def __post_init__(self):
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION

n_slices = self.parallel_config.world_size
max_tokens = self.scheduler_config.max_num_batched_tokens
if not use_cc_kernels(max_tokens / n_slices, n_slices):

Check failure on line 2437 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Argument 1 to "use_cc_kernels" has incompatible type "float"; expected "int" [arg-type]

Check failure on line 2437 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Argument 1 to "use_cc_kernels" has incompatible type "float"; expected "int" [arg-type]

Check failure on line 2437 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Argument 1 to "use_cc_kernels" has incompatible type "float"; expected "int" [arg-type]

Check failure on line 2437 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Argument 1 to "use_cc_kernels" has incompatible type "float"; expected "int" [arg-type]
logger.info(
("Disabling collective fusion pass since chunked prefill size "
"%d is too small."), max_tokens)
self.compilation_config.pass_config.enable_collective_fusion = False

current_platform.check_and_update_config(self)

def __str__(self):
Expand Down

0 comments on commit 7ebd94c

Please sign in to comment.