Skip to content

test/bench: add SM100 Blackwell fused forward correctness tests and benchmark#35

Open
meinie0826 wants to merge 1 commit intoinclusionAI:mainfrom
meinie0826:feat/sm100-test-bench
Open

test/bench: add SM100 Blackwell fused forward correctness tests and benchmark#35
meinie0826 wants to merge 1 commit intoinclusionAI:mainfrom
meinie0826:feat/sm100-test-bench

Conversation

@meinie0826
Copy link
Copy Markdown

Summary

Add test and benchmark infrastructure for the Blackwell (SM100) fully-fused
KDA forward kernel (kda_fully_fused_wip.py / blackwell_fused_fwd.py).

Closes part of #16

Known WIP limitations (tracked)

Feature Status
output_final_state=True Not implemented, returns garbage
use_gate_in_kernel=True Produces NaN

16 passed, 3 xfailed,

python benchmarks/bench_kda_blackwell_fwd.py
[Device] NVIDIA B300 SXM6 AC  compute capability sm103  →  cula.kda.blackwell_fused_fwd.flash_kda_prefill

====================================================================================================
 Fixed-Length Benchmark: cuLA Blackwell fully-fused (sm103) vs FLA Triton
====================================================================================================
g_smem_layout_staged: S<3,4,3> o 0 o ((64,16),1,(2,4),2):((32,1),0,(16,2048),8192)
tma_atom_g: Copy Atom
  ThrID:         1:0
  TV Layout Src: (1,2048):(0,1)
  TV Layout Dst: (1,2048):(0,1)
  Value type:    f32
g_smem_layout: S<3,4,3> o 0 o ((64,16),1,(2,4)):((32,1),0,(16,2048))
q_copy_size: 16384, k_copy_size: 16384, v_copy_size: 16384, g_copy_size: 32768
size of storage: 232448
m_smem_layout_staged: S<3,4,3> o 0 o ((8,8),(64,1),(1,2)):((64,512),(1,0),(0,4096))
grid: (1, Int32(?), Int32(?))
sK_neg_g: tensor<ptr<bf16, smem, align<1024>, S<3,4,3>> o ((64,16),1,(4,2),2):((64,1),0,(16,4096),16384)>
sK_neg_g_b: tensor<ptr<bf16, smem, align<1024>, S<3,4,3>> o (((64,2),16),1,4,2):(((1,4096),64),0,1024,16384)>
sK_g: tensor<ptr<bf16, smem, align<1024>, S<3,4,3>> o ((64,16),1,(4,2),2):((64,1),0,(16,4096),8192)>
sG_last: tensor<ptr<f32, smem, align<1024>> o (128,2):(1,128)>
sBeta: tensor<ptr<f32, smem, align<1024>> o (64,1):(1,64)>
sP: tensor<ptr<bf16, smem, align<1024>, S<3,4,3>> o ((64,16),1,4,2):((64,1),0,16,4096)>
sM: tensor<ptr<bf16, smem, align<1024>, S<3,4,3>> o ((8,8),(64,1),(1,2)):((64,512),(1,0),(0,4096))>
sG_flat: tensor<ptr<f32, smem, align<1024>, S<3,4,3>> o (64,(32,4),2):(32,(1,2048),8192)>
sG_flat_as_bf16: tensor<ptr<bf16, smem, align<1024>> o (64,(32,4),2):(32,(1,2048),16384)>
g_smem_layout_epi: S<3,4,3> o 0 o ((8,8),(32,4),(1,2)):((32,256),(1,2048),(0,8192))
g_smem_layout_coalesce: S<3,4,3> o 0 o (64,(32,4),2):(32,(1,2048),8192)
-------------------- SMEM STORE: KK
copy_atom_r2s_x: Copy Atom
  ThrID:         32:1
  TV Layout Src: (32,(2,4)):(2,(1,64))
  TV Layout Dst: (32,8):(8,1)
  Value type:    f16
tiled_t2r_x: Tiled Copy
  Tiler MN:        ((4,16):(16,1),64:1)
  TV Layout tiled: ((32,4),(64,16)):((0,1),(64,4))
Copy Atom
  ThrID:           32:1
  TV Layout Src:   (32,1024):(0,1)
  TV Layout Dst:   ((4,8),(2,2,8)):((2,64),(1,512,8))
  Value type:      f32
