test/bench: add SM100 Blackwell fused forward correctness tests and benchmark#35
test/bench: add SM100 Blackwell fused forward correctness tests and benchmark#35meinie0826 wants to merge 1 commit intoinclusionAI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces benchmark and correctness test scripts for the Blackwell (SM100) fully-fused KDA forward kernel. The benchmark script compares the cuLA Blackwell implementation against the FLA Triton baseline, measuring accuracy and execution time. The test suite validates the kernel's correctness against token-level and chunk-level references for both fixed-length and variable-length sequences. Feedback identifies a hardcoded variable shadowing a function parameter in the tests and suggests using .item() for safer scalar extraction from tensors.
| torch.manual_seed(42) | ||
| cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) | ||
| cu_seqlens_cpu = cu_seqlens.cpu() | ||
| T = cu_seqlens[-1] |
There was a problem hiding this comment.
It is safer to use .item() when extracting a scalar from a tensor to be used as a dimension in torch.randn or other shape-dependent functions. This ensures the value is a Python integer and avoids potential type mismatch issues in some PyTorch versions.
| T = cu_seqlens[-1] | |
| T = cu_seqlens[-1].item() |
| cu_seqlens_cpu = cu_seqlens.cpu() | ||
| T = cu_seqlens[-1] | ||
| N = len(cu_seqlens) - 1 | ||
| D = 128 |
There was a problem hiding this comment.
There was a problem hiding this comment.
Pull request overview
Adds SM10x/Blackwell-focused correctness tests and a dedicated benchmark runner for the fully-fused KDA forward (“prefill”) kernel, using a direct call path to cula.kda.blackwell_fused_fwd.flash_kda_prefill.
Changes:
- Add SM100-only correctness tests for fixed-length and varlen configs comparing cuLA vs token-level naive and FLA Triton refs.
- Add a standalone Blackwell benchmark script that reports accuracy stats vs FLA and CUDA-event timings.
- Wire benchmark to bypass the existing SM90-oriented dispatcher by importing
flash_kda_prefilldirectly.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
tests/test_kda_blackwell_fwd.py |
New SM100-only correctness suite for flash_kda_prefill (fixed + varlen) with FLA + naive references and WIP gating/final-state notes. |
benchmarks/bench_kda_blackwell_fwd.py |
New Blackwell-only benchmark script comparing cuLA vs FLA for accuracy and runtime, intended for SM10x devices. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def run_fla(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound): | ||
| return fla_chunk_kda( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| g=g, | ||
| beta=beta, | ||
| scale=scale, | ||
| # TODO: switch to use_gate_in_kernel=True once WIP kernel supports it (currently NaN) | ||
| A_log=None, | ||
| dt_bias=None, | ||
| initial_state=init_state, | ||
| output_final_state=False, | ||
| use_qk_l2norm_in_kernel=True, | ||
| cu_seqlens=cu_seqlens, | ||
| use_gate_in_kernel=False, | ||
| safe_gate=True, | ||
| lower_bound=lower_bound, | ||
| ) | ||
|
|
||
|
|
||
| def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound): | ||
| return flash_kda_prefill( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| g=g, | ||
| beta=beta, | ||
| scale=scale, | ||
| # TODO: switch to use_gate_in_kernel=True once WIP kernel supports it (currently NaN) | ||
| A_log=None, | ||
| dt_bias=None, | ||
| initial_state=init_state, | ||
| # TODO: switch to output_final_state=True once WIP kernel implements final state write | ||
| output_final_state=False, | ||
| use_qk_l2norm_in_kernel=True, | ||
| cu_seqlens=cu_seqlens, | ||
| use_gate_in_kernel=False, | ||
| safe_gate=True, | ||
| lower_bound=lower_bound, | ||
| ) |
There was a problem hiding this comment.
prepare_safe_gate_inputs() generates raw g plus (A_log, dt_bias) intended for use_gate_in_kernel=True (and the safe_gate assumptions). Here run_fla forces use_gate_in_kernel=False and drops A_log/dt_bias, so both FLA and cuLA are receiving an unprocessed random g while still setting safe_gate=True. This makes the accuracy numbers and performance comparison misleading (and may violate the kernel’s “safe gate” range assumptions). Consider either (a) switching back to use_gate_in_kernel=True once the WIP path is fixed, or (b) changing input prep to produce a precomputed log-space/cumsummed gate (e.g., via prepare_intra_inputs or equivalent preprocessing) and/or disabling safe_gate when g is not guaranteed to be in the safe range.
| cu_seqlens_cpu = cu_seqlens.cpu() | ||
| T = cu_seqlens[-1] | ||
| N = len(cu_seqlens) - 1 | ||
| D = 128 |
There was a problem hiding this comment.
In test_safe_gate_chunk_varlen, the D argument from parametrization is immediately overwritten with D = 128, so the parameter is effectively ignored. This is easy to miss and will break the test silently if a non-128 D case is ever added; consider removing the reassignment and using the provided D (or assert it equals 128 if that’s a hard requirement).
| D = 128 |
| initial_state=h0.clone(), | ||
| output_final_state=True, | ||
| cu_seqlens=cu_seqlens, | ||
| cu_seqlens_cpu=cu_seqlens_cpu, |
There was a problem hiding this comment.
flash_kda_prefill() doesn’t take a cu_seqlens_cpu parameter (it will be swallowed by **kwargs and ignored), so passing it here is misleading and can hide real argument mismatches. Consider dropping cu_seqlens_cpu from the cuLA call (or updating flash_kda_prefill to explicitly accept/validate it if it’s meant to be part of the API).
| cu_seqlens_cpu=cu_seqlens_cpu, |
| def print_report(fixed_results, varlen_results): | ||
| sep = "=" * 110 | ||
| print(f"\n\n{sep}") | ||
| print(" BENCHMARK REPORT: flash_kda_prefill (Blackwell fully-fused)") | ||
| print(f" cuLA {_SM_TAG} Blackwell fully-fused vs FLA Triton") | ||
| print(f" H={H} D={D} dtype=bf16 safe_gate=True use_gate_in_kernel=False (TODO: True once WIP fixed)") | ||
| wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP | ||
| ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS | ||
| mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") | ||
| print(f" Warmup={wu} Iters={ni}{mode_tag}") |
There was a problem hiding this comment.
The PR description’s sample benchmark output reports use_gate_in_kernel=True, but this script hard-codes use_gate_in_kernel=False (and drops A_log/dt_bias) with a TODO to re-enable it later. Please update either the PR description output (so it reflects the current script behavior) or the script so it matches what was benchmarked, to avoid confusion when others reproduce results.
| # Known WIP limitations (assert_close skipped with pytest.xfail): | ||
| # - output_final_state=True: final state write is unimplemented in the WIP | ||
| # kernel (FIXME comment in blackwell_fused_fwd.py line ~129), ht is garbage |
There was a problem hiding this comment.
Header comment says the output_final_state=True limitation is handled by skipping assert_close with pytest.xfail, but the test currently just comments out the ht assertions (no xfail/skip). Consider either adding an explicit pytest.xfail/pytest.skip for output_final_state=True cases, or updating the comment to match the current behavior.
| # Known WIP limitations (assert_close skipped with pytest.xfail): | |
| # - output_final_state=True: final state write is unimplemented in the WIP | |
| # kernel (FIXME comment in blackwell_fused_fwd.py line ~129), ht is garbage | |
| # Known WIP limitations: | |
| # - output_final_state=True: final state write is unimplemented in the WIP | |
| # kernel (FIXME comment in blackwell_fused_fwd.py line ~129), so `ht` | |
| # is currently not asserted in these tests |
| _SM_TAG = f"sm{_major}{_minor}" | ||
|
|
||
| if _major < 10: | ||
| print(f"[WARNING] This benchmark is designed for SM100 (Blackwell). Current GPU: {_SM_TAG}") |
There was a problem hiding this comment.
The script only prints a warning when running on non-SM10x (_major < 10), but flash_kda_prefill() ultimately enforces Blackwell via assert_blackwell() and will raise at runtime on other GPUs. Consider exiting early with a clear error (or gating run_cula behind an SM10x check) so the behavior matches the message and users don’t hit a confusing runtime exception mid-benchmark.
| print(f"[WARNING] This benchmark is designed for SM100 (Blackwell). Current GPU: {_SM_TAG}") | |
| raise RuntimeError( | |
| f"This benchmark requires an SM100/SM10x (Blackwell) GPU because it calls " | |
| f"flash_kda_prefill(), which only supports Blackwell. Current GPU: {_SM_TAG}" | |
| ) |
Summary
Add test and benchmark infrastructure for the Blackwell (SM100) fully-fused
KDA forward kernel (
kda_fully_fused_wip.py/blackwell_fused_fwd.py).Closes part of #16
Known WIP limitations (tracked)
output_final_state=Trueuse_gate_in_kernel=True16 passed, 3 xfailed,