support bf16 beta for KDA (SM90 & SM10X)#34
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for bfloat16 as a data type for the beta input in KDA kernels across SM90 (Hopper) and SM100 (Blackwell) architectures. The changes include updating the C++ API with appropriate type checks, modifying kernel dispatch logic to handle different global memory types, and templating mainloops to support bfloat16 while maintaining float32 for shared memory and computation. Additionally, the documentation has been updated, and a comprehensive test suite was added to verify numerical parity between float32 and bfloat16 beta inputs. I have no feedback to provide as the implementation is thorough and includes necessary validation.
There was a problem hiding this comment.
Pull request overview
Adds support for passing beta as BF16 (in addition to FP32) for KDA forward kernels on SM90 (Hopper) and SM100/SM10X (Blackwell), with runtime/compile-time dispatch to keep compute in FP32 while reducing GMEM bandwidth.
Changes:
- SM90: Templatize the prefill kernel launcher and mainloop beta load path to accept BF16 beta from GMEM (while keeping SMEM/compute in FP32), plus API dispatch based on
beta.dtype. - SM100: Add
is_beta_bf16to kernel params and dispatch to beta-load variants that reinterpret beta as BF16 and upcast to FP32 for compute. - Add BF16-beta regression tests and update README/USAGE docs to reflect the new supported
betadtype.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
USAGE.md |
Updates user-facing dtype requirements for beta vs initial_state. |
README.md |
Documents BF16 beta support and marks the roadmap item complete. |
tests/test_bf16_beta.py |
New SM90/SM100 test coverage comparing BF16-beta vs FP32-beta outputs (and SM100 grads). |
csrc/kda/sm90/prefill_kernel.hpp |
Extends SM90 launcher template to accept a configurable beta pointer element type. |
csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh |
Wires beta GMEM element type through options to the mainloop and launcher signature. |
csrc/kda/sm90/kernel/options.hpp |
Adds an option tag to carry beta GMEM element type (float vs bf16). |
csrc/kda/sm90/kda_fwd_sm90.cu |
Adds explicit instantiations and updates dispatch wrapper for templated TBeta. |
csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu |
Adds explicit instantiations for the BF16-beta safe-gate variants. |
csrc/kda/sm90/collective/mainloop_kda_fwd.hpp |
Loads beta from GMEM as (float/bf16) while keeping SMEM/compute beta as float. |
csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp |
Templatizes beta load type and upcasts to float in shared beta staging. |
csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp |
Introduces BF16-beta kernel alias and runtime dispatch via params.is_beta_bf16. |
csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp |
Adds beta element-type template parameter and BF16->FP32 upcast on load. |
csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp |
Adds BF16-beta kernel variants and dispatch logic keyed by params.is_beta_bf16. |
csrc/kda/sm100/kda_config.hpp |
Adds is_beta_bf16 to SM100 forward param structs for dispatch. |
csrc/api/kda_sm90.cu |
Accepts BF16 beta and dispatches to the BF16-beta SM90 instantiation. |
csrc/api/kda_sm100.cu |
Validates beta dtype (fp32/bf16), sets is_beta_bf16, and uses SM100 dispatch. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
SM90 (Hopper): - Introduce ElementBetaGmem type alias to decouple GMEM load type from SMEM/compute type, enabling bf16 GMEM load with fp32 compute - CollectiveLoadVector handles implicit bf16->fp32 conversion via Copy_Atom<UniversalCopy<ElementSrc>, ElementDst> - Add TBeta template parameter and explicit instantiations for both float and bf16 in kda_fwd_sm90.cu and kda_fwd_sm90_safe_gate.cu SM100 (Blackwell): - Add is_beta_bf16 runtime flag to KDA_fwd_intra_params and KDA_fwd_recomp_w_u_params - Template-specialize intra and recomp_w_u mainloops on ElementBeta_ - Dispatch to bf16/fp32 kernel variants at runtime (4->8 intra, 1->2 recomp) Tests: - Add comprehensive bf16 beta tests (tests/test_bf16_beta.py) covering SM90 and SM100, fixed-length and varlen configurations
| @@ -0,0 +1,334 @@ | |||
| # Copyright 2025-2026 Ant Group Co., Ltd. | |||
There was a problem hiding this comment.
I think it's more reasonable to compare with FLA and naive implementation.
Can we add a pytest param for beta with float32 and bfloat16 in sm100 & sm90 kda tests, just like disable_recompute flag?
There was a problem hiding this comment.
sure, I'll update the tests
| ? reinterpret_cast<float*>( | ||
| params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx] | ||
| ? float(reinterpret_cast<ElementBeta_*>( | ||
| params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) |
There was a problem hiding this comment.
There may contains risk of numerical differences, because FLA recompute_wu kernel computes K/V*beta directly without converting beta to float.
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/wy_fast.py#L69
I will check it tomorrow.
There was a problem hiding this comment.
I see, this is a valid concern. I'll also check the numerical impact of this forward/backward precision mismatch and follow up.
There was a problem hiding this comment.
There may contains risk of numerical differences, because FLA recompute_wu kernel computes K/V*beta directly without converting beta to float. https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/wy_fast.py#L69 I will check it tomorrow.
did you mean the safe way to align with FLA is to keep the bf16 K*Beta ?
There was a problem hiding this comment.
exactly. If beta is bf16, first compute bf16 K*beta and then convert to float for gating. If beta is fp32, convert bf16 K to fp32 first and then multiply with beta and do gating.
There was a problem hiding this comment.
exactly. If beta is bf16, first compute bf16 K*beta and then convert to float for gating. If beta is fp32, convert bf16 K to fp32 first and then multiply with beta and do gating.
Looks like it's fine to perform upcast (bf16 -> fp32), the it's dangerous to do the downcast (fp32 inputs to bf16).
Done:
Related issues: #12
Follow‑up to #26 . Thanks @meinie0826 for the original implementation.