Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the performance Performance issues label Jan 24, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10 added a commit to timmoon10/TransformerEngine that referenced this pull request Jan 24, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10 added a commit that referenced this pull request Jan 25, 2026
* Expose option for custom op fusions

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add tests for custom ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings and numerical test failures

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Tweak pattern matching logic with fixed window sizes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use TF32 tols in fused op tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Backpropagate fixes from #2622

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@timmoon10 timmoon10 mentioned this pull request Jan 25, 2026
13 tasks
@timmoon10 timmoon10 changed the title [PyTorch] Prototype of fused operation for grouped MLP [PyTorch] Add grouped linear op and experimental fusion for grouped MLP Jan 25, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 marked this pull request as ready for review January 25, 2026 01:00
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 25, 2026

Greptile Overview

Greptile Summary

This PR introduces grouped linear operations and experimental fused kernels for Mixture-of-Experts (MoE) models. The implementation adds GroupedLinear which applies multiple independent linear transformations to split inputs, ScaledSwiGLU for post-scaled SwiGLU activation with GLU interleaving support, and an experimental ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 fused operation that combines FC1 + SwiGLU + FC2 in a single CuTe DSL kernel for SM100+ GPUs.

Key changes:

  • GroupedLinear operation supports FP8/MXFP8 quantization, dynamic parameter registration for multiple groups, and Megatron-LM main_grad accumulation
  • ScaledSwiGLU enables element-wise scaling after SwiGLU activation with configurable gate/activation interleaving
  • Fused kernel packs MXFP8 data into specialized layouts, executes grouped GEMM + SwiGLU fusion, and unpacks outputs for the second GEMM
  • SwiGLU operations refactored from activation.py into dedicated swiglu.py module
  • Comprehensive test coverage added for all new operations

The code quality is high with proper error handling, quantization support, and backward pass implementations. All previously identified issues in review threads have been addressed.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - all identified issues have been addressed
  • High-quality implementation with comprehensive error handling, proper quantization support, extensive test coverage, and all previously reported issues resolved. The experimental nature of the fused kernel is appropriately documented and gated behind capability checks.
  • No files require special attention - all previously identified issues have been resolved

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py New grouped linear operation - comprehensive implementation with proper error handling, quantization support, and backward pass logic
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Experimental fused MXFP8 grouped MLP kernel using CuTe DSL - handles FC1 + SwiGLU + FC2 fusion with proper tensor packing
transformer_engine/pytorch/ops/basic/swiglu.py Adds SwiGLU, ClampedSwiGLU, and ScaledSwiGLU operations with proper interleaving support and backward implementations
tests/pytorch/test_fusible_ops.py Comprehensive test coverage added for grouped linear operations and fused grouped MLP

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear as GroupedLinear (FC1)
    participant ScaledSwiGLU
    participant GroupedLinear2 as GroupedLinear (FC2)
    participant FusedKernel as CuTeGEMMSwiGLU (Fused)

    Note over User,FusedKernel: Standard Path (No Fusion)
    User->>GroupedLinear: input, split_sizes
    GroupedLinear->>GroupedLinear: Split input by split_sizes
    GroupedLinear->>GroupedLinear: Quantize inputs (MXFP8)
    GroupedLinear->>GroupedLinear: Quantize weights (MXFP8)
    GroupedLinear->>GroupedLinear: general_grouped_gemm
    GroupedLinear->>ScaledSwiGLU: FC1 output
    ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving
    ScaledSwiGLU->>ScaledSwiGLU: Apply SwiGLU activation
    ScaledSwiGLU->>ScaledSwiGLU: Multiply by scales
    ScaledSwiGLU->>GroupedLinear2: Scaled output
    GroupedLinear2->>GroupedLinear2: Split and quantize inputs
    GroupedLinear2->>GroupedLinear2: general_grouped_gemm
    GroupedLinear2->>User: Final output

    Note over User,FusedKernel: Fused Path (SM100+, MXFP8)
    User->>FusedKernel: input, split_sizes, scales
    FusedKernel->>FusedKernel: Split and quantize FC1 inputs
    FusedKernel->>FusedKernel: Pack FC1 input/weight data
    FusedKernel->>FusedKernel: grouped_gemm_swiglu_kernel
    Note over FusedKernel: Single kernel: FC1 GEMM + SwiGLU + post-scale
    FusedKernel->>FusedKernel: Unpack kernel outputs
    FusedKernel->>FusedKernel: Construct MXFP8 tensors for FC2
    FusedKernel->>FusedKernel: general_grouped_gemm (FC2)
    FusedKernel->>User: Final output
Loading

greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

quantizer.optimize_for_gemm = True
fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers)

# Pack data tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants