-
Notifications
You must be signed in to change notification settings - Fork 615
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP #2622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Expose option for custom op fusions Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add tests for custom ops Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings and numerical test failures Signed-off-by: Tim Moon <tmoon@nvidia.com> * Tweak pattern matching logic with fixed window sizes Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use TF32 tols in fused op tests Signed-off-by: Tim Moon <tmoon@nvidia.com> * Review suggestion from @greptile-apps Signed-off-by: Tim Moon <tmoon@nvidia.com> * Backpropagate fixes from #2622 Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
Greptile OverviewGreptile SummaryThis PR introduces grouped linear operations and experimental fused kernels for Mixture-of-Experts (MoE) models. The implementation adds Key changes:
The code quality is high with proper error handling, quantization support, and backward pass implementations. All previously identified issues in review threads have been addressed. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant GroupedLinear as GroupedLinear (FC1)
participant ScaledSwiGLU
participant GroupedLinear2 as GroupedLinear (FC2)
participant FusedKernel as CuTeGEMMSwiGLU (Fused)
Note over User,FusedKernel: Standard Path (No Fusion)
User->>GroupedLinear: input, split_sizes
GroupedLinear->>GroupedLinear: Split input by split_sizes
GroupedLinear->>GroupedLinear: Quantize inputs (MXFP8)
GroupedLinear->>GroupedLinear: Quantize weights (MXFP8)
GroupedLinear->>GroupedLinear: general_grouped_gemm
GroupedLinear->>ScaledSwiGLU: FC1 output
ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving
ScaledSwiGLU->>ScaledSwiGLU: Apply SwiGLU activation
ScaledSwiGLU->>ScaledSwiGLU: Multiply by scales
ScaledSwiGLU->>GroupedLinear2: Scaled output
GroupedLinear2->>GroupedLinear2: Split and quantize inputs
GroupedLinear2->>GroupedLinear2: general_grouped_gemm
GroupedLinear2->>User: Final output
Note over User,FusedKernel: Fused Path (SM100+, MXFP8)
User->>FusedKernel: input, split_sizes, scales
FusedKernel->>FusedKernel: Split and quantize FC1 inputs
FusedKernel->>FusedKernel: Pack FC1 input/weight data
FusedKernel->>FusedKernel: grouped_gemm_swiglu_kernel
Note over FusedKernel: Single kernel: FC1 GEMM + SwiGLU + post-scale
FusedKernel->>FusedKernel: Unpack kernel outputs
FusedKernel->>FusedKernel: Construct MXFP8 tensors for FC2
FusedKernel->>FusedKernel: general_grouped_gemm (FC2)
FusedKernel->>User: Final output
|
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
| quantizer.optimize_for_gemm = True | ||
| fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) | ||
|
|
||
| # Pack data tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: