diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index a23de29e02..b1dd7b69eb 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,6 +7,7 @@ from collections.abc import Iterable import io import math +import random from typing import Optional import pytest @@ -172,6 +173,29 @@ def make_reference_and_test_tensors( return ref, test +def assert_close( + a: Optional[torch.Tensor], + b: Optional[torch.Tensor], + *, + rtol: float, + atol: float, +) -> None: + """Assert that two tensors are close.""" + if a is None and b is None: + return + assert a is not None + assert b is not None + a = a.detach() + b = b.detach() + if isinstance(a, QuantizedTensor): + a = a.dequantize() + if isinstance(b, QuantizedTensor): + b = b.dequantize() + a = a.to(dtype=torch.float64, device="cpu") + b = b.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + class TestSequentialContainer: """Tests for sequential container""" @@ -1680,6 +1704,7 @@ def test_swiglu( quantization: Optional[str], quantize_forward: bool, quantize_backward: bool, + glu_interleave_size: Optional[int] = None, ): # Tensor dimensions @@ -1706,7 +1731,17 @@ def test_swiglu( ) # Plain PyTorch implementation - x1, x2 = x_ref.chunk(2, dim=-1) + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + *in_shape[:-1], + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(-3, -2) + x = x.reshape(in_shape) + x1, x2 = x.chunk(2, dim=-1) y_ref = torch.nn.functional.silu(x1) * x2 y_ref.backward(dy_ref) @@ -1714,7 +1749,7 @@ def test_swiglu( recipe = make_recipe(quantization) forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.SwiGLU(), + te_ops.SwiGLU(glu_interleave_size=glu_interleave_size), te_ops.Quantize(forward=quantize_forward, backward=False), ) with te.autocast(enabled=quantized_compute, recipe=recipe): @@ -1727,10 +1762,18 @@ def test_swiglu( tols = quantization_tols(quantization) # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + + def test_interleaved_swiglu(self): + self.test_swiglu( + out_shape=(32, 192), + dtype=torch.float32, + quantization=None, + quantize_forward=False, + quantize_backward=False, + glu_interleave_size=32, + ) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) @@ -1924,6 +1967,231 @@ def test_dropout( abs(z_score) < 2.5758 ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("weight_requires_grad", (False, True)) + def test_grouped_linear( + self, + *, + group_size: int = 4, + bias: bool, + weight_shape: tuple[int, int] = (128, 128), + split_alignment: int = 128, + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_compute: bool, + quantized_weight: bool, + input_requires_grad: bool, + weight_requires_grad: bool, + ) -> None: + """Grouped GEMM""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = (split_sizes.sum().item(), in_features) + out_shape = (in_shape[0], out_features) + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") + if quantization is not None and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + ws_ref, ws_test = [], [] + bs_ref, bs_test = [], [] + for _ in range(group_size): + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + ws_ref.append(w_ref) + ws_test.append(w_test) + bs_ref.append(b_ref) + bs_test.append(b_test) + + # Plain PyTorch implementation + xs_ref = torch.split(x_ref, split_sizes.tolist()) + ys_ref = [] + for x, w, b in zip(xs_ref, ws_ref, bs_ref): + ys_ref.append(torch.nn.functional.linear(x, w, bias=b)) + y_ref = torch.cat(ys_ref) + if input_requires_grad or weight_requires_grad: + y_ref.backward(dy_ref) + + # Construct fusible operation + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te_ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + for group_idx in range(group_size): + getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) + if bias: + getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + del ws_test, bs_test + for param in op.parameters(): + param.requires_grad_(requires_grad=weight_requires_grad) + + # Forward and backward pass with op + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = op(x_test, split_sizes) + if input_requires_grad or weight_requires_grad: + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + else: + assert x_test.grad is None + for group_idx in range(group_size): + w_test = getattr(op, f"weight{group_idx}") + if weight_requires_grad: + dw_test = w_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols) + else: + assert w_test.grad is None + if bias: + b_test = getattr(op, f"bias{group_idx}") + if weight_requires_grad: + db_test = b_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols) + else: + assert b_test.grad is None + + @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) + @pytest.mark.parametrize("glu_interleave_size", (None, 32)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("scales_requires_grad", (False, True)) + def test_scaled_swiglu( + self, + *, + in_shape: Iterable[int], + glu_interleave_size: Optional[int], + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + scales_requires_grad: bool, + ) -> None: + """Multiply two tensors""" + + # Tensor dims + out_shape = list(in_shape) + out_shape[-1] //= 2 + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + scales_ref, scales_test = make_reference_and_test_tensors( + in_shape[:-1], + test_dtype=dtype, + test_device=device, + requires_grad=scales_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + -1, + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(in_shape) + x1, x2 = x.chunk(2, dim=-1) + y = torch.nn.functional.silu(x1) * x2 + y_ref = scales_ref.unsqueeze(-1) * y + if input_requires_grad or scales_requires_grad: + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + y_test = op(x_test, scales_test) + if input_requires_grad or scales_requires_grad: + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + else: + assert x_test.grad is None + if scales_requires_grad: + ds_test = scales_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(ds_test, scales_ref.grad, **tols) + else: + assert scales_test.grad is None + class TestFusedOps: """Tests for fused operations""" @@ -2931,6 +3199,215 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("glu_interleave_size", (None, 32)) + def test_grouped_mlp( + self, + *, + group_size: int = 4, + bias: bool, + hidden_size: int = 256, + dtype: torch.dtype, + quantization: Optional[str], + device: torch.device = "cuda", + split_alignment: int = 256, + glu_interleave_size: Optional[int], + ) -> None: + """GroupedLinear + ScaledSwiGLU + GroupedLinear""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + + # Make input shape + in_shape = (split_sizes.sum().item(), hidden_size) + out_shape = in_shape + + # Skip invalid configurations + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + if with_quantization and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + probs_ref, probs_test = make_reference_and_test_tensors( + (in_shape[0],), + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref, fc1_ws_test = [], [] + fc1_bs_ref, fc1_bs_test = [], [] + fc2_ws_ref, fc2_ws_test = [], [] + fc2_bs_ref, fc2_bs_test = [], [] + for _ in range(group_size): + fc1_w_ref, fc1_w_test = make_reference_and_test_tensors( + (2 * hidden_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( + (hidden_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc1_b_ref, fc1_b_test = None, None + fc2_b_ref, fc2_b_test = None, None + if bias: + fc1_b_ref, fc1_b_test = make_reference_and_test_tensors( + (2 * hidden_size,), + test_dtype=dtype, + test_device=device, + ) + fc2_b_ref, fc2_b_test = make_reference_and_test_tensors( + (hidden_size,), + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref.append(fc1_w_ref) + fc1_bs_ref.append(fc1_b_ref) + fc1_ws_test.append(fc1_w_test) + fc1_bs_test.append(fc1_b_test) + fc2_ws_ref.append(fc2_w_ref) + fc2_bs_ref.append(fc2_b_ref) + fc2_ws_test.append(fc2_w_test) + fc2_bs_test.append(fc2_b_test) + with torch.no_grad(): + for t in fc1_ws_ref + fc1_ws_test + fc2_ws_ref + fc2_ws_test: + t -= 0.5 + t *= 1 / 2 + for t in (x_ref, x_test, dy_ref, dy_test): + t -= 0.5 + t *= 1 / 2 + if bias: + for t in fc1_bs_ref + fc1_bs_test + fc2_bs_ref + fc2_bs_test: + t -= 0.5 + + # Reference implementation + xs = torch.split(x_ref, split_sizes.tolist()) + probs = torch.split(probs_ref, split_sizes.tolist()) + ys = [] + for group_idx in range(group_size): + x = xs[group_idx] + x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + if glu_interleave_size is not None: + x = x.reshape( + -1, + 2 * hidden_size // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(-1, 2 * hidden_size) + x1, x2 = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(x1) * x2 + x = x * probs[group_idx].unsqueeze(-1) + x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx]) + ys.append(x) + y_ref = torch.cat(ys) + y_ref.backward(dy_ref) + + # Construct operations + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): + fc1 = te_ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=bias, + device=device, + dtype=dtype, + ) + fc2 = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + ) + module = te_ops.Sequential( + fc1, + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + fc2, + ) + + # Copy weights + with torch.no_grad(): + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + if bias: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test + + # Fuse ops and perform forward and backward pass + with te.autocast(enabled=with_quantization, recipe=recipe): + y_test = module(x_test, split_sizes, probs_test, split_sizes) + y_test.backward(dy_test) + + # Check for expected fusions + if ( + te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() + and quantization == "mxfp8" + and dtype == torch.bfloat16 + and not bias + and glu_interleave_size == 32 + ): + forward_ops = module._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ) + + def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Convert to FP64 CPU tensor""" + if tensor is None: + return None + out = tensor.detach().to(dtype=torch.float64, device="cpu") + out = out.requires_grad_(requires_grad=tensor.requires_grad) + return out + + # Loose tols for sanity checking + tols = {"rtol": 0.25, "atol": 0.5} + if quantization == "nvfp4": + tols = {"rtol": 0.5, "atol": 1} + + # Check values + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + assert_close(probs_test.grad, probs_ref.grad, **tols) + for group_idx in range(group_size): + assert_close( + getattr(fc2, f"weight{group_idx}").grad, + fc2_ws_ref[group_idx].grad, + **tols, + ) + assert_close( + getattr(fc1, f"weight{group_idx}").grad, + fc1_ws_ref[group_idx].grad, + **tols, + ) + class TestCustomOps: """Test with ops that are defined externally""" diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 665ffe359c..32da121cce 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -14,8 +14,6 @@ SReLU, SReGLU, SiLU, - SwiGLU, - ClampedSwiGLU, ) from .add_extra_input import AddExtraInput from .all_gather import AllGather @@ -24,6 +22,7 @@ from .bias import Bias from .constant_scale import ConstantScale from .dropout import Dropout +from .grouped_linear import GroupedLinear from .identity import Identity from .l2normalization import L2Normalization from .layer_norm import LayerNorm @@ -32,3 +31,4 @@ from .reduce_scatter import ReduceScatter from .reshape import Reshape from .rmsnorm import RMSNorm +from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 9d54e12dba..2f1debdf5e 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -27,8 +27,6 @@ "SReLU", "SReGLU", "SiLU", - "SwiGLU", - "ClampedSwiGLU", ] @@ -355,76 +353,3 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dsilu(*args, **kwargs) - - -class SwiGLU(_ActivationOperation): - r"""Swish gated linear unit - - The input tensor is split into chunks :math:`a` and :math:`b` - along the last dimension and the following is computed: - - .. math:: - - \text{GEGLU}(a,b) = \text{SiLU}(a) * b - - where - - .. math:: - - \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} - - .. warning:: - - Transformer Engine's gated activations and PyTorch's GLU - activation follow opposite conventions for :math:`a` and - :math:`b`. Transformer Engine applies the gating function to - the first half of the input tensor, while PyTorch applies it to - the second half. - - The Sigmoid Linear Unit (SiLU) gating function is also known as - the swish function. See - `GLU Variants Improve Transformer`__ - and `Gaussian Error Linear Units (GELUs)`__. - - """ - - def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.swiglu(*args, **kwargs) - - def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.dswiglu(*args, **kwargs) - - -class ClampedSwiGLU(_ActivationOperation): - r"""GPT-OSS - Implementation based on `GPT-OSS`__. - - This activation has two differences compared to the original SwiGLU - 1. Both gate and pre-activations are clipped based on parameter limit. - 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - - .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt - from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. - - Parameters - ---------- - limit : float - The clamp limit. - alpha : float - The scaling factor for the sigmoid function used in the activation. - cache_quantized_input : bool, default = False - Quantize input tensor when caching for use in the backward pass. - """ - - def __init__( - self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False - ): - super().__init__(cache_quantized_input=cache_quantized_input) - self.limit = limit - self.alpha = alpha - - def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) - - def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py new file mode 100644 index 0000000000..ac6b5665e3 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -0,0 +1,606 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for bias.""" + +from __future__ import annotations +from collections.abc import Callable, Iterable +import contextlib +import math +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpp_extensions import general_grouped_gemm +from ...distributed import CudaRNGStatesTracker +from ...module.base import ( + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, + get_dummy_wgrad, +) +from ...quantization import FP8GlobalStateManager, Recipe +from ...tensor import Quantizer +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) +from .._common import is_quantized_tensor, maybe_dequantize +from ..op import BasicOperation, OperationContext + + +class GroupedLinear(BasicOperation): + """Apply multiple linear transformations: :math:``y_i = x_i W_i^T + b_i`` + + This is equivalent to splitting the input tensor along its first + dimension, applying a separate ``torch.nn.Linear`` to each split, + and concatenating along the first dimension. + + Parameters + ---------- + num_groups : int + Number of linear transformations. + in_features : int + Inner dimension of input tensor. + out_features : int + Inner dimension of output tensor. + bias : bool, default = ``True`` + Apply additive bias. + device : torch.device, default = default CUDA device + Tensor device. + dtype : torch.dtype, default = default dtype + Tensor datatype. + rng_state_tracker_function : callable + Function that returns ``CudaRNGStatesTracker``, which is used + for model-parallel weight initialization. + accumulate_into_main_grad : bool, default = ``False`` + Whether to directly accumulate weight gradients into the + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally + and there is no guarantee that `grad` will be set or be + meaningful. This is primarily intented to integrate with + Megatron-LM. This argument along with weight tensor having + attribute ``overwrite_main_grad`` set to True will overwrite + ``main_grad`` instead of accumulating. + + """ + + # Operation expects input split sizes + num_extra_inputs: int = 1 + + def __init__( + self, + num_groups: int, + in_features: int, + out_features: int, + *, + bias: bool = True, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, + accumulate_into_main_grad: bool = False, + ) -> None: + super().__init__() + + # Weight tensor dimensions + self.num_groups: int = num_groups + self.in_features: int = in_features + self.out_features: int = out_features + if self.num_groups <= 0: + raise ValueError(f"Invalid number of groups ({self.num_groups})") + if self.in_features <= 0: + raise ValueError(f"Invalid input size ({self.in_features})") + if self.out_features <= 0: + raise ValueError(f"Invalid output size ({self.out_features})") + + # Weight tensor attributes + device = canonicalize_device(device) + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Initialize recipe state if needed for natively quantized weight + self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() + if self._with_quantized_weight: + self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) + + # RNG state tracker + self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] + self._rng_state_tracker_function = rng_state_tracker_function + + # Register weights + self.weight0: torch.nn.Parameter + for group_idx in range(self.num_groups): + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=dtype, + ) + self.register_parameter( + f"weight{group_idx}", + torch.nn.Parameter(weight_tensor), + ) + + # Register biases + self.bias0: Optional[torch.nn.Parameter] + for group_idx in range(self.num_groups): + bias_tensor = None + if bias: + bias_tensor = torch.empty( + self.out_features, + device=device, + dtype=dtype, + ) + bias_tensor = torch.nn.Parameter(bias_tensor) + self.register_parameter(f"bias{group_idx}", bias_tensor) + + # Initialize weights if needed + if device.type != "meta": + self.reset_parameters() + + # Whether to accumulate weight gradient into main_grad + self._accumulate_into_main_grad: bool = accumulate_into_main_grad + + def num_quantizers(self, mode: str) -> int: + if mode == "forward": + return 2 * self.num_groups + if mode == "backward": + return self.num_groups + return 0 + + @property + def has_bias(self) -> bool: + """Whether an additive bias is being applied""" + return self.bias0 is not None + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Parameter device + device = self.weight0.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize weights + for group_idx in range(self.num_groups): + weight = getattr(self, f"weight{group_idx}") + + # Allocate buffers if needed + if is_quantized_tensor(weight): + weight = torch.empty( + weight.size(), + dtype=weight.dtype, + device=device, + ) + elif not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + + # Initialize values + init_context = contextlib.nullcontext() + if self._rng_state_tracker_function is not None: + init_context = self._rng_state_tracker_function().fork() + with init_context: + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + # Quantize weight if needed + if self._with_quantized_weight: + quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if quantizer is None: + raise RuntimeError( + "Tried to quantize weight with deferred initialization " + "due to meta device, but no quantizer was available. " + "This is most likely because the weight was initialized " + "within quantized_model_init, but the forward pass was not " + "performed within autocast." + ) + quantizer.set_usage( + rowwise=True, + columnwise=torch.is_grad_enabled(), + ) + quantizer.internal = False + with torch.no_grad(): + weight = quantizer(weight) + + # Save updated parameters + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + setattr(self, f"weight{group_idx}", weight) + + # Initialize biases if needed + if self.bias0 is not None: + with torch.no_grad(): + for group_idx in range(self.num_groups): + bias = getattr(self, f"bias{group_idx}") + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) + bias.zero_() + if not isinstance(bias, torch.nn.Parameter): + bias = torch.nn.Parameter(bias) + setattr(self, f"bias{group_idx}", bias) + + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() + + # Initialize params if needed + if any(param.device.type == "meta" for param in self.parameters()): + self.reset_parameters() + + # Check that weights are consistent + dtype = self.weight0.dtype + device = self.weight0.device + weight_requires_grad = self.weight0.requires_grad + weight_tensor_type = type(self.weight0.data) + for group_idx in range(self.num_groups): + weight = getattr(self, f"weight{group_idx}") + if weight.dtype != dtype: + raise RuntimeError( + f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})." + ) + if not devices_match(weight.device, device): + raise RuntimeError( + f"Weight {group_idx} has invalid device " + f"(expected {device}, got {weight.device})." + ) + if weight.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Weight {group_idx} has requires_grad={weight.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck + raise RuntimeError( + f"Weight {group_idx} has invalid tensor type " + f"(expected {weight_tensor_type.__name__}, " + f"got {type(weight.data).__name__})." + ) + + # Check that biases are consistent + for group_idx in range(self.num_groups): + bias = getattr(self, f"bias{group_idx}") + if self.has_bias: + if bias is None: + raise RuntimeError(f"Expected biases, but bias {group_idx} is uninitialized") + if bias.dtype != dtype: + raise RuntimeError( + f"Bias {group_idx} has invalid dtype (expected {dtype}, got {bias.dtype})." + ) + if not devices_match(bias.device, device): + raise RuntimeError( + f"Bias {group_idx} has invalid device " + f"(expected {device}, got {bias.device})." + ) + if bias.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Bias {group_idx} has requires_grad={bias.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + else: + if bias is not None: + raise RuntimeError(f"Expected no biases, but bias {group_idx} is initialized") + + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Assume weights have consistent grad requirement + weight_requires_grad = requires_grad and self.weight0.requires_grad + + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + for group_idx in range(self.num_groups): + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: + super().reset_recipe_state(recipe=recipe) + + for group_idx in range(self.num_groups): + # Input/grad output quantizers use internal tensors + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + if input_quantizer is not None: + input_quantizer.internal = True + if grad_output_quantizer is not None: + grad_output_quantizer.internal = True + + # Handle weight quantizer + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if weight_quantizer is None: + pass + elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)): + # Make sure weight param has correct quantizer + weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + weight_quantizer.internal = False + getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) + else: + # Use internal tensors if quantized weights will not be + # exposed externally + weight_quantizer.internal = ( + not FP8GlobalStateManager.with_fp8_parameters() + and not getattr(self, "_with_quantized_weight", False) + ) + + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = ( + recipe.fp8_quant_bwd_grad.power_2_scale + ) + grad_output_quantizer.amax_epsilon_scales = ( + recipe.fp8_quant_bwd_grad.amax_epsilon + ) + + def op_forward(self, *args, **kwargs): + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs): + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + num_groups = self.num_groups + has_bias = self.has_bias + device = self.weight0.device + + # Check which grads are required + ctx = basic_op_ctxs[0] + input_requires_grad = ctx.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad + + # Quantizers + input_quantizers = [None] * num_groups + weight_quantizers = [None] * num_groups + grad_output_quantizers = [None] * num_groups + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + for group_idx in range(num_groups): + input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) + weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) + + # Get autocast dtype if needed + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = self.weight0.dtype + + # Extract split sizes from extra input + split_sizes = basic_op_extra_inputs[0][0] + split_sizes_int = [int(s) for s in split_sizes.tolist()] + if len(split_sizes_int) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.") + + # Extract params + weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] + bs = None + if has_bias: + bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups)] + + # Convert weight dtype if needed + ws = [] + for w, quantizer in zip(weights, weight_quantizers): + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + ws.append(w) + + # Split input tensor and convert dtypes if needed + x = maybe_dequantize(input_, dtype) + xs = None + if with_quantized_compute: + for quantizer in input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + xs = tex.split_quantize(x, split_sizes_int, input_quantizers) + else: + xs = torch.split(x, split_sizes_int) + + # Allocate output tensor + in_shape = list(input_.size()) + out_shape = in_shape[:-1] + [self.out_features] + out = torch.empty(out_shape, dtype=dtype, device=device) + + # Perform GEMMs + general_grouped_gemm( + ws, + xs, + [out], + [None] * num_groups, # quantization_params + dtype, + m_splits=split_sizes_int, + bias=bs, + use_bias=has_bias, + use_split_accumulator=_2X_ACC_FPROP, + single_output=True, + ) + + # Prepare weight tensors for backward pass + if not input_requires_grad: + ws = [None] * num_groups + elif with_quantized_compute: + for w, weight_param in zip(ws, weights): + if w is not weight_param: + w.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Prepare input tensor for backward pass + if not weight_requires_grad: + xs = [None] * num_groups + elif with_quantized_compute: + for x in xs: + x.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Save state for backward pass + if ctx.requires_grad: + ctx.save_for_backward(split_sizes, *xs, *ws) + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizers = input_quantizers + ctx.weight_quantizers = weight_quantizers + ctx.grad_output_quantizers = grad_output_quantizers + ctx.grad_input_quantizers = None + ctx.dtype = dtype + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + num_groups = self.num_groups + has_bias = self.has_bias + device = self.weight0.device + + # Saved tensors from forward pass + ctx = basic_op_ctxs[0] + saved_tensors = ctx.saved_tensors + split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] + xs, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] + ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] + + # Split grad output tensor and convert dtypes if needed + split_sizes_int = [int(s) for s in split_sizes.tolist()] + dy = maybe_dequantize(grad_output, ctx.dtype) + dys = None + grad_biases = [None] * num_groups + if ctx.with_quantized_compute: + for quantizer in ctx.grad_output_quantizers: + quantizer.set_usage( + rowwise=ctx.input_requires_grad, + columnwise=ctx.weight_requires_grad, + ) + dys = tex.split_quantize(dy, split_sizes_int, ctx.grad_output_quantizers) + if has_bias: + grad_biases = [ + dy.reshape(-1, dy.size(-1)).sum(dim=0) + for dy in torch.split(grad_output, split_sizes_int) + ] + else: + dys = torch.split(dy, split_sizes_int) + if has_bias: + grad_biases = [dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys] + + # Initialize grad weight grads + accumulate_into_main_grad = self._accumulate_into_main_grad + grad_weights = [None] * num_groups + if ctx.weight_requires_grad: + if accumulate_into_main_grad: + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can + # accumulate directly into it. + for group_idx in range(num_groups): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr( + weight_param, "overwrite_main_grad", False + ) + grad_weights[group_idx] = weight_param.main_grad + else: + weight_shape = ws[0].size() + for group_idx in range(num_groups): + grad_weights[group_idx] = torch.empty( + weight_shape, + dtype=ctx.dtype, + device=device, + ) + else: + accumulate_into_main_grad = False + + # Perform dgrad GEMMs + grad_input = None + if ctx.input_requires_grad: + out_shape = list(grad_output.size()) + in_shape = out_shape[:-1] + [self.in_features] + grad_input = torch.empty( + in_shape, + dtype=ctx.dtype, + device=device, + ) + general_grouped_gemm( + ws, + dys, + [grad_input], + [None] * num_groups, # quantization_params + ctx.dtype, + layout="NN", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_DGRAD, + single_output=True, + ) + + # Perform wgrad GEMMs + if ctx.weight_requires_grad: + general_grouped_gemm( + xs, + dys, + grad_weights, + [None] * num_groups, # quantization_params + ctx.dtype, + layout="NT", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_into_main_grad, + ) + + # Clear input tensors if possible + clear_tensor_data(*xs) + + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weights = [None] * num_groups + for group_idx in range(num_groups): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weights[group_idx] = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + grad_params = grad_weights + grad_biases if has_bias else grad_weights + return grad_input, [grad_params], [(None,)] diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py new file mode 100644 index 0000000000..8f1068ba46 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -0,0 +1,415 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for multiplying with extra input tensor.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...tensor import Float8CurrentScalingQuantizer, Quantizer +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize + +__all__ = ["SwiGLU", "ClampedSwiGLU", "ScaledSwiGLU"] + + +class SwiGLU(BasicOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:``a`` and :math:``b`` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:``a`` and + :math:``b``. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + ``GLU Variants Improve Transformer``__ + and ``Gaussian Error Linear Units (GELUs)``__. + + """ + + def __init__( + self, *, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None + ): + super().__init__() + self.cache_quantized_input: bool = cache_quantized_input + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + input_ = maybe_dequantize(input_.contiguous(), dtype) + + # Remove interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Launch kernel + out = tex.swiglu(swiglu_in, next_op_input_quantizer) + + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + input_quantizer = Float8CurrentScalingQuantizer( + tex.DType.kFloat8E4M3, + input_.device, + ) + input_quantizer.set_usage(rowwise=True, columnwise=False) + input_ = input_quantizer(input_) + + # Save state for backward pass + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(input_) + ctx.save_for_backward(input_) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (input_,) = ctx.saved_tensors + + # Make sure tensors have correct dtypes + x = maybe_dequantize(input_.contiguous(), ctx.dtype) + dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + # Remove interleaving if needed + swiglu_in = x + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Quantizer for grad input + quantizer = ctx.prev_op_grad_output_quantizer + if self.glu_interleave_size is not None: + quantizer = None + + # Launch kernel + grad_swiglu_in = tex.dswiglu(dy, swiglu_in, quantizer) + + # Apply interleaving if needed + dx = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = dx.size() + dx = dx.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + dx = dx.transpose(1, 2).contiguous() + dx = dx.view(shape) + + # Clear input tensor if possible + clear_tensor_data(input_) + + return dx, () + + +class ClampedSwiGLU(BasicOperation): + r"""GPT-OSS + Implementation based on ``GPT-OSS``__. + + This activation has two differences compared to the original SwiGLU + 1. Both gate and pre-activations are clipped based on parameter limit. + 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt + from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. + + Parameters + ---------- + limit : float + The clamp limit. + alpha : float + The scaling factor for the sigmoid function used in the activation. + cache_quantized_input : bool, default = ``False`` + Quantize input tensor when caching for use in the backward pass. + """ + + def __init__( + self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False + ): + super().__init__() + self.limit: float = limit + self.alpha: float = alpha + self.cache_quantized_input: bool = cache_quantized_input + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = maybe_dequantize(input_.contiguous(), dtype) + + # Launch kernel + y = tex.clamped_swiglu( + x, + next_op_input_quantizer, + limit=self.limit, + alpha=self.alpha, + ) + + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer.set_usage(rowwise=True, columnwise=False) + x = input_quantizer(x) + + # Save state for backward pass + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.save_for_backward(x) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (input_,) = ctx.saved_tensors + + # Make sure tensors have correct dtypes + x = maybe_dequantize(input_.contiguous(), ctx.dtype) + dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + # Launch kernel + dx = tex.clamped_dswiglu( + dy, + x, + ctx.prev_op_grad_output_quantizer, + limit=self.limit, + alpha=self.alpha, + ) + + # Clear input tensor if possible + clear_tensor_data(input_) + + return dx, () + + +class ScaledSwiGLU(BasicOperation): + """SwiGLU with post-scaling + + If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is + multiplied with an extra input tensor of shape + ``(d_1, ..., d_{n-1})``. + + """ + + # Operation expects scales + num_extra_inputs: int = 1 + + def __init__(self, glu_interleave_size: Optional[int] = None): + super().__init__() + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + # Determine compute dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + # Make sure inputs are in correct dtype + input_ = maybe_dequantize(input_, dtype) + scales = maybe_dequantize(extra_input, dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute scaled SwiGLU + swiglu_out = tex.swiglu(swiglu_in, None) + out = swiglu_out * scales.unsqueeze(-1) + + # Save state for backward pass + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(input_) + ctx.input_requires_grad = True + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.dtype = dtype + ctx.save_for_backward( + input_, + scales if ctx.input_requires_grad else None, + ) + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + ctx = basic_op_ctxs[0] + input_, scales = ctx.saved_tensors + input_ = maybe_dequantize(input_, ctx.dtype) + if scales is not None: + scales = maybe_dequantize(scales, ctx.dtype) + grad_output = maybe_dequantize(grad_output, ctx.dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute input grad + grad_input = None + if ctx.input_requires_grad: + grad_swiglu_out = grad_output * scales.unsqueeze(-1) + grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None) + grad_input = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = grad_input.size() + grad_input = grad_input.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + grad_input = grad_input.transpose(1, 2).contiguous() + grad_input = grad_input.view(shape) + + # Compute scales grad by recomputing SwiGLU + grad_extra_input = None + if ctx.extra_input_requires_grad: + swiglu_out = tex.swiglu(swiglu_in, None) + grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output) + + # Clear input tensor if possible + clear_tensor_data(ctx.saved_tensors[0]) # input_ + + return grad_input, [()], [(grad_extra_input,)] diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 19608894e0..52cda9caf6 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -28,3 +28,9 @@ register_backward_fusion(BackwardLinearScale.fuse_backward_ops) register_backward_fusion(BackwardActivationBias.fuse_backward_ops) register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) + +# Import experimental fusions +# Note: Registration logic is non-trivial, so submodule handles it internally. +from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position + ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, +) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py new file mode 100644 index 0000000000..19901cb4af --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -0,0 +1,453 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for MoE grouped MLP.""" + +from __future__ import annotations +from collections.abc import Callable, Iterable +import functools +import itertools +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpp_extensions import general_grouped_gemm +from ...quantization import Recipe +from ...tensor import MXFP8Tensor, Quantizer +from ...utils import get_device_compute_capability +from ..basic import GroupedLinear, ScaledSwiGLU +from ..fuser import register_forward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import is_quantized_tensor, maybe_dequantize + + +class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): + """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear + + Uses experimental CuTe DSL kernel from cuDNN front-end. + + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_swiglu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, SwiGLU, and post-multiplication.""" + from cudnn import grouped_gemm_swiglu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_swiglu_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether this fused operation is supported on the current system.""" + if get_device_compute_capability() < (10, 0): + # Kernel requires SM100+ + return False + try: + # Make sure kernel is available + cls.grouped_gemm_swiglu_kernel() + except ImportError: + return False + return True + + def __init__( + self, + *, + fc1: GroupedLinear, + swiglu: ScaledSwiGLU, + fc2: GroupedLinear, + ) -> None: + super().__init__((fc1, swiglu, fc2)) + + # Check for unsupported configurations + if not self.is_supported(): + self.grouped_gemm_swiglu_kernel() # Try triggering import error + raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") + if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " + f"in_features={fc1.in_features}, out_features={fc1.out_features})." + ) + if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " + f"in_features={fc2.in_features}, out_features={fc2.out_features})." + ) + if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups: + raise ValueError( + f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " + f"out_features={fc1.out_features}) " + f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " + f"out_features={fc2.out_features}) do not match." + ) + if fc1.has_bias or fc2.has_bias: + raise ValueError("Fused kernel does not support bias.") + if swiglu.glu_interleave_size != 32: + raise ValueError( + "Fused kernel requires 32-wide GLU interleaving, " + "but got glu_interleave_size={swiglu.glu_interleave_size}." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + fc1_op, _, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + + # Tensor properties + in_shape = list(input_.size()) + assert len(in_shape) == 2, f"Expected 2D input tensor, got shape={in_shape}." + fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) + fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) + num_groups = fc1_op.num_groups + device = fc1_op.weight0.device + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = fc1_op.weight0.dtype + + # Check which grads are required + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + weight_requires_grad = requires_grad and ( + fc1_op.weight0.requires_grad or fc2_op.weight0.requires_grad + ) + + # Quantizers + fc1_input_quantizers = [None] * num_groups + fc1_weight_quantizers = [None] * num_groups + fc1_grad_output_quantizers = [None] * num_groups + fc2_input_quantizers = [None] * num_groups + fc2_weight_quantizers = [None] * num_groups + fc2_grad_output_quantizers = [None] * num_groups + for idx in range(num_groups): + fc1_input_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx) + fc1_weight_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx + 1) + fc1_grad_output_quantizers[idx] = fc1_op.get_quantizer("backward", idx) + fc2_input_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx) + fc2_weight_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx + 1) + fc2_grad_output_quantizers[idx] = fc2_op.get_quantizer("backward", idx) + + # Extract split sizes from extra input + fc1_split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + fc1_split_sizes.size() != fc2_split_sizes.size() + or fc1_split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError( + f"{self.__class__.__name__} got different split points for FC1 and FC2." + ) + split_sizes = fc1_split_sizes + split_sizes_cpu = [int(s) for s in split_sizes.tolist()] + if len(split_sizes_cpu) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_cpu)}.") + split_sizes = split_sizes.to(dtype=torch.int, device=device) + split_points = torch.zeros( + split_sizes.numel() + 1, + dtype=torch.int, + device=device, + ) + torch.cumsum(split_sizes, 0, out=split_points[1:]) + + # Extract post-scales from extra input + scales = basic_op_extra_inputs[1][0] + + # Extract params + fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] + + # Convert weight dtype if needed + fc1_ws = [] + fc2_ws = [] + for w, quantizer in zip(fc1_weights, fc1_weight_quantizers): + if not is_quantized_tensor(w): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + fc1_ws.append(w) + for w, quantizer in zip(fc2_weights, fc2_weight_quantizers): + if not is_quantized_tensor(w): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + fc2_ws.append(w) + + # Split input tensor and convert dtypes if needed + fc1_x = maybe_dequantize(input_, dtype) + fc1_xs = None + for quantizer in fc1_input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + quantizer.optimize_for_gemm = True + fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) + + # Pack data tensors + fc1_x_data = torch.cat([x._rowwise_data for x in fc1_xs]) + fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) + fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) + fc1_x_scales = torch.cat([x._rowwise_scale_inv for x in fc1_xs]) + fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) + fc1_x_scales = fc1_x_scales.view( + 1, + in_shape[0] // 128, + in_shape[1] // 128, + 32, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + + # Pack weight tensors + fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) + fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) + fc1_w_data = fc1_w_data.view( + num_groups, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] + ) + fc1_w_data = fc1_w_data.flip(2).contiguous() # Swap SwiGLU gate/activation + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.permute(1, 2, 0) + fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) + fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) + fc1_w_scales = fc1_w_scales.view( + num_groups, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] // 32 + ) + fc1_w_scales = fc1_w_scales.flip(2).contiguous() # Swap SwiGLU gate/activation + fc1_w_scales = fc1_w_scales.view( + num_groups, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4 + ) + fc1_w_scales = fc1_w_scales.permute( + 0, 1, 4, 3, 2, 5 + ).contiguous() # Convert to swizzled layout + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + # Kernel tile logic + mma_tiler_mn = (256, 256) + tile_points = torch.arange( + 0, + in_shape[0], + mma_tiler_mn[0], + dtype=torch.int, + device=device, + ) + tile_idx_to_expert_idx = torch.searchsorted( + split_points[1:], + tile_points, + out_int32=True, + side="right", + ) + num_non_exiting_tiles = torch.full( + (1,), + in_shape[0] // mma_tiler_mn[0], + dtype=torch.int, + device=device, + ) + + # Fused kernel for FC1 + SwiGLU + post-scale + fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( + fc1_x_data, + fc1_w_data, + fc1_x_scales, + fc1_w_scales, + tile_idx_to_expert_idx, + num_non_exiting_tiles, + torch.ones(num_groups, dtype=dtype, device=device), # alpha_tensor + torch.ones(1, dtype=dtype, device=device), # norm_const_tensor + scales.detach().reshape(-1, 1, 1), + split_points, + acc_dtype=torch.float32, + c_dtype=torch.bfloat16, + d_dtype=torch.float8_e4m3fn, + cd_major="n", + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=(2, 1), + sf_vec_size=32, + ) + + # Unpack kernel outputs + swiglu_in = fc1_kernel_out["c_tensor"] + swiglu_in = swiglu_in.permute(2, 0, 1) + swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0] // 64, 2, 32) + swiglu_in = swiglu_in.flip(2) # Undo swapped SwiGLU gate/activation + swiglu_in = swiglu_in.contiguous().view(in_shape[0], fc1_weight_shape[0]) + fc2_in_row_data = fc1_kernel_out["d_tensor"] + fc2_in_row_data = fc2_in_row_data.permute(2, 0, 1) + fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_row_data = torch.split(fc2_in_row_data.contiguous(), split_sizes_cpu) + fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) + fc2_in_row_scale = fc2_in_row_scale.view(in_shape[0], fc2_weight_shape[1] // 32) + fc2_in_row_scale = torch.split(fc2_in_row_scale.contiguous(), split_sizes_cpu) + fc2_in_col_data = fc1_kernel_out["d_col_tensor"] + fc2_in_col_data = fc2_in_col_data.permute(2, 0, 1) + fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_col_data = torch.split(fc2_in_col_data.contiguous(), split_sizes_cpu) + fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] + fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) + fc2_in_col_scale = torch.split(fc2_in_col_scale, [s // 128 for s in split_sizes_cpu], dim=2) + fc2_in_col_scale = [s.contiguous().view(-1, fc2_weight_shape[1]) for s in fc2_in_col_scale] + + # Construct MXFP8 tensors for FC2 + fc2_xs = [] + for group_idx in range(num_groups): + x = MXFP8Tensor( + shape=(split_sizes_cpu[group_idx], fc2_weight_shape[1]), + dtype=dtype, + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise_data=fc2_in_row_data[group_idx], + rowwise_scale_inv=fc2_in_row_scale[group_idx], + columnwise_data=fc2_in_col_data[group_idx], + columnwise_scale_inv=fc2_in_col_scale[group_idx], + quantizer=fc2_input_quantizers[group_idx], + requires_grad=False, + with_gemm_swizzled_scales=True, + ) + fc2_xs.append(x) + + # FC2 GEMM + fc2_out_shape = in_shape[:-1] + [fc2_weight_shape[0]] + fc2_out = torch.empty(fc2_out_shape, dtype=dtype, device=device) + general_grouped_gemm( + fc2_ws, + fc2_xs, + [fc2_out], + [None] * num_groups, # quantization_params + dtype, + m_splits=split_sizes_cpu, + bias=[None] * num_groups, + use_bias=False, + single_output=True, + ) + + # Prepare input tensors for backward pass + for x in itertools.chain(fc1_xs, fc2_xs): + x.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Save state for backward pass + if requires_grad: + # FC1 + fc1_ctx.save_for_backward(split_sizes, *fc1_xs, *fc1_ws) + fc1_ctx.with_quantized_compute = True + fc1_ctx.input_quantizers = fc1_input_quantizers + fc1_ctx.weight_quantizers = fc1_weight_quantizers + fc1_ctx.grad_output_quantizers = fc1_grad_output_quantizers + fc1_ctx.grad_input_quantizers = None + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = weight_requires_grad + + # Scaled SwiGLU + swiglu_ctx.save_for_backward(swiglu_in, scales) + swiglu_ctx.input_requires_grad = True + swiglu_ctx.extra_input_requires_grad = True + swiglu_ctx.dtype = dtype + + # FC2 state + fc2_ctx.save_for_backward(split_sizes, *fc2_xs, *fc2_ws) + fc2_ctx.with_quantized_compute = True + fc2_ctx.input_quantizers = fc2_input_quantizers + fc2_ctx.weight_quantizers = fc2_weight_quantizers + fc2_ctx.grad_output_quantizers = fc2_grad_output_quantizers + fc2_ctx.grad_input_quantizers = None + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = weight_requires_grad + + return fc2_out, [(), (), ()] + + +def fuse_forward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Return immediately if fused kernel is not supported + if not ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + return ops + + # Check if recipe is supported + if recipe is None: + return ops + if not recipe.mxfp8(): + return ops + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + + # Check if window matches pattern + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[1], ScaledSwiGLU) + and isinstance(window[2], GroupedLinear) + ): + matches_pattern = False + elif window[0].has_bias or window[2].has_bias: + matches_pattern = False + elif window[0].num_groups != window[2].num_groups: + matches_pattern = False + elif ( + window[0].in_features % 256 != 0 + or window[0].out_features % 256 != 0 + or window[2].in_features % 256 != 0 + or window[2].out_features % 256 != 0 + ): + matches_pattern = False + elif window[1].glu_interleave_size != 32: + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8( + fc1=window[0], + swiglu=window[1], + fc2=window[2], + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops + out.extend(window) + return out + + +# Register fusion if available +if ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + register_forward_fusion(fuse_forward_ops, prepend=True)