Skip to content

support bf16 beta for KDA (SM90 & SM10X)#34

Open
cherhh wants to merge 2 commits intoinclusionAI:mainfrom
cherhh:dev.beta_bf16
Open

support bf16 beta for KDA (SM90 & SM10X)#34
cherhh wants to merge 2 commits intoinclusionAI:mainfrom
cherhh:dev.beta_bf16

Conversation

@cherhh
Copy link
Copy Markdown
Collaborator

@cherhh cherhh commented Apr 5, 2026

Done:

  • support bf16 beta for SM90 and SM100 forward kernels
  • support bf16 beta test
  • misc: update README beta dtype docs and roadmap

Related issues: #12
Follow‑up to #26 . Thanks @meinie0826 for the original implementation.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for bfloat16 as a data type for the beta input in KDA kernels across SM90 (Hopper) and SM100 (Blackwell) architectures. The changes include updating the C++ API with appropriate type checks, modifying kernel dispatch logic to handle different global memory types, and templating mainloops to support bfloat16 while maintaining float32 for shared memory and computation. Additionally, the documentation has been updated, and a comprehensive test suite was added to verify numerical parity between float32 and bfloat16 beta inputs. I have no feedback to provide as the implementation is thorough and includes necessary validation.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for passing beta as BF16 (in addition to FP32) for KDA forward kernels on SM90 (Hopper) and SM100/SM10X (Blackwell), with runtime/compile-time dispatch to keep compute in FP32 while reducing GMEM bandwidth.

Changes:

  • SM90: Templatize the prefill kernel launcher and mainloop beta load path to accept BF16 beta from GMEM (while keeping SMEM/compute in FP32), plus API dispatch based on beta.dtype.
  • SM100: Add is_beta_bf16 to kernel params and dispatch to beta-load variants that reinterpret beta as BF16 and upcast to FP32 for compute.
  • Add BF16-beta regression tests and update README/USAGE docs to reflect the new supported beta dtype.

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated no comments.

Show a summary per file
File Description
USAGE.md Updates user-facing dtype requirements for beta vs initial_state.
README.md Documents BF16 beta support and marks the roadmap item complete.
tests/test_bf16_beta.py New SM90/SM100 test coverage comparing BF16-beta vs FP32-beta outputs (and SM100 grads).
csrc/kda/sm90/prefill_kernel.hpp Extends SM90 launcher template to accept a configurable beta pointer element type.
csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh Wires beta GMEM element type through options to the mainloop and launcher signature.
csrc/kda/sm90/kernel/options.hpp Adds an option tag to carry beta GMEM element type (float vs bf16).
csrc/kda/sm90/kda_fwd_sm90.cu Adds explicit instantiations and updates dispatch wrapper for templated TBeta.
csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu Adds explicit instantiations for the BF16-beta safe-gate variants.
csrc/kda/sm90/collective/mainloop_kda_fwd.hpp Loads beta from GMEM as (float/bf16) while keeping SMEM/compute beta as float.
csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp Templatizes beta load type and upcasts to float in shared beta staging.
csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp Introduces BF16-beta kernel alias and runtime dispatch via params.is_beta_bf16.
csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp Adds beta element-type template parameter and BF16->FP32 upcast on load.
csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp Adds BF16-beta kernel variants and dispatch logic keyed by params.is_beta_bf16.
csrc/kda/sm100/kda_config.hpp Adds is_beta_bf16 to SM100 forward param structs for dispatch.
csrc/api/kda_sm90.cu Accepts BF16 beta and dispatches to the BF16-beta SM90 instantiation.
csrc/api/kda_sm100.cu Validates beta dtype (fp32/bf16), sets is_beta_bf16, and uses SM100 dispatch.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

cherhh added 2 commits April 5, 2026 13:13
SM90 (Hopper):
- Introduce ElementBetaGmem type alias to decouple GMEM load type from
  SMEM/compute type, enabling bf16 GMEM load with fp32 compute
- CollectiveLoadVector handles implicit bf16->fp32 conversion via
  Copy_Atom<UniversalCopy<ElementSrc>, ElementDst>
- Add TBeta template parameter and explicit instantiations for both
  float and bf16 in kda_fwd_sm90.cu and kda_fwd_sm90_safe_gate.cu

SM100 (Blackwell):
- Add is_beta_bf16 runtime flag to KDA_fwd_intra_params and
  KDA_fwd_recomp_w_u_params
- Template-specialize intra and recomp_w_u mainloops on ElementBeta_
- Dispatch to bf16/fp32 kernel variants at runtime (4->8 intra, 1->2 recomp)

Tests:
- Add comprehensive bf16 beta tests (tests/test_bf16_beta.py) covering
  SM90 and SM100, fixed-length and varlen configurations
Copy link
Copy Markdown
Collaborator

@icavan icavan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -0,0 +1,334 @@
# Copyright 2025-2026 Ant Group Co., Ltd.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more reasonable to compare with FLA and naive implementation.
Can we add a pytest param for beta with float32 and bfloat16 in sm100 & sm90 kda tests, just like disable_recompute flag?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll update the tests

? reinterpret_cast<float*>(
params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]
? float(reinterpret_cast<ElementBeta_*>(
params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx])
Copy link
Copy Markdown
Collaborator

@KevinZeng08 KevinZeng08 Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may contains risk of numerical differences, because FLA recompute_wu kernel computes K/V*beta directly without converting beta to float.
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/wy_fast.py#L69
I will check it tomorrow.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this is a valid concern. I'll also check the numerical impact of this forward/backward precision mismatch and follow up.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may contains risk of numerical differences, because FLA recompute_wu kernel computes K/V*beta directly without converting beta to float. https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/wy_fast.py#L69 I will check it tomorrow.

did you mean the safe way to align with FLA is to keep the bf16 K*Beta ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly. If beta is bf16, first compute bf16 K*beta and then convert to float for gating. If beta is fp32, convert bf16 K to fp32 first and then multiply with beta and do gating.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly. If beta is bf16, first compute bf16 K*beta and then convert to float for gating. If beta is fp32, convert bf16 K to fp32 first and then multiply with beta and do gating.

Looks like it's fine to perform upcast (bf16 -> fp32), the it's dangerous to do the downcast (fp32 inputs to bf16).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants