Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ device = 'cuda'
q = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(B, T, H, V, device=device, dtype=torch.bfloat16, requires_grad=True)
g = torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1 # gate (log space)
beta = torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid()
g = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) * 0.1 # gate (log space)
beta = torch.randn(B, T, H, device=device, dtype=torch.bfloat16).sigmoid()
A_log = torch.randn(H, device=device, dtype=torch.float32) * 0.01
dt_bias = torch.zeros(H * K, device=device, dtype=torch.float32)
init_state = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
Expand Down Expand Up @@ -92,7 +92,7 @@ print(f'Final state shape: {final_state.shape}') # [2, 32, 128, 128]

**Notes:**
- `safe_gate=True` is required to leverage TensorCore acceleration.
- `beta` and `initial_state` must be `float32`.
- `beta` supports both `float32` and `bfloat16`; `initial_state` must be `float32`.
Comment thread
cherhh marked this conversation as resolved.
- `cu_seqlens` (for variable-length sequences) must be `int32`.

## Usage
Expand Down Expand Up @@ -177,7 +177,7 @@ See [REPO_LAYOUT.md](REPO_LAYOUT.md) for the full directory structure and a summ
* [ ] Continuous optimization via agentic methods such as [AVO](https://arxiv.org/abs/2603.24517).
* [ ] Support for more algorithms.
* [ ] Small B/H/S optimizations.
* [ ] Support for BF16 beta input.
* [x] Support for BF16 beta input.

**Train**

Expand Down
10 changes: 5 additions & 5 deletions USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Both are drop-in replacements for [FLA](https://github.com/fla-org/flash-linear-
**General Notes**

- **`safe_gate=True`** is required to leverage TensorCore (M=16) acceleration.
- **`beta`** and **`initial_state`** must be **`float32`**.
- **`beta`** must be **`float32`** or **`bfloat16`**; **`initial_state`** must be **`float32`**.
- **`cu_seqlens`** (for variable-length sequences) must be **`int32`**.

---
Expand All @@ -39,8 +39,8 @@ device = 'cuda'
q = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(B, T, H, V, device=device, dtype=torch.bfloat16, requires_grad=True)
g = torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1
beta = torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid()
g = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) * 0.1
beta = torch.randn(B, T, H, device=device, dtype=torch.bfloat16).sigmoid()
A_log = torch.randn(H, device=device, dtype=torch.float32) * 0.01
dt_bias = torch.zeros(H * K, device=device, dtype=torch.float32)
init_state = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
Expand Down Expand Up @@ -86,8 +86,8 @@ device = 'cuda'
q = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16)
k = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16)
v = torch.randn(B, T, H, V, device=device, dtype=torch.bfloat16)
g = torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1
beta = torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid()
g = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) * 0.1
beta = torch.randn(B, T, H, device=device, dtype=torch.bfloat16).sigmoid()
A_log = torch.randn(H, device=device, dtype=torch.float32) * 0.01
dt_bias = torch.zeros(H * K, device=device, dtype=torch.float32)
init_state = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
Expand Down
6 changes: 6 additions & 0 deletions csrc/api/kda_sm100.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ ChunkKDAFwdIntra(
params.scale = scale;
params.use_tf32_inverse = use_tf32_inverse;
params.unified_gref = unified_gref;
TORCH_CHECK(beta.dtype() == torch::kFloat32 || beta.dtype() == torch::kBFloat16,
"beta must be float32 or bfloat16, got ", beta.dtype());
params.is_beta_bf16 = (beta.dtype() == torch::kBFloat16);
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.g_ptr = g.data_ptr();
Expand Down Expand Up @@ -83,6 +86,9 @@ ChunkKDAFwdRecompWU(
params.h = k.size(2);
params.d = k.size(3);
params.chunk_size = chunk_size;
TORCH_CHECK(beta.dtype() == torch::kFloat32 || beta.dtype() == torch::kBFloat16,
"beta must be float32 or bfloat16, got ", beta.dtype());
params.is_beta_bf16 = (beta.dtype() == torch::kBFloat16);
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
params.beta_ptr = beta.data_ptr();
Expand Down
68 changes: 46 additions & 22 deletions csrc/api/kda_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ kda_fwd_prefill(

// Extract optional pointers
float const* alpha_ptr = nullptr;
float const* beta_ptr = nullptr;
float const* input_state_ptr = nullptr;

if (alpha_.has_value()) {
Expand All @@ -90,11 +89,11 @@ kda_fwd_prefill(
}
if (beta_.has_value()) {
auto& beta = beta_.value();
TORCH_CHECK(beta.dtype() == torch::kFloat32, "beta must be float32");
TORCH_CHECK(beta.dtype() == torch::kFloat32 || beta.dtype() == torch::kBFloat16,
"beta must be float32 or bfloat16, got ", beta.dtype());
TORCH_CHECK(beta.is_contiguous(), "beta must be contiguous");
TORCH_CHECK(
beta.size(0) == packed_seq && beta.size(1) == num_heads, "beta shape must be [packed_seq, num_heads]");
beta_ptr = beta.data_ptr<float>();
}
if (input_state_.has_value()) {
auto& input_state = input_state_.value();
Expand All @@ -114,25 +113,50 @@ kda_fwd_prefill(
using bf16 = cute::bfloat16_t;
using Sm90 = cutlass::arch::Sm90;

kda::sm90::launch_kda_fwd_prefill_kernel<Sm90, bf16, bf16, float>(
stream,
reinterpret_cast<bf16*>(output.data_ptr()),
output_state.data_ptr<float>(),
reinterpret_cast<bf16 const*>(q.data_ptr()),
reinterpret_cast<bf16 const*>(k.data_ptr()),
reinterpret_cast<bf16 const*>(v.data_ptr()),
input_state_ptr,
alpha_ptr,
beta_ptr,
cu_seqlens.data_ptr<int32_t>(),
workspace_buffer.data_ptr<uint8_t>(),
static_cast<int32_t>(num_seqs),
static_cast<int32_t>(num_heads),
static_cast<int32_t>(head_size),
static_cast<int64_t>(packed_seq),
scale,
safe_gate,
static_cast<int32_t>(sm_count));
bool beta_is_bf16 = beta_.has_value() && beta_.value().dtype() == torch::kBFloat16;

if (beta_is_bf16) {
kda::sm90::launch_kda_fwd_prefill_kernel<Sm90, bf16, bf16, float, bf16>(
stream,
reinterpret_cast<bf16*>(output.data_ptr()),
output_state.data_ptr<float>(),
reinterpret_cast<bf16 const*>(q.data_ptr()),
reinterpret_cast<bf16 const*>(k.data_ptr()),
reinterpret_cast<bf16 const*>(v.data_ptr()),
input_state_ptr,
alpha_ptr,
reinterpret_cast<bf16 const*>(beta_.value().data_ptr()),
cu_seqlens.data_ptr<int32_t>(),
workspace_buffer.data_ptr<uint8_t>(),
static_cast<int32_t>(num_seqs),
static_cast<int32_t>(num_heads),
static_cast<int32_t>(head_size),
static_cast<int64_t>(packed_seq),
scale,
safe_gate,
static_cast<int32_t>(sm_count));
} else {
float const* beta_ptr = beta_.has_value() ? beta_.value().data_ptr<float>() : nullptr;
kda::sm90::launch_kda_fwd_prefill_kernel<Sm90, bf16, bf16, float, float>(
stream,
reinterpret_cast<bf16*>(output.data_ptr()),
output_state.data_ptr<float>(),
reinterpret_cast<bf16 const*>(q.data_ptr()),
reinterpret_cast<bf16 const*>(k.data_ptr()),
reinterpret_cast<bf16 const*>(v.data_ptr()),
input_state_ptr,
alpha_ptr,
beta_ptr,
cu_seqlens.data_ptr<int32_t>(),
workspace_buffer.data_ptr<uint8_t>(),
static_cast<int32_t>(num_seqs),
static_cast<int32_t>(num_heads),
static_cast<int32_t>(head_size),
static_cast<int64_t>(packed_seq),
scale,
safe_gate,
static_cast<int32_t>(sm_count));
}

return {output, output_state};
}
2 changes: 2 additions & 0 deletions csrc/kda/sm100/kda_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct KDA_fwd_intra_params {
float scale;
bool use_tf32_inverse;
bool unified_gref;
bool is_beta_bf16;

void* __restrict__ q_ptr; //[b, t, h, d]
void* __restrict__ k_ptr; //[b, t, h, d]
Expand Down Expand Up @@ -55,6 +56,7 @@ struct KDA_fwd_recomp_w_u_params {
int h;
int d;
int chunk_size;
bool is_beta_bf16;

void* __restrict__ k_ptr; //[b, t, h, d]
void* __restrict__ v_ptr; //[b, t, h, d]
Expand Down
43 changes: 36 additions & 7 deletions csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,19 @@ using KdaChunkFwdIntraKernelSm100_TF32_GHalf =
using KdaChunkFwdIntraKernelSm100_FP16_GHalf =
KdaChunkFwdIntraKernelSm100<KdaChunkFwdIntraMainloopSm100<false, false, false>>;

// BetaBF16 variants: same as above but load beta from bf16 GMEM
using KdaChunkFwdIntraKernelSm100_TF32_BetaBF16 =
KdaChunkFwdIntraKernelSm100<KdaChunkFwdIntraMainloopSm100<true, false, true, __nv_bfloat16>>;

using KdaChunkFwdIntraKernelSm100_FP16_BetaBF16 =
KdaChunkFwdIntraKernelSm100<KdaChunkFwdIntraMainloopSm100<false, false, true, __nv_bfloat16>>;

using KdaChunkFwdIntraKernelSm100_TF32_GHalf_BetaBF16 =
KdaChunkFwdIntraKernelSm100<KdaChunkFwdIntraMainloopSm100<true, false, false, __nv_bfloat16>>;

using KdaChunkFwdIntraKernelSm100_FP16_GHalf_BetaBF16 =
KdaChunkFwdIntraKernelSm100<KdaChunkFwdIntraMainloopSm100<false, false, false, __nv_bfloat16>>;

// ===================================================================
// __global__ kernel wrapper (free function — CUDA requires this)
// ===================================================================
Expand Down Expand Up @@ -378,17 +391,33 @@ run_kda_fwd_intra_sm100_impl_dispatch(KDA_fwd_intra_params& params, cudaStream_t
// ===================================================================
inline void
run_kda_fwd_intra_sm100_impl(KDA_fwd_intra_params& params, cudaStream_t stream) {
if (params.use_tf32_inverse) {
if (params.unified_gref) {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_TF32>(params, stream);
if (params.is_beta_bf16) {
if (params.use_tf32_inverse) {
if (params.unified_gref) {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_TF32_BetaBF16>(params, stream);
} else {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_TF32_GHalf_BetaBF16>(params, stream);
}
} else {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_TF32_GHalf>(params, stream);
if (params.unified_gref) {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_FP16_BetaBF16>(params, stream);
} else {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_FP16_GHalf_BetaBF16>(params, stream);
}
}
} else {
if (params.unified_gref) {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_FP16>(params, stream);
if (params.use_tf32_inverse) {
if (params.unified_gref) {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_TF32>(params, stream);
} else {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_TF32_GHalf>(params, stream);
}
} else {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_FP16_GHalf>(params, stream);
if (params.unified_gref) {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_FP16>(params, stream);
} else {
run_kda_fwd_intra_sm100_impl_dispatch<KdaChunkFwdIntraKernelSm100_FP16_GHalf>(params, stream);
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct KdaChunkFwdIntraSm100NamedBarriers {
// constants, and the persistent loop bodies for each warp role.
// The Kernel struct is templated on this Mainloop.
// ===================================================================
template <bool UseTF32Inverse_ = true, bool RoundingTF32_ = false, bool UnifiedGRef_ = false>
template <bool UseTF32Inverse_ = true, bool RoundingTF32_ = false, bool UnifiedGRef_ = false, typename ElementBeta_ = float>
struct KdaChunkFwdIntraMainloopSm100 {
// ===================== Tile / Buffer Constants =====================
static constexpr int SubTileT = 16;
Expand Down Expand Up @@ -890,8 +890,8 @@ struct KdaChunkFwdIntraMainloopSm100 {
if (thread_idx < TileT) {
shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] =
(thread_idx < sub_seq_len)
? reinterpret_cast<float*>(
params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]
? float(reinterpret_cast<ElementBeta_*>(
params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx])
: float(0);
}
fence_view_async_shared();
Expand Down
18 changes: 16 additions & 2 deletions csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ struct KdaChunkFwdRecompWUKernelSm100 {
using KdaChunkFwdRecompWUKernelSm100Default = KdaChunkFwdRecompWUKernelSm100<KdaChunkFwdRecompWUMainloopSm100<false>>;
using KdaChunkFwdRecompWUKernelSm100StoreQG = KdaChunkFwdRecompWUKernelSm100<KdaChunkFwdRecompWUMainloopSm100<true>>;

// BetaBF16 variants: loads beta from bf16 GMEM
using KdaChunkFwdRecompWUKernelSm100Default_BetaBF16 =
KdaChunkFwdRecompWUKernelSm100<KdaChunkFwdRecompWUMainloopSm100<false, __nv_bfloat16>>;
using KdaChunkFwdRecompWUKernelSm100StoreQG_BetaBF16 =
KdaChunkFwdRecompWUKernelSm100<KdaChunkFwdRecompWUMainloopSm100<true, __nv_bfloat16>>;

// ===================================================================
// __global__ kernel wrapper (free function — CUDA requires this)
// ===================================================================
Expand Down Expand Up @@ -492,9 +498,17 @@ run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cu
inline void
run_kda_fwd_recomp_w_u_sm100_impl(KDA_fwd_recomp_w_u_params& params, cudaStream_t stream) {
if (params.store_qg) {
run_kda_fwd_recomp_w_u_sm100_impl_dispatch<KdaChunkFwdRecompWUKernelSm100StoreQG>(params, stream);
if (params.is_beta_bf16) {
run_kda_fwd_recomp_w_u_sm100_impl_dispatch<KdaChunkFwdRecompWUKernelSm100StoreQG_BetaBF16>(params, stream);
} else {
run_kda_fwd_recomp_w_u_sm100_impl_dispatch<KdaChunkFwdRecompWUKernelSm100StoreQG>(params, stream);
}
} else {
run_kda_fwd_recomp_w_u_sm100_impl_dispatch<KdaChunkFwdRecompWUKernelSm100Default>(params, stream);
if (params.is_beta_bf16) {
run_kda_fwd_recomp_w_u_sm100_impl_dispatch<KdaChunkFwdRecompWUKernelSm100Default_BetaBF16>(params, stream);
} else {
run_kda_fwd_recomp_w_u_sm100_impl_dispatch<KdaChunkFwdRecompWUKernelSm100Default>(params, stream);
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct KdaChunkFwdRecompWUSm100NamedBarriers {
// constants, and the persistent loop bodies for each warp role.
// The Kernel struct is templated on this Mainloop.
// ===================================================================
template <bool StoreQG_ = false>
template <bool StoreQG_ = false, typename ElementBeta_ = float>
struct KdaChunkFwdRecompWUMainloopSm100 {
// ===================== Tile / Buffer Constants =====================
static constexpr int TileT = 64;
Expand Down Expand Up @@ -1008,8 +1008,8 @@ struct KdaChunkFwdRecompWUMainloopSm100 {
if (thread_idx < TileT) {
float beta_val =
(thread_idx < sub_seq_len)
? reinterpret_cast<float*>(
params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]
? float(reinterpret_cast<ElementBeta_*>(
params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx])
Comment thread
KevinZeng08 marked this conversation as resolved.
: float(0);
shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = beta_val;
}
Expand Down
12 changes: 6 additions & 6 deletions csrc/kda/sm90/collective/mainloop_kda_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd {
using ElementAccumulatorKV = ElementAccumulatorKV_;
using ElementO = Element;
using ElementAlpha = float;
// TODO: support bf16 beta
using ElementBeta = float;
using ElementBeta = float; // SMEM + compute stays fp32
using ElementBetaGmem = find_option_t<Tag::kElementBetaGmem, float, Options>; // GMEM type (float or bf16)
using ElementGatedMMA = cutlass::tfloat32_t;

using TileShape = TileShape_;
Expand Down Expand Up @@ -459,7 +459,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd {
using LoadBeta = CollectiveLoadVector<
LoadKindVector::kBeta,
MainloopBetaPipeline,
ElementBeta,
ElementBetaGmem,
GmemLayoutBeta,
ElementBeta,
SmemLayoutBeta,
Expand All @@ -474,7 +474,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd {
float* ptr_output_state; // layout fixed (kdim, vdim, num_heads, num_seqs):LayoutLeft{}
float const* ptr_input_state;
float scale;
ElementBeta const* beta_ptr; GmemStrideBeta beta_stride;
ElementBetaGmem const* beta_ptr; GmemStrideBeta beta_stride;
}; // clang-format on

struct Params {
Expand All @@ -489,7 +489,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd {
float* ptr_output_state;
float const* ptr_input_state;

ElementBeta const* beta_ptr;
ElementBetaGmem const* beta_ptr;
GmemLayoutBeta beta_layout;
};

Expand Down Expand Up @@ -646,7 +646,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd {
// auto collective_load = LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/1.0f, pipeline,
// storage.smem_beta};
auto collective_load =
LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/0.0f, pipeline, storage.smem_beta};
LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/ElementBetaGmem(0), pipeline, storage.smem_beta};
auto src_dst = collective_load.partition_SD(problem_size, tile_shape, work_desc);

CUTE_NO_UNROLL
Expand Down
Loading