diff --git a/cula/ops/lightning_attn.py b/cula/ops/lightning_attn.py index 057cc73..31eb037 100644 --- a/cula/ops/lightning_attn.py +++ b/cula/ops/lightning_attn.py @@ -410,10 +410,9 @@ def __call__( # and the parameter is eliminated at compile time via const_expr guards. # For varlen: state pool is [pool_size, H, D, D]. We use B (=N) as the # pool dimension — strides are correct regardless of actual pool_size. - fstate_layout = cute.make_layout( - (D, D, (H, B)), - stride=(1, D, (D * D, D * D * H)), - ) + fstate_layout = cute.make_layout((D, D, (H, B)), + stride=(D, 1, (D * D, D * D * H)), + ) if cutlass.const_expr(self.has_initial_state): initial_state = cute.make_tensor(initial_state_in.iterator, fstate_layout) else: @@ -1703,8 +1702,8 @@ def kernel( # -------------- Initial State Loading (h0) ---------------- if cutlass.const_expr(self.has_initial_state): gState_h0 = initial_state[None, None, (hidx, state_idx)] - gRow_h0 = cute.make_tensor(gState_h0.iterator + local_tidx, cute.make_layout(_D, stride=_D)) - cute.autovec_copy(gRow_h0, init_flat) + gCol_h0 = cute.make_tensor(gState_h0.iterator + local_tidx * _D, cute.make_layout(_D, stride=1)) + cute.autovec_copy(gCol_h0, init_flat) # Store raw h0 as BF16 to kv16 TMEM for SQ MMA at idx=0 tmem_store_rAccKVAsBF16.store(tTR_rKV.load().to(self.io_dtype)) @@ -1877,10 +1876,10 @@ def kernel( gState_ht = initial_state[None, None, (hidx, state_idx)] else: gState_ht = final_state[None, None, (hidx, state_idx)] - gRow_ht = cute.make_tensor(gState_ht.iterator + local_tidx, cute.make_layout(_D, stride=_D)) - + + gCol_ht = cute.make_tensor(gState_ht.iterator + local_tidx * _D, cute.make_layout(_D, stride=1)) out_flat = cute.make_tensor(tTR_rKV.iterator, layout=cute.make_layout(_D)) - cute.autovec_copy(out_flat, gRow_ht) + cute.autovec_copy(out_flat, gCol_ht) # Advance k_stage_offset by number of chunks in this WU # so next WU's k_stage_idx stays in sync with the K pipeline. @@ -2895,7 +2894,7 @@ def lightning_attn_fwd( V: (B, S, H, D) bf16 value decay: (H,) f32 per-head decay coefficients scale: attention scale factor (default: 1.0) - initial_state: (B, H, D, D) f32 initial state or None + initial_state: (B, H, D, D) f32 initial state in BHVK layout or None output_final_state: whether to output final state chunk_size: chunk size (default: 64) @@ -3111,7 +3110,7 @@ def lightning_attn_fwd_varlen( decay: (H,) f32 per-head decay coefficients cu_seqlens: (N+1,) int32 cumulative sequence lengths scale: attention scale factor (default: 1.0) - state_pool: (pool_size, H, D, D) f32 state pool, or None + state_pool: (pool_size, H, D, D) f32 state pool in BHVK layout, or None If None, a zero state pool is allocated with pool_size=N. States are updated in-place (INPLACE_UPDATE). initial_state_indices: (N,) int32 indices into state_pool per sequence. diff --git a/tests/test_la_decode.py b/tests/test_la_decode.py index d7a3a40..96687ee 100644 --- a/tests/test_la_decode.py +++ b/tests/test_la_decode.py @@ -83,13 +83,8 @@ def make_inputs(B, H, D, device="cuda", seed=42): def run_la_decode(q, k, v, state_4d, decay_scales, scale): """Run la_decode with proper state layout conversion.""" B, H, D, _ = state_4d.shape - # la_decode state layout: [B*H, V, K] (pretransposed) - state_cute = ( - state_4d.clone() - .permute(0, 1, 3, 2) # [B, H, V, K] - .reshape(B * H, D, D) - .contiguous() - ) + # la_decode state layout: [B*H, D, D] (BHVK layout) + state_cute = state_4d.clone().reshape(B * H, D, D) out = torch.zeros(B, H, D, device=q.device, dtype=torch.bfloat16) s_offsets = torch.arange(B, device=q.device, dtype=torch.int32) @@ -111,8 +106,7 @@ def run_la_decode(q, k, v, state_4d, decay_scales, scale): K_SPLIT_DIM=D, V_SPLIT_DIM=D, ) - # Convert state back to [B, H, K, V] - state_out = state_cute.reshape(B, H, D, D).permute(0, 1, 3, 2).contiguous() + state_out = state_cute.reshape(B, H, D, D) return out, state_out @@ -230,6 +224,40 @@ def test_vs_fla(B): max_ref = torch.abs(o_fla.float()).max().item() assert rmse / (max_ref + 1e-8) < 0.005, f"B={B}: vs fla mismatch, rel_rmse={rmse / (max_ref + 1e-8):.6f}" + # --------------------------------------------------------------------------- +# End-to-End Prefill -> Decode Test +# --------------------------------------------------------------------------- +def test_prefill_decode_e2e(): + """Verify prefill output state passes directly into decode without transpose.""" + from cula.ops.lightning_attn import lightning_attn_fwd + + B, S, H, D = 2, 64, 8, 128 + scale = D**-0.5 + decay_scales = 0.5 * torch.arange(H, device="cuda", dtype=torch.float32) / H + + # Dummy prefill tokens + q_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + k_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + v_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + + # 1. Run Prefill (Generates BHVK ht) + _, ht = lightning_attn_fwd( + q_pre, k_pre, v_pre, decay_scales, scale=scale, output_final_state=True + ) + + # Dummy decode tokens + q_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16) + k_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16) + v_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16) + + # 2. Run Decode (Accepts ht directly!) + out_dec, state_new = run_la_decode(q_dec, k_dec, v_dec, ht, decay_scales, scale) + + # 3. Check against PyTorch reference + out_ref, state_new_ref = torch_la_decode_ref(q_dec, k_dec, v_dec, ht, decay_scales, scale) + rmse = torch.sqrt(torch.mean((out_dec.float() - out_ref.float()) ** 2)).item() + max_ref = torch.abs(out_ref.float()).max().item() + assert rmse / (max_ref + 1e-8) < 0.01, "E2E Output mismatch" if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/test_lightning_attn.py b/tests/test_lightning_attn.py index 42fe241..07f7e60 100644 --- a/tests/test_lightning_attn.py +++ b/tests/test_lightning_attn.py @@ -395,7 +395,7 @@ def test_against_fla_with_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, ato Q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 V = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32) * 0.01 + h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32).contiguous() * 0.01 decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) g_gamma = -decay @@ -525,7 +525,7 @@ def test_varlen_with_initial_state(seq_lens=None, H=4, D=128, C=64, decay_val=0. # State pool with 3 slots, use indices [2, 0] pool_size = 3 - state_pool = torch.randn(pool_size, H, D, D, dtype=torch.float32, device="cuda") * 0.01 + state_pool = torch.randn(pool_size, H, D, D, dtype=torch.float32, device="cuda").contiguous() * 0.01 indices = torch.tensor([2, 0], dtype=torch.int32, device="cuda") O_var, sp = run_cute_kernel_varlen( @@ -589,8 +589,7 @@ def test_varlen_against_pytorch_ref( K = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 V = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) - - state_pool = torch.randn(N, H, D, D, dtype=torch.float32, device="cuda") * 0.01 + state_pool = torch.randn(N, H, D, D, dtype=torch.float32, device="cuda").contiguous() * 0.01 O_var, sp = run_cute_kernel_varlen( Q,