thr_t2r_x: Tiled Copy
  Tiler MN:        ((4,16):(16,1),64:1)
  TV Layout tiled: ((4,8,4),((2,2,8),1)):((128,4,1),((64,32,512),0))
Copy Atom
  ThrID:           32:1
  TV Layout Src:   (32,(2,4)):(2,(1,64))
  TV Layout Dst:   (32,8):(8,1)
  Value type:      f16
before partition_D: tensor<ptr<f16, smem, align<1024>, S<3,4,3>> o ((8,8),(64,1),(1,2)):((64,512),(1,0),(0,4096))>
after partition_D, tXsX_r2s: tensor<ptr<f16, smem, align<16>, S<3,4,3>> o ((8,4),1,1,(1,2)):((1,16),0,0,(0,4096))>
before retile tXrX_t2r: tensor<ptr<f16, rmem, align<32>> o (((2,2,8),1),1,1):(((1,2,4),0),0,0)>
after retile tXrX_r2s: tensor<ptr<f16, rmem, align<32>> o ((8,4),1,1):((1,8),0,0)>

====================================================================================================
 Varlen Benchmark: cuLA Blackwell fully-fused (sm103) vs FLA Triton
====================================================================================================


==============================================================================================================
                  BENCHMARK REPORT: flash_kda_prefill (Blackwell fully-fused)
                  cuLA sm103 Blackwell fully-fused vs FLA Triton
                  H=64  D=128  dtype=bf16  safe_gate=True  use_gate_in_kernel=True
                  Warmup=10  Iters=30
==============================================================================================================

  [Fixed-Length]
  ──────────────────────────────────────────────────────────────────────────────────────────
    B       T  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────
    1     512  │    0.000026    0.010274    0.000014  │     0.8008      0.5960     1.34x
    1    1024  │    0.000024    0.008824    0.000012  │     0.7891      0.5820     1.36x
    1    4096  │    0.000024    0.007246    0.000013  │     0.7930      0.8212     0.97x
    1    8192  │    0.000026    0.009174    0.000014  │     1.4054      1.5773     0.89x
    1   16384  │    0.000026    0.008929    0.000014  │     2.7329      3.0870     0.89x
    2     512  │    0.000024    0.008824    0.000012  │     0.7931      0.5893     1.35x
    2    1024  │    0.000028    0.008721    0.000015  │     0.7916      0.5934     1.33x
    2    4096  │    0.000026    0.009174    0.000014  │     1.4001      0.9681     1.45x
    2    8192  │    0.000026    0.008929    0.000013  │     2.7237      1.8568     1.47x
    2   16384  │    0.000026    0.010256    0.000014  │     5.3809      3.6481     1.47x
  ──────────────────────────────────────────────────────────────────────────────────────────

  [Varlen]
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         Config  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
       uniform 10seqs T=4096 [409..415] avg=409  │    0.000024    0.009662    0.000012  │     0.8037      0.6454     1.25x
        random 10seqs T=4096 [24..1201] avg=409  │    0.000024    0.009662    0.000012  │     0.7985      0.6297     1.27x
       skewed 10seqs T=4096 [227..2053] avg=409  │    0.000024    0.007246    0.000012  │     0.8394      0.6619     1.27x
       uniform 20seqs T=4096 [204..220] avg=204  │    0.000024    0.009662    0.000012  │     0.8642      0.7121     1.21x
          random 20seqs T=4096 [5..787] avg=204  │    0.000024    0.009662    0.000012  │     0.8381      0.7006     1.20x
       skewed 20seqs T=4096 [107..2063] avg=204  │    0.000024    0.007246    0.000012  │     0.8169      0.6705     1.22x
       uniform 10seqs T=8192 [819..821] avg=819  │    0.000026    0.009174    0.000014  │     1.4277      1.0841     1.32x
        random 10seqs T=8192 [48..2401] avg=819  │    0.000026    0.009174    0.000014  │     1.4453      1.2541     1.15x
       skewed 10seqs T=8192 [455..4097] avg=819  │    0.000026    0.009174    0.000014  │     1.4519      1.2334     1.18x
       uniform 20seqs T=8192 [409..421] avg=409  │    0.000026    0.009174    0.000014  │     1.4921      1.1562     1.29x
         random 20seqs T=8192 [9..1574] avg=409  │    0.000026    0.009174    0.000014  │     1.4867      1.2125     1.23x
       skewed 20seqs T=8192 [215..4107] avg=409  │    0.000026    0.009174    0.000014  │     1.4939      1.2756     1.17x
   uniform 10seqs T=16384 [1638..1642] avg=1638  │    0.000026    0.008929    0.000013  │     2.7190      2.0002     1.36x
      random 10seqs T=16384 [95..4802] avg=1638  │    0.000026    0.008929    0.000013  │     2.7280      2.0462     1.33x
     skewed 10seqs T=16384 [910..8194] avg=1638  │    0.000026    0.008929    0.000013  │     2.7368      2.3269     1.18x
      uniform 20seqs T=16384 [819..823] avg=819  │    0.000026    0.008929    0.000013  │     2.7464      1.9582     1.40x
       random 20seqs T=16384 [19..3147] avg=819  │    0.000026    0.008929    0.000013  │     2.7592      2.1832     1.26x
      skewed 20seqs T=16384 [431..8195] avg=819  │    0.000026    0.008929    0.000013  │     2.7401      2.3157     1.18x
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================

Copilot AI review requested due to automatic review settings April 5, 2026 19:00
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces benchmark and correctness test scripts for the Blackwell (SM100) fully-fused KDA forward kernel. The benchmark script compares the cuLA Blackwell implementation against the FLA Triton baseline, measuring accuracy and execution time. The test suite validates the kernel's correctness against token-level and chunk-level references for both fixed-length and variable-length sequences. Feedback identifies a hardcoded variable shadowing a function parameter in the tests and suggests using .item() for safer scalar extraction from tensors.

torch.manual_seed(42)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
cu_seqlens_cpu = cu_seqlens.cpu()
T = cu_seqlens[-1]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is safer to use .item() when extracting a scalar from a tensor to be used as a dimension in torch.randn or other shape-dependent functions. This ensures the value is a Python integer and avoids potential type mismatch issues in some PyTorch versions.

Suggested change
T = cu_seqlens[-1]
T = cu_seqlens[-1].item()

cu_seqlens_cpu = cu_seqlens.cpu()
T = cu_seqlens[-1]
N = len(cu_seqlens) - 1
D = 128
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable D is already provided as a function parameter. Hardcoding D = 128 here shadows the parameter and makes the test less flexible, as it will ignore any value passed via @pytest.mark.parametrize. This line should be removed to allow testing with different head dimensions.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds SM10x/Blackwell-focused correctness tests and a dedicated benchmark runner for the fully-fused KDA forward (“prefill”) kernel, using a direct call path to cula.kda.blackwell_fused_fwd.flash_kda_prefill.

Changes:

  • Add SM100-only correctness tests for fixed-length and varlen configs comparing cuLA vs token-level naive and FLA Triton refs.
  • Add a standalone Blackwell benchmark script that reports accuracy stats vs FLA and CUDA-event timings.
  • Wire benchmark to bypass the existing SM90-oriented dispatcher by importing flash_kda_prefill directly.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
tests/test_kda_blackwell_fwd.py New SM100-only correctness suite for flash_kda_prefill (fixed + varlen) with FLA + naive references and WIP gating/final-state notes.
benchmarks/bench_kda_blackwell_fwd.py New Blackwell-only benchmark script comparing cuLA vs FLA for accuracy and runtime, intended for SM10x devices.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +117 to +157
def run_fla(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound):
return fla_chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
# TODO: switch to use_gate_in_kernel=True once WIP kernel supports it (currently NaN)
A_log=None,
dt_bias=None,
initial_state=init_state,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
use_gate_in_kernel=False,
safe_gate=True,
lower_bound=lower_bound,
)


def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound):
return flash_kda_prefill(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
# TODO: switch to use_gate_in_kernel=True once WIP kernel supports it (currently NaN)
A_log=None,
dt_bias=None,
initial_state=init_state,
# TODO: switch to output_final_state=True once WIP kernel implements final state write
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
use_gate_in_kernel=False,
safe_gate=True,
lower_bound=lower_bound,
)
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_safe_gate_inputs() generates raw g plus (A_log, dt_bias) intended for use_gate_in_kernel=True (and the safe_gate assumptions). Here run_fla forces use_gate_in_kernel=False and drops A_log/dt_bias, so both FLA and cuLA are receiving an unprocessed random g while still setting safe_gate=True. This makes the accuracy numbers and performance comparison misleading (and may violate the kernel’s “safe gate” range assumptions). Consider either (a) switching back to use_gate_in_kernel=True once the WIP path is fixed, or (b) changing input prep to produce a precomputed log-space/cumsummed gate (e.g., via prepare_intra_inputs or equivalent preprocessing) and/or disabling safe_gate when g is not guaranteed to be in the safe range.

Copilot uses AI. Check for mistakes.
cu_seqlens_cpu = cu_seqlens.cpu()
T = cu_seqlens[-1]
N = len(cu_seqlens) - 1
D = 128
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In test_safe_gate_chunk_varlen, the D argument from parametrization is immediately overwritten with D = 128, so the parameter is effectively ignored. This is easy to miss and will break the test silently if a non-128 D case is ever added; consider removing the reassignment and using the provided D (or assert it equals 128 if that’s a hard requirement).

Suggested change
D = 128

Copilot uses AI. Check for mistakes.
initial_state=h0.clone(),
output_final_state=True,
cu_seqlens=cu_seqlens,
cu_seqlens_cpu=cu_seqlens_cpu,
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flash_kda_prefill() doesn’t take a cu_seqlens_cpu parameter (it will be swallowed by **kwargs and ignored), so passing it here is misleading and can hide real argument mismatches. Consider dropping cu_seqlens_cpu from the cuLA call (or updating flash_kda_prefill to explicitly accept/validate it if it’s meant to be part of the API).

Suggested change
cu_seqlens_cpu=cu_seqlens_cpu,

Copilot uses AI. Check for mistakes.
Comment on lines +304 to +313
def print_report(fixed_results, varlen_results):
sep = "=" * 110
print(f"\n\n{sep}")
print(" BENCHMARK REPORT: flash_kda_prefill (Blackwell fully-fused)")
print(f" cuLA {_SM_TAG} Blackwell fully-fused vs FLA Triton")
print(f" H={H} D={D} dtype=bf16 safe_gate=True use_gate_in_kernel=False (TODO: True once WIP fixed)")
wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP
ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS
mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "")
print(f" Warmup={wu} Iters={ni}{mode_tag}")
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description’s sample benchmark output reports use_gate_in_kernel=True, but this script hard-codes use_gate_in_kernel=False (and drops A_log/dt_bias) with a TODO to re-enable it later. Please update either the PR description output (so it reflects the current script behavior) or the script so it matches what was benchmarked, to avoid confusion when others reproduce results.

Copilot uses AI. Check for mistakes.
Comment on lines +24 to +26
# Known WIP limitations (assert_close skipped with pytest.xfail):
# - output_final_state=True: final state write is unimplemented in the WIP
# kernel (FIXME comment in blackwell_fused_fwd.py line ~129), ht is garbage
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Header comment says the output_final_state=True limitation is handled by skipping assert_close with pytest.xfail, but the test currently just comments out the ht assertions (no xfail/skip). Consider either adding an explicit pytest.xfail/pytest.skip for output_final_state=True cases, or updating the comment to match the current behavior.

Suggested change
# Known WIP limitations (assert_close skipped with pytest.xfail):
# - output_final_state=True: final state write is unimplemented in the WIP
# kernel (FIXME comment in blackwell_fused_fwd.py line ~129), ht is garbage
# Known WIP limitations:
# - output_final_state=True: final state write is unimplemented in the WIP
# kernel (FIXME comment in blackwell_fused_fwd.py line ~129), so `ht`
# is currently not asserted in these tests

Copilot uses AI. Check for mistakes.
_SM_TAG = f"sm{_major}{_minor}"

if _major < 10:
print(f"[WARNING] This benchmark is designed for SM100 (Blackwell). Current GPU: {_SM_TAG}")
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The script only prints a warning when running on non-SM10x (_major < 10), but flash_kda_prefill() ultimately enforces Blackwell via assert_blackwell() and will raise at runtime on other GPUs. Consider exiting early with a clear error (or gating run_cula behind an SM10x check) so the behavior matches the message and users don’t hit a confusing runtime exception mid-benchmark.

Suggested change
print(f"[WARNING] This benchmark is designed for SM100 (Blackwell). Current GPU: {_SM_TAG}")
raise RuntimeError(
f"This benchmark requires an SM100/SM10x (Blackwell) GPU because it calls "
f"flash_kda_prefill(), which only supports Blackwell. Current GPU: {_SM_TAG}"
)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants