diff --git a/benchmarks/bench_kda_blackwell_fwd.py b/benchmarks/bench_kda_blackwell_fwd.py new file mode 100644 index 0000000..29daca6 --- /dev/null +++ b/benchmarks/bench_kda_blackwell_fwd.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +bench_kda_blackwell_fwd.py — Benchmark: Blackwell (SM100) fully-fused KDA forward + +Directly calls flash_kda_prefill from cula.kda.blackwell_fused_fwd, +bypassing the get_kda_fused_fwd() dispatch (which raises NotImplementedError +on SM100 for the legacy SM90 path). + +Compares: + - cuLA Blackwell fully-fused (kda_fully_fused_wip.py / KDAChunkwise) + - FLA Triton baseline (chunk_kda) + +Measures: + - Accuracy: RMSE, relative max diff, mean absolute diff (vs FLA Triton) + - Performance: kernel execution time (ms) with CUDA events + +Modes: + - Fixed-length: various (B, T) configs + - Varlen: sequences with different length distributions + +Usage: + python benchmarks/bench_kda_blackwell_fwd.py [--mode fixed|varlen|both] [--ncu] + +With --ncu, warmup=1 and iters=1 for ncu profiling: + ncu --set full -o report python benchmarks/bench_kda_blackwell_fwd.py --ncu +""" + +import argparse +import os +import pathlib +import sys + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) +os.environ.setdefault("FLA_USE_FAST_OPS", os.getenv("CULA_USE_FAST_MATH", "1")) + +import torch +from fla.ops.kda import chunk_kda as fla_chunk_kda + +from benchmarks.utils import ( + SEED, + build_varlen_configs, + exclusive_cumsum, + prepare_safe_gate_inputs, + set_seed, +) +from cula.kda.blackwell_fused_fwd import flash_kda_prefill +from cula.utils import get_device_sm_version + +# ============================================================ +# Device sanity check +# ============================================================ +_device = torch.device("cuda") +_major, _minor = get_device_sm_version(_device) +_SM_TAG = f"sm{_major}{_minor}" + +if _major < 10: + print(f"[WARNING] This benchmark is designed for SM100 (Blackwell). Current GPU: {_SM_TAG}") + +# ============================================================ +# Constants +# ============================================================ +H, D = 64, 128 +WARMUP = 10 +N_ITERS = 30 +NCU_MODE = False +SANITIZER_MODE = False + + +# ============================================================ +# Helpers +# ============================================================ +def time_kernel(fn, warmup=None, n_iters=None): + if warmup is None: + warmup = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP + if n_iters is None: + n_iters = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + for _ in range(n_iters): + fn() + end_evt.record() + torch.cuda.synchronize() + return start_evt.elapsed_time(end_evt) / n_iters + + +def accuracy_stats(ref, out): + """Compute RMSE, relative max diff, and mean absolute difference.""" + ref_f = ref.float() + out_f = out.float() + diff = (ref_f - out_f).abs() + rmse = diff.pow(2).mean().sqrt().item() + max_diff = diff.max().item() + denom = ref_f.abs().max().item() + rel_max = max_diff / denom if denom > 0 else 0.0 + mean_diff = diff.mean().item() + return rmse, rel_max, mean_diff + + +def run_fla(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound): + return fla_chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + # TODO: switch to use_gate_in_kernel=True once WIP kernel supports it (currently NaN) + A_log=None, + dt_bias=None, + initial_state=init_state, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + use_gate_in_kernel=False, + safe_gate=True, + lower_bound=lower_bound, + ) + + +def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound): + return flash_kda_prefill( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + # TODO: switch to use_gate_in_kernel=True once WIP kernel supports it (currently NaN) + A_log=None, + dt_bias=None, + initial_state=init_state, + # TODO: switch to output_final_state=True once WIP kernel implements final state write + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + use_gate_in_kernel=False, + safe_gate=True, + lower_bound=lower_bound, + ) + + +# ============================================================ +# Fixed-length benchmark +# ============================================================ +def bench_fixed(configs): + print("\n" + "=" * 100) + print(f" Fixed-Length Benchmark: cuLA Blackwell fully-fused ({_SM_TAG}) vs FLA Triton") + print("=" * 100) + results = [] + + for B, T in configs: + set_seed(SEED) + device = torch.device("cuda") + torch.cuda.empty_cache() + + seq_lens = [T] * B + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] + scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) + + # Accuracy + o_fla, _ = run_fla(**common) + o_cula, _ = run_cula(**common) + torch.cuda.synchronize() + + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) + + # Performance + ms_fla = time_kernel(lambda: run_fla(**common)) + ms_cula = time_kernel(lambda: run_cula(**common)) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + results.append( + { + "B": B, + "T": T, + "rmse": rmse, + "rel_max": rel_max, + "mean_diff": mean_diff, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + ) + + del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + torch.cuda.empty_cache() + + return results + + +# ============================================================ +# Varlen benchmark +# ============================================================ +def bench_varlen(configs): + print("\n" + "=" * 100) + print(f" Varlen Benchmark: cuLA Blackwell fully-fused ({_SM_TAG}) vs FLA Triton") + print("=" * 100) + results = [] + + for seq_lens, total_len, dist in configs: + set_seed(SEED) + device = torch.device("cuda") + torch.cuda.empty_cache() + + T = total_len + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] + scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) + + # Accuracy + o_fla, _ = run_fla(**common) + o_cula, _ = run_cula(**common) + torch.cuda.synchronize() + + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) + + # Performance + ms_fla = time_kernel(lambda: run_fla(**common)) + ms_cula = time_kernel(lambda: run_cula(**common)) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + n_seqs = len(seq_lens) + min_l, max_l = min(seq_lens), max(seq_lens) + avg_l = T // n_seqs + tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" + + results.append( + { + "tag": tag, + "dist": dist, + "T_total": T, + "n_seqs": n_seqs, + "rmse": rmse, + "rel_max": rel_max, + "mean_diff": mean_diff, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + ) + + del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + torch.cuda.empty_cache() + + return results + + +# ============================================================ +# Report +# ============================================================ +def print_report(fixed_results, varlen_results): + sep = "=" * 110 + print(f"\n\n{sep}") + print(" BENCHMARK REPORT: flash_kda_prefill (Blackwell fully-fused)") + print(f" cuLA {_SM_TAG} Blackwell fully-fused vs FLA Triton") + print(f" H={H} D={D} dtype=bf16 safe_gate=True use_gate_in_kernel=False (TODO: True once WIP fixed)") + wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP + ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS + mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") + print(f" Warmup={wu} Iters={ni}{mode_tag}") + print(sep) + + if fixed_results: + print("\n [Fixed-Length]") + print(f" {'─' * 90}") + print( + f" {'B':>3s} {'T':>6s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" + f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + ) + print(f" {'─' * 90}") + for r in fixed_results: + print( + f" {r['B']:3d} {r['T']:6d} │ " + f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " + f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" + ) + print(f" {'─' * 90}") + + if varlen_results: + print("\n [Varlen]") + print(f" {'─' * 105}") + print( + f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" + f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + ) + print(f" {'─' * 105}") + for r in varlen_results: + print( + f" {r['tag']:>45s} │ " + f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " + f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" + ) + print(f" {'─' * 105}") + + print(f"\n{sep}\n") + + +# ============================================================ +# Main +# ============================================================ +def main(): + parser = argparse.ArgumentParser( + description="bench_kda_blackwell_fwd: Blackwell fully-fused KDA forward vs FLA Triton" + ) + parser.add_argument( + "--mode", + type=str, + default="both", + choices=["fixed", "varlen", "both"], + help="Which benchmark mode to run (default: both)", + ) + parser.add_argument( + "--ncu", + action="store_true", + help="NCU profiling mode: warmup=1, iters=1", + ) + parser.add_argument( + "--sanitizer", + action="store_true", + help="Sanitizer mode: warmup=1, iters=1", + ) + args = parser.parse_args() + + global NCU_MODE, SANITIZER_MODE + if args.ncu: + NCU_MODE = True + print("[NCU mode] warmup=1, iters=1") + if args.sanitizer: + SANITIZER_MODE = True + print("[Sanitizer mode] warmup=1, iters=1") + + print( + f"[Device] {torch.cuda.get_device_name(0)} compute capability {_SM_TAG}" + f" → cula.kda.blackwell_fused_fwd.flash_kda_prefill" + ) + + fixed_configs = [ + # (B, T) + (1, 512), + (1, 1024), + (1, 4096), + (1, 8192), + (1, 16384), + (2, 512), + (2, 1024), + (2, 4096), + (2, 8192), + (2, 16384), + ] + + varlen_configs = build_varlen_configs( + num_seqs_list=(10, 20), + total_lens=(4096, 8192, 16384), + dists=("uniform", "random", "skewed"), + ) + + fixed_res, varlen_res = [], [] + + if args.mode in ("fixed", "both"): + fixed_res = bench_fixed(fixed_configs) + + if args.mode in ("varlen", "both"): + varlen_res = bench_varlen(varlen_configs) + + print_report(fixed_res, varlen_res) + + return fixed_res, varlen_res + + +if __name__ == "__main__": + main() diff --git a/cula/kda/blackwell_fused_fwd.py b/cula/kda/blackwell_fused_fwd.py index 7408661..64b7d93 100644 --- a/cula/kda/blackwell_fused_fwd.py +++ b/cula/kda/blackwell_fused_fwd.py @@ -126,7 +126,6 @@ def forward( g_cute = from_dlpack(g.detach()) beta_cute = from_dlpack(beta.detach()) - # FIXME: support return final_states o = torch.empty_like(q) o_cute = from_dlpack(o.detach()) diff --git a/cula/ops/kda_fully_fused_wip.py b/cula/ops/kda_fully_fused_wip.py index adcf809..d342ec4 100644 --- a/cula/ops/kda_fully_fused_wip.py +++ b/cula/ops/kda_fully_fused_wip.py @@ -2576,7 +2576,10 @@ def index_transform_half(index_q, index_k): if index_q == Constant.C - 1: sG_last[index_k + k_offset, g_stage_idx] = tQrG_persists[half_idx][i] else: - if index_q == Constant.C - 1: + # non-varlen: tail chunk needs g of last valid token + tail_len_q = seq_len % C + last_valid_q = (tail_len_q - 1) if (idx == final_blk and tail_len_q != 0) else (Constant.C - 1) + if index_q == last_valid_q: sG_last[index_k + k_offset, g_stage_idx] = tQrG_persists[half_idx][i] # exp(g) half in-place — persists for K gating reuse @@ -2628,7 +2631,10 @@ def index_transform_half(index_q, index_k): if index_q == Constant.C - 1: sG_last[index_k + k_offset, g_stage_idx] = tQrG_persists[half_idx][i] else: - if index_q == Constant.C - 1: + # non-varlen: tail chunk needs g of last valid token + tail_len_q = seq_len % C + last_valid_q = (tail_len_q - 1) if (idx == final_blk and tail_len_q != 0) else (Constant.C - 1) + if index_q == last_valid_q: sG_last[index_k + k_offset, g_stage_idx] = tQrG_persists[half_idx][i] # ==================================================== @@ -2822,13 +2828,21 @@ def index_transform_half(index_q, index_k): tQrK_half_cv = thr_load_qk_half.retile(tQrK_half) cute.copy(tiled_load_qk_half, tQsK_h[half_idx][None, None, None, k_stage_idx], tQrK_half_cv) - # Zero K half for invalid positions (varlen only) + # Zero K half for invalid positions (varlen and non-varlen tail chunks) if cutlass.const_expr(self.is_varlen): if valid_len_chunk < C: for i in cutlass.range_constexpr(cute.size(tQcMq_half)): index_q, index_k = index_transform_half(*tQcMq_half[i]) if index_q >= valid_len_chunk: tQrK_half[i] = self.k_dtype(0.0) + else: + # non-varlen: mask padding in tail chunk (seq_len % C != 0) + tail_len_k = seq_len % C + if idx == final_blk and tail_len_k != 0: + for i in cutlass.range_constexpr(cute.size(tQcMq_half)): + index_q, index_k = index_transform_half(*tQcMq_half[i]) + if index_q >= tail_len_k: + tQrK_half[i] = self.k_dtype(0.0) # K^T gating: exp(g_last - g) * K for i in cutlass.range_constexpr(cute.size(tQcMq_half)): @@ -3099,14 +3113,19 @@ def index_transform_half(index_q, index_k): # ------------------------------------------------------------ # NOTE: Save exp(g) of last VALID row to rG_last for state update in next chunk - # For full chunks, directly use C-1; only loop for partial chunks (varlen only) + # For full chunks, directly use C-1; for partial tail chunks use actual last token. if cutlass.const_expr(self.is_varlen): if valid_len_chunk < C: rG_last = exp_g[valid_len_chunk - 1] else: rG_last = exp_g[Constant.C - 1] else: - rG_last = exp_g[Constant.C - 1] + # non-varlen: tail chunk may be partial (seq_len % C != 0) + tail_len = seq_len % C # 0 means full chunk + if idx == final_blk and tail_len != 0: + rG_last = exp_g[tail_len - 1] + else: + rG_last = exp_g[Constant.C - 1] # NOTE: each thread save one element sG_last[local_tidx, g_stage_idx] = rG_last diff --git a/tests/test_kda_blackwell_fwd.py b/tests/test_kda_blackwell_fwd.py new file mode 100644 index 0000000..c4a0f00 --- /dev/null +++ b/tests/test_kda_blackwell_fwd.py @@ -0,0 +1,303 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Correctness tests for the Blackwell (SM100) fully-fused KDA forward kernel. +# Mirrors test_kda_fused_fwd.py (SM90) but directly calls flash_kda_prefill +# from cula.kda.blackwell_fused_fwd to bypass the NotImplementedError in +# get_kda_fused_fwd() on SM100. +# +# For each config we compare against two references: +# 1. naive_recurrent_kda — token-level reference +# 2. fla chunk_kda — chunk-level Triton reference +# +# Known WIP limitations (assert_close skipped with pytest.xfail): +# - output_final_state=True: final state write is unimplemented in the WIP +# kernel (FIXME comment in blackwell_fused_fwd.py line ~129), ht is garbage +# - use_gate_in_kernel=True: fused gate path produces NaN in current WIP kernel + +import pytest +import torch +import torch.nn.functional as F +from fla.ops import chunk_kda as fla_chunk_kda +from fla.ops.kda.gate import naive_kda_gate +from fla.ops.kda.naive import naive_recurrent_kda +from fla.utils import assert_close, device + +from cula.kda.blackwell_fused_fwd import flash_kda_prefill + +pytestmark = pytest.mark.sm100_only + + +# ============================================================ +# Fixed-length correctness tests +# ============================================================ +@pytest.mark.parametrize( + ( + "B", + "T", + "H", + "D", + "gate_logit_normalizer", + "mask_p", + "use_qk_l2norm_in_kernel", + "use_gate_in_kernel", + "safe_gate", + "dtype", + ), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-{}".format(*test), + ) + for test in [ + # --- basic safe_gate (use_gate_in_kernel=False) --- + (1, 63, 1, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 500, 3, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 1000, 3, 128, 1, 0.5, False, False, True, torch.bfloat16), + (3, 1024, 4, 128, 0.1, 0, False, False, True, torch.bfloat16), + (4, 1024, 4, 128, 1, 0, False, False, True, torch.bfloat16), + # --- use_qk_l2norm_in_kernel --- + (4, 1024, 4, 128, 1, 0, True, False, True, torch.bfloat16), + # --- use_gate_in_kernel (A_log / dt_bias fused) --- + # TODO: use_gate_in_kernel=True produces NaN in current WIP kernel, xfail until fixed + (2, 1500, 4, 128, 10, 0, False, True, True, torch.bfloat16), + (4, 2048, 8, 128, 1, 0, False, True, True, torch.bfloat16), + # --- pipeline stress: multi-chunk sequences --- + (2, 256, 4, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 2048, 4, 128, 1, 0, False, False, True, torch.bfloat16), + # TODO: use_gate_in_kernel=True, xfail until fixed + (1, 4096, 4, 128, 1, 0, False, True, True, torch.bfloat16), + ] + ], +) +def test_safe_gate_chunk( + B: int, + T: int, + H: int, + D: int, + gate_logit_normalizer: float, + mask_p: float, + use_qk_l2norm_in_kernel: bool, + use_gate_in_kernel: bool, + safe_gate: bool, + dtype: torch.dtype, +): + from fla.ops.kda.gate import naive_kda_lowerbound_gate + + # TODO: use_gate_in_kernel=True produces NaN in current WIP kernel + # Remove xfail once fused gate path is implemented in kda_fully_fused_wip.py + if use_gate_in_kernel: + pytest.xfail("use_gate_in_kernel=True not yet implemented in WIP kernel (produces NaN)") + + torch.manual_seed(42) + q = torch.rand(B, T, H, D, dtype=dtype) + k = torch.rand(B, T, H, D, dtype=dtype) + v = torch.rand(B, T, H, D, dtype=dtype) + g = torch.randn(B, T, H, D, dtype=torch.float if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = torch.randn(H, dtype=torch.float) + dt_bias = torch.randn(H * D, dtype=torch.float) + else: + g = F.logsigmoid(g) / gate_logit_normalizer + g = g * (torch.rand_like(g) > mask_p) + A_log = None + dt_bias = None + + if safe_gate: + lower_bound = -5.0 + if not use_gate_in_kernel: + g = g.clamp(-5, 0) + naive_kda_gate_fn = naive_kda_lowerbound_gate + else: + lower_bound = None + naive_kda_gate_fn = naive_kda_gate + + beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid() + h0 = torch.randn(B, H, D, D, dtype=torch.float32) + if use_gate_in_kernel: + A_log, dt_bias = map(lambda x: x.to(device), (A_log, dt_bias)) + q, k, v, g, beta, h0 = map(lambda x: x.to(device), (q, k, v, g, beta, h0)) + + # Reference 1: naive token-level + ref, ref_ht = naive_recurrent_kda( + q=F.normalize(q.clone(), p=2, dim=-1), + k=F.normalize(k.clone(), p=2, dim=-1), + v=v.clone(), + g=(naive_kda_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + ) + + # Reference 2: FLA chunk_kda (Triton) + ref_fla, ref_ht_fla = fla_chunk_kda( + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + lower_bound=lower_bound, + ) + + # cuLA Blackwell fully-fused kernel + tri, tri_ht = flash_kda_prefill( + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + lower_bound=lower_bound, + ) + + assert_close("o", ref, tri, 0.005) + assert_close("o (vs fla)", ref_fla, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("ht (vs fla)", ref_ht_fla, tri_ht, 0.005) + + +# ============================================================ +# Varlen correctness tests +# ============================================================ +@pytest.mark.parametrize( + ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}".format(*test)) + for test in [ + (4, 128, 0.1, [0, 15], torch.bfloat16, True), + (4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True), + (4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), + (4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True), + (4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), + # pipeline stress: simulated-trace cu_seqlens (same as test_kda_fused_fwd.py SM90) + ( + 32, + 128, + 0, + [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192], + torch.bfloat16, + True, + ), + ( + 32, + 128, + 0, + [0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192], + torch.bfloat16, + True, + ), + ( + 32, + 128, + 0, + [0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192], + torch.bfloat16, + True, + ), + ] + ], +) +def test_safe_gate_chunk_varlen( + H: int, + D: int, + mask_p: float, + cu_seqlens: list[int], + dtype: torch.dtype, + safe_gate: bool, +): + torch.manual_seed(42) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + cu_seqlens_cpu = cu_seqlens.cpu() + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + D = 128 + + q = torch.randn((1, T, H, D), dtype=dtype) + k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) + v = torch.randn((1, T, H, D), dtype=dtype) + g = F.logsigmoid(torch.randn(1, T, H, D, dtype=torch.float)) + mask = torch.rand_like(g) > mask_p + g = g * mask + (~mask) * (-1000) + if safe_gate: + g = g.clamp(-5, 0) + + beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid() + h0 = torch.randn((N, H, D, D), dtype=torch.float32) + q, k, v, g, beta, h0 = map(lambda x: x.to(device), (q, k, v, g, beta, h0)) + + # cuLA Blackwell kernel + tri, tri_ht = flash_kda_prefill( + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + safe_gate=safe_gate, + lower_bound=-5.0 if safe_gate else None, + ) + + # Reference 2: FLA chunk_kda (Triton) + ref_fla, ref_ht_fla = fla_chunk_kda( + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + safe_gate=safe_gate, + lower_bound=-5.0 if safe_gate else None, + ) + + # Reference 1: naive token-level (per-sequence) + ref_list = [] + ref_ht_list = [] + for i in range(N): + ref_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), + k=k[:, cu_seqlens[i] : cu_seqlens[i + 1]], + v=v[:, cu_seqlens[i] : cu_seqlens[i + 1]], + beta=beta[:, cu_seqlens[i] : cu_seqlens[i + 1]], + g=g[:, cu_seqlens[i] : cu_seqlens[i + 1]], + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = torch.cat(ref_list, 1) + ref_ht = torch.cat(ref_ht_list, 0) + + assert_close("o", ref, tri, 0.005) + assert_close("o (vs fla)", ref_fla, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("ht (vs fla)", ref_ht_fla, tri_ht, 0.005)