Skip to content

feat: BHVK (K-last) state layout for Lightning Attention prefill & decode#56

Merged
icavan merged 10 commits intomainfrom
feat/vk_states
Apr 22, 2026
Merged

feat: BHVK (K-last) state layout for Lightning Attention prefill & decode#56
icavan merged 10 commits intomainfrom
feat/vk_states

Conversation

@icavan
Copy link
Copy Markdown
Collaborator

@icavan icavan commented Apr 19, 2026

feat: BHVK (K-last) state layout for Lightning Attention prefill & decode

NOTE: based on code base from #36, higgsboson1710

Motivation

The decode kernel (la_decode) uses BHVK state layout (B, H, V, K) where K is contiguous, while the prefill kernel (lightning_attn) previously used BHKV (B, H, K, V). This mismatch required a transpose between prefill and decode, adding latency on the critical serving path.

This PR unifies both kernels on the BHVK layout so that prefill output state flows directly into decode without any transpose.

Changes

Kernel (cula/ops/lightning_attn.py)

  • Changed fstate_layout stride from (1, D, ...) (K-major, BHKV) to (D, 1, ...) (K-contiguous, BHVK)
  • Implemented SMEM-mediated cooperative state load/store to maintain coalesced GMEM access despite the VK memory layout:
    • 128 CE threads cooperatively load/store states in row-major order (coalesced 128-bit GMEM transactions)
    • 2 strips × 64 V-rows processed sequentially through a 33KB padded SMEM buffer
    • SMEM row stride = D + 4 (132 for D=128) eliminates bank conflicts (stride mod 32 = 4)
    • Added sStateBuf field to SharedStorage struct (~33KB, within 228KB SMEM budget at occ=1)

Tests (tests/test_lightning_attn.py, tests/test_la_decode.py)

  • All state inputs converted BHKV→BHVK via .transpose(-1, -2).contiguous() before passing to CuTe kernel
  • All state outputs converted BHVK→BHKV via .transpose(-1, -2) before comparing with FLA/PyTorch reference
  • Added test_prefill_decode_e2e: verifies prefill ht passes directly into decode without manual transpose
  • API docstrings updated to document BHVK layout requirement

Benchmarks (benchmarks/bench_lightning_attn.py)

  • h0_cute converted to BHVK before timing; ht converted back to BHKV for accuracy comparison
  • Transpose is outside time_fn — not included in reported timings

Performance (vs FLA Triton baseline, GB200)

Mode Avg Speedup Min Max
h0_ht (prefill with state) 1.33x 0.90x 1.88x
varlen (persistent) 1.44x 0.83x 2.05x
no_state (prefill only) 1.50x 1.12x 1.89x

The no_state mode is unaffected by this change (no state load/store). The h0_ht and varlen modes show ~1-5% overhead vs the previous BHKV kernel (from the SMEM transpose), which is acceptable given the elimination of the prefill→decode transpose on the serving path.

Performance Evolution (vs origin/main BHKV baseline)

Mode origin/main (BHKV) VK version 1 (no SMEM) Final (SMEM-mediated) Recovery
h0_ht avg 1.32x avg 1.24x (-6%) avg 1.33x (+0.8%) fully recovered
varlen avg 1.54x avg 1.12x (-27%) avg 1.44x (-6.5%) mostly recovered
no_state avg 1.49x avg 1.50x (flat) avg 1.50x (flat) unaffected

The naive VK layout change caused severe varlen regression (-27%) due to uncoalesced column-strided GMEM accesses (each thread writing stride-D elements). The SMEM cooperative transpose recovers most of the loss:

  • h0_ht: fully recovered (+0.8% vs BHKV, within noise)
  • varlen: recovered from -27% to -6.5%, remaining gap from SMEM transpose overhead in the persistent loop
  • no_state: unchanged (no state I/O)

Precision

RMSE matches FLA exactly — no precision loss from the layout change:

  • Output O RMSE: 0.2347% (identical to FLA)
  • State Ht RMSE: 0.0198% (identical to FLA)
  • la_decode: 20/20 tests pass, state relative RMSE < 0.1%

Test Results

tests/test_lightning_attn.py — 10/10 passed
tests/test_la_decode.py      — 20/20 passed

🚀 Pull Request Checklist

Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing.

⚡ Performance

Reviewer Notes

higgsboson1710 and others added 10 commits April 5, 2026 21:11
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Implement cooperative GMEM↔SMEM transpose for VK state access:
- 128 CE threads load/store states cooperatively (coalesced GMEM)
- 2 strips of 64 V-rows with padded SMEM (stride D+4) to eliminate bank conflicts
- Row-based addressing: each iteration covers 4 rows × 32 cols with LDG/STG.128
- Per-thread SMEM read/write uses padded stride for conflict-free bank access

Performance (vs FLA): h0_ht avg 1.33x, varlen avg 1.44x, no_state avg 1.46x
Precision: RMSE matches FLA exactly (0.2347% O, 0.0198% Ht)
All 10 tests pass
run_la_decode now transposes state_4d from BHKV to BHVK before passing
to the kernel, and transposes output state back to BHKV for comparison.
E2E test also converts prefill BHVK output to BHKV before passing to
run_la_decode and torch_la_decode_ref.

All 20 tests pass.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request transitions the lightning attention kernel to a BHVK (K-contiguous) layout for initial and final states, improving global memory access efficiency through a shared memory transpose buffer and cooperative loading/storing logic. Benchmarks and tests have been updated to accommodate this layout change, and a new end-to-end prefill-to-decode test has been added. I have no feedback to provide.

Copy link
Copy Markdown
Collaborator

@KevinZeng08 KevinZeng08 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@zheyang0825 zheyang0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@icavan icavan mentioned this pull request Apr 20, 2026
@icavan icavan merged commit 33ec0b3 into main Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants