Skip to content

Feat/pretransposed states#36

Open
higgsboson1710 wants to merge 2 commits intoinclusionAI:mainfrom
higgsboson1710:feat/pretransposed-states
Open

Feat/pretransposed states#36
higgsboson1710 wants to merge 2 commits intoinclusionAI:mainfrom
higgsboson1710:feat/pretransposed-states

Conversation

@higgsboson1710
Copy link
Copy Markdown

This PR implements the pre-transposed BHVK state layout optimization.

Updated the core C++/CUDA kernel and Python API to natively handle the BHVK layout.

Updated tests/test_lightning_attn.py and tests/test_la_decode.py to match the new layout.

Added an end-to-end prefill → decode test to verify the state passes directly without manual transposes

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request transitions the attention state layout from Column-Major to Row-Major (BHVK) within the lightning attention kernels and updates the associated loading and storing logic. The review feedback highlights critical concerns regarding the use of non-contiguous transposed tensors with kernels that assume fixed memory layouts, which could lead to silent data corruption. The reviewer recommends using contiguous allocations, simplifying the test suite by removing redundant transpose operations, and correcting minor indentation inconsistencies.

# Allocate state pool if not provided
if state_pool is None:
state_pool = torch.zeros(N, H, D, D, dtype=torch.float32, device=Q.device)
state_pool = torch.zeros(N, H, D, D, dtype=torch.float32, device=Q.device).transpose(-1,-2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This in-place update on a non-contiguous state_pool is highly problematic. Because the kernel operates on the raw pointer using a fixed layout, it will write data to memory locations that do not align with the non-contiguous strides of the transposed PyTorch view. This will lead to silent data corruption or unexpected results when the state_pool is accessed later in Python. The state pool should be allocated as a contiguous tensor in the layout the kernel expects.

Suggested change
state_pool = torch.zeros(N, H, D, D, dtype=torch.float32, device=Q.device).transpose(-1,-2)
state_pool = torch.zeros(N, H, D, D, dtype=torch.float32, device=Q.device)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We assume the input states & output states are all pretransposed. Hence, we do not need to transpose the input states explicitly.

Comment on lines +413 to +415
fstate_layout = cute.make_layout((D, D, (H, B)),
stride=(D, 1, (D * D, D * D * H)),
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The indentation of the stride argument and the closing parenthesis for make_layout is inconsistent with the surrounding code. It is better to keep the arguments aligned for readability.

        fstate_layout = cute.make_layout((D, D, (H, B)),
                                         stride=(D, 1, (D * D, D * D * H)))


if output_final_state:
ht = torch.zeros(B, H, D, D, dtype=torch.float32, device=Q.device)
ht = torch.zeros(B, H, D, D, dtype=torch.float32, device=Q.device).transpose(-1,-2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Applying .transpose(-1, -2) to the newly allocated ht tensor creates a non-contiguous view. Since the kernel (at line 417/424) wraps the raw pointer in a fixed fstate_layout that assumes a specific physical memory ordering (Row-Major with strides D, 1), passing a non-contiguous tensor is dangerous. If the kernel writes to the pointer in Row-Major order, the resulting ht tensor in Python will appear transposed logically. It would be much cleaner to allocate the tensor with the correct shape and ensure it is contiguous, avoiding reliance on the kernel ignoring PyTorch strides.

Suggested change
ht = torch.zeros(B, H, D, D, dtype=torch.float32, device=Q.device).transpose(-1,-2)
ht = torch.zeros(B, H, D, D, dtype=torch.float32, device=Q.device)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@higgsboson1710 since the ht is a zero tensor, we do not need the transpose.

k = torch.randn(B, H, D, device=device, dtype=torch.bfloat16)
v = torch.randn(B, H, D, device=device, dtype=torch.bfloat16)
state = torch.randn(B, H, D, D, device=device, dtype=torch.float32) * 0.01
state = torch.randn(B, H, D, D, device=device, dtype=torch.float32).transpose(-1, -2).contiguous().transpose(-1, -2) * 0.01
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The double-transpose transpose(-1, -2).contiguous().transpose(-1, -2) creates a non-contiguous tensor. However, in run_la_decode (line 87), this tensor is immediately cloned, which produces a contiguous version and negates the effect of the previous layout manipulation. This adds unnecessary complexity to the test setup. If the goal is to test the kernel with BHVK layout, simply providing a standard contiguous tensor of the appropriate shape is sufficient and clearer.

Suggested change
state = torch.randn(B, H, D, D, device=device, dtype=torch.float32).transpose(-1, -2).contiguous().transpose(-1, -2) * 0.01
state = torch.randn(B, H, D, D, device=device, dtype=torch.float32) * 0.01

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed above, I think you could just remove the transpose here.

V = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1
decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32)
h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32) * 0.01
h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32).transpose(-1, -2).contiguous().transpose(-1, -2) * 0.01
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of the double-transpose pattern here and in subsequent tests (lines 398, 528, 592) to create non-contiguous inputs is confusing. Given that the kernel ignores input strides and uses a fixed Row-Major layout, passing a physically Column-Major tensor (which this pattern produces) will result in the kernel seeing a transposed matrix. It is recommended to use standard contiguous tensors to ensure the layout is handled predictably.

Suggested change
h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32).transpose(-1, -2).contiguous().transpose(-1, -2) * 0.01
h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32) * 0.01

@higgsboson1710
Copy link
Copy Markdown
Author

"Hi @icavan, I've completed the roadmap. I see the bot is flagging the non-contiguous state pool allocations and the double-transpose pattern in the tests. I used these to ensure the memory matches the new BHVK layout, but let me know if you'd prefer I refactor these to standard contiguous allocations to satisfy the linter/bot."


if output_final_state:
ht = torch.zeros(B, H, D, D, dtype=torch.float32, device=Q.device)
ht = torch.zeros(B, H, D, D, dtype=torch.float32, device=Q.device).transpose(-1,-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@higgsboson1710 since the ht is a zero tensor, we do not need the transpose.

# Allocate state pool if not provided
if state_pool is None:
state_pool = torch.zeros(N, H, D, D, dtype=torch.float32, device=Q.device)
state_pool = torch.zeros(N, H, D, D, dtype=torch.float32, device=Q.device).transpose(-1,-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We assume the input states & output states are all pretransposed. Hence, we do not need to transpose the input states explicitly.

k = torch.randn(B, H, D, device=device, dtype=torch.bfloat16)
v = torch.randn(B, H, D, device=device, dtype=torch.bfloat16)
state = torch.randn(B, H, D, D, device=device, dtype=torch.float32) * 0.01
state = torch.randn(B, H, D, D, device=device, dtype=torch.float32).transpose(-1, -2).contiguous().transpose(-1, -2) * 0.01
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed above, I think you could just remove the transpose here.

gCol_ht = cute.make_tensor(gState_ht.iterator + local_tidx * _D, cute.make_layout(_D, stride=1))
out_flat = cute.make_tensor(tTR_rKV.iterator, layout=cute.make_layout(_D))
cute.autovec_copy(out_flat, gRow_ht)
cute.autovec_copy(out_flat, gCol_ht)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to track the performance change here. Could you share the results of bench_lightning_attn.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants