diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 97f18b037b..1af004f1ad 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -13,25 +13,25 @@ LayerNormLinear, LayerNormMLP, Linear, - make_graphed_callables, MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init, + make_graphed_callables, ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops -# Only run FP8 tests on H100. +# Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# Record initial RNG state. seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -# Record initial RNG state from script run. _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() @@ -49,25 +49,14 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} -modules = [ - "transformer", - "layernorm_mlp", - "layernorm_linear", - "linear", - "mha", - "dpa", - "linear_op", -] - -all_boolean = [True, False] - -dtypes = [torch.float32, torch.float16] +# Supported data types +dtypes: List[torch.dtype] = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher dtypes.append(torch.bfloat16) def reset_rng_states() -> None: - """revert back to initial RNG state.""" + """Revert to initial RNG state.""" torch.set_rng_state(_cpu_rng_state) torch.cuda.set_rng_state(_cuda_rng_state) @@ -79,64 +68,40 @@ def reset_global_fp8_state(): def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: - """Ensures two lists are equal.""" + """Check that two lists of tensors match exactly.""" assert len(l1) == len(l2), "Unequal number of outputs." - failed = False - failed_tensors = "" + failure_message = "Output mismatches in:" + failed_tensors = [] for i, (t1, t2) in enumerate(zip(l1, l2)): if not torch.equal(t1, t2): - failed = True - failed_tensors += ( - f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" - ) - assert not failed, "Output mismatches in:\n" + failed_tensors + failure_message += "\n " + if names is None: + failure_message += f"tensor at idx={i}" + else: + failure_message += names[i] + failed_tensors.append((t1, t2)) + if failed_tensors: + print(failure_message) + t1, t2 = failed_tensors[0] + torch.testing.assert_close(t1, t2, rtol=0, atol=0) def generate_data( - config: ModelConfig, + model_config: ModelConfig, dtype: torch.dtype, - dpa: bool = False, warmup: bool = False, - return_grad_output: bool = False, -) -> Tuple[List[torch.Tensor], torch.Tensor]: + requires_grad: bool = True, +) -> torch.Tensor: """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn - if dpa: - inputs = [ - gen_func( - config.sequence_length, - config.batch_size, - config.num_heads, - config.kv_channels, - device="cuda", - requires_grad=True, - dtype=dtype, - ) - for _ in range(3) - ] - else: - inputs = [ - gen_func( - config.sequence_length, - config.batch_size, - config.hidden_size, - device="cuda", - requires_grad=True, - dtype=dtype, - ) - ] - - if not return_grad_output: - return inputs - - grad_output = torch.randn( - config.sequence_length, - config.batch_size, - config.hidden_size, + return gen_func( + model_config.sequence_length, + model_config.batch_size, + model_config.hidden_size, device="cuda", + requires_grad=requires_grad, dtype=dtype, ) - return inputs, grad_output def get_outputs( @@ -166,33 +131,43 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: return x +# Supported modules +_test_cuda_graphs_modules: List[str] = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", + "linear_op", +] + def _test_cuda_graphs( *, - config: ModelConfig, + graph_mode: str, + module: str, + model_config: ModelConfig, num_layers: int, dtype: torch.dtype, fp8: bool, fp8_params: bool, fp8_weight_caching: bool, - module: str, - graph_mode: str, ) -> List[torch.Tensor]: """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() - dpa = module == "dpa" + # Operation-based API does not support FP8 weight caching. if module == "linear_op": fp8_weight_caching = False + # Create modules. with fp8_model_init(enabled=fp8_params): - # Create modules. if module == "transformer": modules = [ TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.hidden_size, + model_config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, fuse_qkv_params=True, @@ -202,41 +177,51 @@ def _test_cuda_graphs( ] elif module == "layernorm_mlp": modules = [ - LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype) + LayerNormMLP( + model_config.hidden_size, + model_config.hidden_size, + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "layernorm_linear": modules = [ - LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype) + LayerNormLinear( + model_config.hidden_size, + model_config.hidden_size, + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "mha": modules = [ MultiheadAttention( - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.num_heads, attention_dropout=0.0, params_dtype=dtype, fuse_qkv_params=True, ) for _ in range(num_layers) ] - elif module == "dpa": - assert config.hidden_size % config.num_heads == 0, "Err." - assert num_layers == 1, "Err." - modules = [ - DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0) - for _ in range(num_layers) - ] elif module == "linear": modules = [ - Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype) + Linear( + model_config.hidden_size, + model_config.hidden_size, + device="cuda", + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "linear_op": modules = [ te_ops.Sequential( - te_ops.Linear(config.hidden_size, config.hidden_size, dtype=dtype), + te_ops.Linear( + model_config.hidden_size, + model_config.hidden_size, + dtype=dtype, + ), ) for _ in range(num_layers) ] @@ -251,111 +236,207 @@ def _test_cuda_graphs( # Generate model and wrap API to return graphed version. if graph_mode == "full": # Graph entire model at once. - model = modules[0] if dpa else torch.nn.Sequential(*modules) + model = torch.nn.Sequential(*modules) model = make_graphed_callables( model, - generate_data(config, dtype, dpa=dpa, warmup=True), + (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, ) elif graph_mode == "individual": - # Graph individual modules + # Graph individual modules. modules = [ make_graphed_callables( module, - generate_data(config, dtype, dpa=dpa, warmup=True), + (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, ) for module in modules ] - model = modules[0] if dpa else _Sequential(*modules) + model = _Sequential(*modules) else: - model = modules[0] if dpa else _Sequential(*modules) + model = _Sequential(*modules) # Optimizer. - if not dpa: - optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) - # Launch. + # Training steps. for _ in range(3): - if not dpa: - optimizer.zero_grad(set_to_none=False) + optimizer.zero_grad(set_to_none=False) for grad_accumulation_step in range(2): - inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True) + input_ = generate_data(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) with fp8_autocast(enabled=fp8): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 - output = model(*inputs, **kwargs) + output = model(input_, **kwargs) output.backward(grad_output) - if not dpa: - optimizer.step() + optimizer.step() return get_outputs(model, output) +@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("model", model_configs.keys()) -@pytest.mark.parametrize("num_layers", [1, 3]) -@pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("fp8_params", all_boolean) -@pytest.mark.parametrize("fp8_weight_caching", all_boolean) -@pytest.mark.parametrize("module", modules) -def test_gpt_make_graphed_callables( +@pytest.mark.parametrize("fp8", (False, True)) +@pytest.mark.parametrize("fp8_params", (False, True)) +def test_make_graphed_callables( + *, + module: str, + model_config: str = "small", + num_layers: int = 3, dtype: torch.dtype, - model: str, - num_layers: int, fp8: bool, fp8_params: bool, - fp8_weight_caching: bool, - module: str, + fp8_weight_caching: bool = False, ) -> None: + + # Skip invalid configurations. if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if module == "dpa" and num_layers > 1: - pytest.skip("Max 1 layer for DPA.") - - config = model_configs[model] + # Run model with different CUDA graph settings. + model_config = model_configs[model_config] kwargs = dict( - config=config, + module=module, + model_config=model_config, num_layers=num_layers, dtype=dtype, fp8=fp8, fp8_params=fp8_params, fp8_weight_caching=fp8_weight_caching, - module=module, ) outputs = _test_cuda_graphs(graph_mode="none", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) - # Check that results match + # Check that results match. assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode2) -def _test_cuda_graphs_with_kwargs( +_test_make_graphed_callables_with_fp8_weight_caching_modules = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", +] + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize( + "module", + _test_make_graphed_callables_with_fp8_weight_caching_modules, +) +@pytest.mark.parametrize("fp8_params", (False, True)) +def test_make_graphed_callables_with_fp8_weight_caching( *, - config: ModelConfig, + module: str, + fp8_params: bool, +) -> None: + test_make_graphed_callables( + module=module, + dtype=torch.float32, + fp8=True, + fp8_params=fp8_params, + fp8_weight_caching=True, + ) + + +def generate_data_for_dot_product_attention( + model_config: ModelConfig, dtype: torch.dtype, + warmup: bool = False, +) -> List[torch.Tensor]: + """Generate synthetic data for dot product attention.""" + gen_func = torch.ones if warmup else torch.randn + return [ + gen_func( + model_config.sequence_length, + model_config.batch_size, + model_config.num_heads, + model_config.kv_channels, + device="cuda", + requires_grad=True, + dtype=dtype, + ) + for _ in range(3) + ] + + +def _test_cuda_graphs_with_dot_product_attention( + *, with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, ) -> List[torch.Tensor]: - """Simulate Megatron-LM interleaved pipeline parallelism.""" + """Helper function for CUDA graph test.""" + reset_rng_states() + FP8GlobalStateManager.reset() + + # Create dot product attention module. + assert model_config.hidden_size % model_config.num_heads == 0 + model = DotProductAttention( + model_config.num_heads, + model_config.kv_channels, + attention_dropout=0.0, + ) + + # Graph model if needed. + if with_graph: + model = make_graphed_callables( + model, + generate_data_for_dot_product_attention(model_config, dtype, warmup=True), + num_warmup_iters=10, + fp8_enabled=False, + ) + + # Forward and backward passes. + for _ in range(3): + inputs = generate_data_for_dot_product_attention(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) + output = model(*inputs) + output.backward(grad_output) + + return get_outputs(model, output) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_make_graphed_callables_with_dot_product_attention( + *, + model_config: str = "small", + dtype: torch.dtype, +) -> None: + """Test CUDA graphs with dot product attention.""" + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) + outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs) + graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs) + assert_all_equal(outputs, graph_outputs) + + +def _test_cuda_graphs_with_kwargs( + *, + with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, +) -> List[torch.Tensor]: + """Helper function for CUDA graph test with keyword arguments.""" reset_rng_states() # Initialize model. model = TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.hidden_size, + model_config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, self_attn_mask_type="arbitrary", @@ -370,13 +451,18 @@ def _test_cuda_graphs_with_kwargs( # Make graphed version of model if needed. if with_graph: attn_mask = torch.zeros( - (config.batch_size, 1, config.sequence_length, config.sequence_length), + ( + model_config.batch_size, + 1, + model_config.sequence_length, + model_config.sequence_length, + ), dtype=torch.bool, device="cuda", ) model = make_graphed_callables( model, - generate_data(config, dtype, warmup=True), + (generate_data(model_config, dtype, warmup=True),), sample_kwargs=dict(attention_mask=attn_mask), allow_unused_input=True, ) @@ -388,14 +474,15 @@ def _test_cuda_graphs_with_kwargs( for _ in range(3): optimizer.zero_grad(set_to_none=False) for grad_accumulation_step in range(2): - inputs, grad_output = generate_data(config, dtype, return_grad_output=True) + input_ = generate_data(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) attn_mask = torch.randint( 2, - (config.batch_size, 1, config.sequence_length, config.sequence_length), + (model_config.batch_size, 1, model_config.sequence_length, model_config.sequence_length), dtype=torch.bool, device="cuda", ) - output = model(*inputs, attention_mask=attn_mask) + output = model(input_, attention_mask=attn_mask) output.backward(grad_output) optimizer.step() @@ -403,12 +490,13 @@ def _test_cuda_graphs_with_kwargs( def test_make_graphed_callables_with_kwargs( + *, + model_config: str = "small", dtype: torch.dtype = torch.float32, - model: str = "small", ) -> None: """Test CUDA graphs with keyword arguments.""" - config = model_configs[model] - kwargs = dict(config=config, dtype=dtype) + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs) graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs) assert_all_equal(outputs, graph_outputs) @@ -416,9 +504,9 @@ def test_make_graphed_callables_with_kwargs( def _test_cuda_graphs_with_interleaved_pipeline_parallelism( *, - config: ModelConfig, - dtype: torch.dtype, with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, ) -> List[torch.Tensor]: """Simulate Megatron-LM interleaved pipeline parallelism.""" reset_rng_states() @@ -432,8 +520,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( model = torch.nn.ModuleList( [ Linear( - config.hidden_size, - config.hidden_size, + model_config.hidden_size, + model_config.hidden_size, params_dtype=dtype, ) for _ in range(num_layers) @@ -451,7 +539,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( } if with_graph: sample_args = tuple( - generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches) + (generate_data(model_config, dtype, warmup=True),) + for _ in range(num_layers * num_microbatches) ) layer_forwards = make_graphed_callables( tuple(model), @@ -476,9 +565,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( grad_outputs = {} for layer_idx in range(num_layers): for microbatch_idx in range(num_microbatches): - x, dy = generate_data(config, dtype, return_grad_output=True) + x = generate_data(model_config, dtype) + dy = generate_data(model_config, dtype, requires_grad=False) idxs = (layer_idx, microbatch_idx) - inputs[idxs] = x[0] + inputs[idxs] = x grad_outputs[idxs] = dy # Cache for layer outputs. @@ -515,12 +605,13 @@ def backward(layer_idx: int, microbatch_idx: int): def test_make_graphed_callables_with_interleaved_pipeline_parallelism( + *, + model_config: str = "small", dtype: torch.dtype = torch.float16, - model: str = "small", ) -> None: """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" - config = model_configs[model] - kwargs = dict(config=config, dtype=dtype) + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=False, **kwargs,