Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions cula/ops/lightning_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,9 @@ def __call__(
# and the parameter is eliminated at compile time via const_expr guards.
# For varlen: state pool is [pool_size, H, D, D]. We use B (=N) as the
# pool dimension — strides are correct regardless of actual pool_size.
fstate_layout = cute.make_layout(
(D, D, (H, B)),
stride=(1, D, (D * D, D * D * H)),
)
fstate_layout = cute.make_layout((D, D, (H, B)),
stride=(D, 1, (D * D, D * D * H)),
)
Comment thread
higgsboson1710 marked this conversation as resolved.
if cutlass.const_expr(self.has_initial_state):
initial_state = cute.make_tensor(initial_state_in.iterator, fstate_layout)
else:
Expand Down Expand Up @@ -1703,8 +1702,8 @@ def kernel(
# -------------- Initial State Loading (h0) ----------------
if cutlass.const_expr(self.has_initial_state):
gState_h0 = initial_state[None, None, (hidx, state_idx)]
gRow_h0 = cute.make_tensor(gState_h0.iterator + local_tidx, cute.make_layout(_D, stride=_D))
cute.autovec_copy(gRow_h0, init_flat)
gCol_h0 = cute.make_tensor(gState_h0.iterator + local_tidx * _D, cute.make_layout(_D, stride=1))
cute.autovec_copy(gCol_h0, init_flat)

# Store raw h0 as BF16 to kv16 TMEM for SQ MMA at idx=0
tmem_store_rAccKVAsBF16.store(tTR_rKV.load().to(self.io_dtype))
Expand Down Expand Up @@ -1877,10 +1876,10 @@ def kernel(
gState_ht = initial_state[None, None, (hidx, state_idx)]
else:
gState_ht = final_state[None, None, (hidx, state_idx)]
gRow_ht = cute.make_tensor(gState_ht.iterator + local_tidx, cute.make_layout(_D, stride=_D))


gCol_ht = cute.make_tensor(gState_ht.iterator + local_tidx * _D, cute.make_layout(_D, stride=1))
out_flat = cute.make_tensor(tTR_rKV.iterator, layout=cute.make_layout(_D))
cute.autovec_copy(out_flat, gRow_ht)
cute.autovec_copy(out_flat, gCol_ht)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to track the performance change here. Could you share the results of bench_lightning_attn.py


# Advance k_stage_offset by number of chunks in this WU
# so next WU's k_stage_idx stays in sync with the K pipeline.
Expand Down Expand Up @@ -2895,7 +2894,7 @@ def lightning_attn_fwd(
V: (B, S, H, D) bf16 value
decay: (H,) f32 per-head decay coefficients
scale: attention scale factor (default: 1.0)
initial_state: (B, H, D, D) f32 initial state or None
initial_state: (B, H, D, D) f32 initial state in BHVK layout or None
output_final_state: whether to output final state
chunk_size: chunk size (default: 64)

Expand Down Expand Up @@ -3111,7 +3110,7 @@ def lightning_attn_fwd_varlen(
decay: (H,) f32 per-head decay coefficients
cu_seqlens: (N+1,) int32 cumulative sequence lengths
scale: attention scale factor (default: 1.0)
state_pool: (pool_size, H, D, D) f32 state pool, or None
state_pool: (pool_size, H, D, D) f32 state pool in BHVK layout, or None
If None, a zero state pool is allocated with pool_size=N.
States are updated in-place (INPLACE_UPDATE).
initial_state_indices: (N,) int32 indices into state_pool per sequence.
Expand Down
46 changes: 37 additions & 9 deletions tests/test_la_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,8 @@ def make_inputs(B, H, D, device="cuda", seed=42):
def run_la_decode(q, k, v, state_4d, decay_scales, scale):
"""Run la_decode with proper state layout conversion."""
B, H, D, _ = state_4d.shape
# la_decode state layout: [B*H, V, K] (pretransposed)
state_cute = (
state_4d.clone()
.permute(0, 1, 3, 2) # [B, H, V, K]
.reshape(B * H, D, D)
.contiguous()
)
# la_decode state layout: [B*H, D, D] (BHVK layout)
state_cute = state_4d.clone().reshape(B * H, D, D)
out = torch.zeros(B, H, D, device=q.device, dtype=torch.bfloat16)
s_offsets = torch.arange(B, device=q.device, dtype=torch.int32)

Expand All @@ -111,8 +106,7 @@ def run_la_decode(q, k, v, state_4d, decay_scales, scale):
K_SPLIT_DIM=D,
V_SPLIT_DIM=D,
)
# Convert state back to [B, H, K, V]
state_out = state_cute.reshape(B, H, D, D).permute(0, 1, 3, 2).contiguous()
state_out = state_cute.reshape(B, H, D, D)
return out, state_out


Expand Down Expand Up @@ -230,6 +224,40 @@ def test_vs_fla(B):
max_ref = torch.abs(o_fla.float()).max().item()
assert rmse / (max_ref + 1e-8) < 0.005, f"B={B}: vs fla mismatch, rel_rmse={rmse / (max_ref + 1e-8):.6f}"

# ---------------------------------------------------------------------------
# End-to-End Prefill -> Decode Test
# ---------------------------------------------------------------------------
def test_prefill_decode_e2e():
"""Verify prefill output state passes directly into decode without transpose."""
from cula.ops.lightning_attn import lightning_attn_fwd

B, S, H, D = 2, 64, 8, 128
scale = D**-0.5
decay_scales = 0.5 * torch.arange(H, device="cuda", dtype=torch.float32) / H

# Dummy prefill tokens
q_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
k_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
v_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)

# 1. Run Prefill (Generates BHVK ht)
_, ht = lightning_attn_fwd(
q_pre, k_pre, v_pre, decay_scales, scale=scale, output_final_state=True
)

# Dummy decode tokens
q_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
k_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)
v_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16)

# 2. Run Decode (Accepts ht directly!)
out_dec, state_new = run_la_decode(q_dec, k_dec, v_dec, ht, decay_scales, scale)

# 3. Check against PyTorch reference
out_ref, state_new_ref = torch_la_decode_ref(q_dec, k_dec, v_dec, ht, decay_scales, scale)

rmse = torch.sqrt(torch.mean((out_dec.float() - out_ref.float()) ** 2)).item()
max_ref = torch.abs(out_ref.float()).max().item()
assert rmse / (max_ref + 1e-8) < 0.01, "E2E Output mismatch"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
7 changes: 3 additions & 4 deletions tests/test_lightning_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_against_fla_with_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, ato
Q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1
K = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1
V = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1
h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32) * 0.01
h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32).contiguous() * 0.01

decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32)
g_gamma = -decay
Expand Down Expand Up @@ -525,7 +525,7 @@ def test_varlen_with_initial_state(seq_lens=None, H=4, D=128, C=64, decay_val=0.

# State pool with 3 slots, use indices [2, 0]
pool_size = 3
state_pool = torch.randn(pool_size, H, D, D, dtype=torch.float32, device="cuda") * 0.01
state_pool = torch.randn(pool_size, H, D, D, dtype=torch.float32, device="cuda").contiguous() * 0.01
indices = torch.tensor([2, 0], dtype=torch.int32, device="cuda")

O_var, sp = run_cute_kernel_varlen(
Expand Down Expand Up @@ -589,8 +589,7 @@ def test_varlen_against_pytorch_ref(
K = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1
V = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1
decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32)

state_pool = torch.randn(N, H, D, D, dtype=torch.float32, device="cuda") * 0.01
state_pool = torch.randn(N, H, D, D, dtype=torch.float32, device="cuda").contiguous() * 0.01

O_var, sp = run_cute_kernel_varlen(
Q,
Expand Down