feat: BHVK (K-last) state layout for Lightning Attention prefill & decode#56
Merged
feat: BHVK (K-last) state layout for Lightning Attention prefill & decode#56
Conversation
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Implement cooperative GMEM↔SMEM transpose for VK state access: - 128 CE threads load/store states cooperatively (coalesced GMEM) - 2 strips of 64 V-rows with padded SMEM (stride D+4) to eliminate bank conflicts - Row-based addressing: each iteration covers 4 rows × 32 cols with LDG/STG.128 - Per-thread SMEM read/write uses padded stride for conflict-free bank access Performance (vs FLA): h0_ht avg 1.33x, varlen avg 1.44x, no_state avg 1.46x Precision: RMSE matches FLA exactly (0.2347% O, 0.0198% Ht) All 10 tests pass
run_la_decode now transposes state_4d from BHKV to BHVK before passing to the kernel, and transposes output state back to BHKV for comparison. E2E test also converts prefill BHVK output to BHKV before passing to run_la_decode and torch_la_decode_ref. All 20 tests pass.
Contributor
There was a problem hiding this comment.
Code Review
This pull request transitions the lightning attention kernel to a BHVK (K-contiguous) layout for initial and final states, improving global memory access efficiency through a shared memory transpose buffer and cooperative loading/storing logic. Benchmarks and tests have been updated to accommodate this layout change, and a new end-to-end prefill-to-decode test has been added. I have no feedback to provide.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat: BHVK (K-last) state layout for Lightning Attention prefill & decode
NOTE: based on code base from #36, higgsboson1710
Motivation
The decode kernel (
la_decode) uses BHVK state layout(B, H, V, K)where K is contiguous, while the prefill kernel (lightning_attn) previously used BHKV(B, H, K, V). This mismatch required a transpose between prefill and decode, adding latency on the critical serving path.This PR unifies both kernels on the BHVK layout so that prefill output state flows directly into decode without any transpose.
Changes
Kernel (
cula/ops/lightning_attn.py)fstate_layoutstride from(1, D, ...)(K-major, BHKV) to(D, 1, ...)(K-contiguous, BHVK)D + 4(132 for D=128) eliminates bank conflicts (stride mod 32 = 4)sStateBuffield toSharedStoragestruct (~33KB, within 228KB SMEM budget at occ=1)Tests (
tests/test_lightning_attn.py,tests/test_la_decode.py).transpose(-1, -2).contiguous()before passing to CuTe kernel.transpose(-1, -2)before comparing with FLA/PyTorch referencetest_prefill_decode_e2e: verifies prefillhtpasses directly into decode without manual transposeBenchmarks (
benchmarks/bench_lightning_attn.py)h0_cuteconverted to BHVK before timing;htconverted back to BHKV for accuracy comparisontime_fn— not included in reported timingsPerformance (vs FLA Triton baseline, GB200)
The
no_statemode is unaffected by this change (no state load/store). Theh0_htandvarlenmodes show ~1-5% overhead vs the previous BHKV kernel (from the SMEM transpose), which is acceptable given the elimination of the prefill→decode transpose on the serving path.Performance Evolution (vs origin/main BHKV baseline)
The naive VK layout change caused severe varlen regression (-27%) due to uncoalesced column-strided GMEM accesses (each thread writing stride-D elements). The SMEM cooperative transpose recovers most of the loss:
Precision
RMSE matches FLA exactly — no precision loss from the layout change:
Test Results
🚀 Pull Request Checklist
Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
⚡ Performance
Reviewer Notes