diff --git a/README.md b/README.md index f62c325..1622309 100644 --- a/README.md +++ b/README.md @@ -63,8 +63,8 @@ device = 'cuda' q = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16, requires_grad=True) k = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16, requires_grad=True) v = torch.randn(B, T, H, V, device=device, dtype=torch.bfloat16, requires_grad=True) -g = torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1 # gate (log space) -beta = torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid() +g = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) * 0.1 # gate (log space) +beta = torch.randn(B, T, H, device=device, dtype=torch.bfloat16).sigmoid() A_log = torch.randn(H, device=device, dtype=torch.float32) * 0.01 dt_bias = torch.zeros(H * K, device=device, dtype=torch.float32) init_state = torch.zeros(B, H, K, V, device=device, dtype=torch.float32) @@ -92,7 +92,7 @@ print(f'Final state shape: {final_state.shape}') # [2, 32, 128, 128] **Notes:** - `safe_gate=True` is required to leverage TensorCore acceleration. -- `beta` and `initial_state` must be `float32`. +- `beta` supports both `float32` and `bfloat16`; `initial_state` must be `float32`. - `cu_seqlens` (for variable-length sequences) must be `int32`. ## Usage @@ -177,7 +177,7 @@ See [REPO_LAYOUT.md](REPO_LAYOUT.md) for the full directory structure and a summ * [ ] Continuous optimization via agentic methods such as [AVO](https://arxiv.org/abs/2603.24517). * [ ] Support for more algorithms. * [ ] Small B/H/S optimizations. -* [ ] Support for BF16 beta input. +* [x] Support for BF16 beta input. **Train** diff --git a/USAGE.md b/USAGE.md index 0b8b9c1..198d4b8 100644 --- a/USAGE.md +++ b/USAGE.md @@ -18,7 +18,7 @@ Both are drop-in replacements for [FLA](https://github.com/fla-org/flash-linear- **General Notes** - **`safe_gate=True`** is required to leverage TensorCore (M=16) acceleration. -- **`beta`** and **`initial_state`** must be **`float32`**. +- **`beta`** must be **`float32`** or **`bfloat16`**; **`initial_state`** must be **`float32`**. - **`cu_seqlens`** (for variable-length sequences) must be **`int32`**. --- @@ -39,8 +39,8 @@ device = 'cuda' q = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16, requires_grad=True) k = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16, requires_grad=True) v = torch.randn(B, T, H, V, device=device, dtype=torch.bfloat16, requires_grad=True) -g = torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1 -beta = torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid() +g = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) * 0.1 +beta = torch.randn(B, T, H, device=device, dtype=torch.bfloat16).sigmoid() A_log = torch.randn(H, device=device, dtype=torch.float32) * 0.01 dt_bias = torch.zeros(H * K, device=device, dtype=torch.float32) init_state = torch.zeros(B, H, K, V, device=device, dtype=torch.float32) @@ -86,8 +86,8 @@ device = 'cuda' q = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) k = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) v = torch.randn(B, T, H, V, device=device, dtype=torch.bfloat16) -g = torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1 -beta = torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid() +g = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) * 0.1 +beta = torch.randn(B, T, H, device=device, dtype=torch.bfloat16).sigmoid() A_log = torch.randn(H, device=device, dtype=torch.float32) * 0.01 dt_bias = torch.zeros(H * K, device=device, dtype=torch.float32) init_state = torch.zeros(B, H, K, V, device=device, dtype=torch.float32) diff --git a/csrc/api/kda_sm100.cu b/csrc/api/kda_sm100.cu index 918b99b..5c5b236 100644 --- a/csrc/api/kda_sm100.cu +++ b/csrc/api/kda_sm100.cu @@ -43,6 +43,9 @@ ChunkKDAFwdIntra( params.scale = scale; params.use_tf32_inverse = use_tf32_inverse; params.unified_gref = unified_gref; + TORCH_CHECK(beta.dtype() == torch::kFloat32 || beta.dtype() == torch::kBFloat16, + "beta must be float32 or bfloat16, got ", beta.dtype()); + params.is_beta_bf16 = (beta.dtype() == torch::kBFloat16); params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.g_ptr = g.data_ptr(); @@ -83,6 +86,9 @@ ChunkKDAFwdRecompWU( params.h = k.size(2); params.d = k.size(3); params.chunk_size = chunk_size; + TORCH_CHECK(beta.dtype() == torch::kFloat32 || beta.dtype() == torch::kBFloat16, + "beta must be float32 or bfloat16, got ", beta.dtype()); + params.is_beta_bf16 = (beta.dtype() == torch::kBFloat16); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); params.beta_ptr = beta.data_ptr(); diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index 842794c..5e49d94 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -76,7 +76,6 @@ kda_fwd_prefill( // Extract optional pointers float const* alpha_ptr = nullptr; - float const* beta_ptr = nullptr; float const* input_state_ptr = nullptr; if (alpha_.has_value()) { @@ -90,11 +89,11 @@ kda_fwd_prefill( } if (beta_.has_value()) { auto& beta = beta_.value(); - TORCH_CHECK(beta.dtype() == torch::kFloat32, "beta must be float32"); + TORCH_CHECK(beta.dtype() == torch::kFloat32 || beta.dtype() == torch::kBFloat16, + "beta must be float32 or bfloat16, got ", beta.dtype()); TORCH_CHECK(beta.is_contiguous(), "beta must be contiguous"); TORCH_CHECK( beta.size(0) == packed_seq && beta.size(1) == num_heads, "beta shape must be [packed_seq, num_heads]"); - beta_ptr = beta.data_ptr(); } if (input_state_.has_value()) { auto& input_state = input_state_.value(); @@ -114,25 +113,50 @@ kda_fwd_prefill( using bf16 = cute::bfloat16_t; using Sm90 = cutlass::arch::Sm90; - kda::sm90::launch_kda_fwd_prefill_kernel( - stream, - reinterpret_cast(output.data_ptr()), - output_state.data_ptr(), - reinterpret_cast(q.data_ptr()), - reinterpret_cast(k.data_ptr()), - reinterpret_cast(v.data_ptr()), - input_state_ptr, - alpha_ptr, - beta_ptr, - cu_seqlens.data_ptr(), - workspace_buffer.data_ptr(), - static_cast(num_seqs), - static_cast(num_heads), - static_cast(head_size), - static_cast(packed_seq), - scale, - safe_gate, - static_cast(sm_count)); + bool beta_is_bf16 = beta_.has_value() && beta_.value().dtype() == torch::kBFloat16; + + if (beta_is_bf16) { + kda::sm90::launch_kda_fwd_prefill_kernel( + stream, + reinterpret_cast(output.data_ptr()), + output_state.data_ptr(), + reinterpret_cast(q.data_ptr()), + reinterpret_cast(k.data_ptr()), + reinterpret_cast(v.data_ptr()), + input_state_ptr, + alpha_ptr, + reinterpret_cast(beta_.value().data_ptr()), + cu_seqlens.data_ptr(), + workspace_buffer.data_ptr(), + static_cast(num_seqs), + static_cast(num_heads), + static_cast(head_size), + static_cast(packed_seq), + scale, + safe_gate, + static_cast(sm_count)); + } else { + float const* beta_ptr = beta_.has_value() ? beta_.value().data_ptr() : nullptr; + kda::sm90::launch_kda_fwd_prefill_kernel( + stream, + reinterpret_cast(output.data_ptr()), + output_state.data_ptr(), + reinterpret_cast(q.data_ptr()), + reinterpret_cast(k.data_ptr()), + reinterpret_cast(v.data_ptr()), + input_state_ptr, + alpha_ptr, + beta_ptr, + cu_seqlens.data_ptr(), + workspace_buffer.data_ptr(), + static_cast(num_seqs), + static_cast(num_heads), + static_cast(head_size), + static_cast(packed_seq), + scale, + safe_gate, + static_cast(sm_count)); + } return {output, output_state}; } diff --git a/csrc/kda/sm100/kda_config.hpp b/csrc/kda/sm100/kda_config.hpp index 3f38c09..6f96529 100644 --- a/csrc/kda/sm100/kda_config.hpp +++ b/csrc/kda/sm100/kda_config.hpp @@ -28,6 +28,7 @@ struct KDA_fwd_intra_params { float scale; bool use_tf32_inverse; bool unified_gref; + bool is_beta_bf16; void* __restrict__ q_ptr; //[b, t, h, d] void* __restrict__ k_ptr; //[b, t, h, d] @@ -55,6 +56,7 @@ struct KDA_fwd_recomp_w_u_params { int h; int d; int chunk_size; + bool is_beta_bf16; void* __restrict__ k_ptr; //[b, t, h, d] void* __restrict__ v_ptr; //[b, t, h, d] diff --git a/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp b/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp index 4aaeabc..ad9155e 100644 --- a/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp @@ -317,6 +317,19 @@ using KdaChunkFwdIntraKernelSm100_TF32_GHalf = using KdaChunkFwdIntraKernelSm100_FP16_GHalf = KdaChunkFwdIntraKernelSm100>; +// BetaBF16 variants: same as above but load beta from bf16 GMEM +using KdaChunkFwdIntraKernelSm100_TF32_BetaBF16 = + KdaChunkFwdIntraKernelSm100>; + +using KdaChunkFwdIntraKernelSm100_FP16_BetaBF16 = + KdaChunkFwdIntraKernelSm100>; + +using KdaChunkFwdIntraKernelSm100_TF32_GHalf_BetaBF16 = + KdaChunkFwdIntraKernelSm100>; + +using KdaChunkFwdIntraKernelSm100_FP16_GHalf_BetaBF16 = + KdaChunkFwdIntraKernelSm100>; + // =================================================================== // __global__ kernel wrapper (free function — CUDA requires this) // =================================================================== @@ -378,17 +391,33 @@ run_kda_fwd_intra_sm100_impl_dispatch(KDA_fwd_intra_params& params, cudaStream_t // =================================================================== inline void run_kda_fwd_intra_sm100_impl(KDA_fwd_intra_params& params, cudaStream_t stream) { - if (params.use_tf32_inverse) { - if (params.unified_gref) { - run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + if (params.is_beta_bf16) { + if (params.use_tf32_inverse) { + if (params.unified_gref) { + run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + } else { + run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + } } else { - run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + if (params.unified_gref) { + run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + } else { + run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + } } } else { - if (params.unified_gref) { - run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + if (params.use_tf32_inverse) { + if (params.unified_gref) { + run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + } else { + run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + } } else { - run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + if (params.unified_gref) { + run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + } else { + run_kda_fwd_intra_sm100_impl_dispatch(params, stream); + } } } } diff --git a/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp b/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp index 2784006..bc1f85d 100644 --- a/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp @@ -44,7 +44,7 @@ struct KdaChunkFwdIntraSm100NamedBarriers { // constants, and the persistent loop bodies for each warp role. // The Kernel struct is templated on this Mainloop. // =================================================================== -template +template struct KdaChunkFwdIntraMainloopSm100 { // ===================== Tile / Buffer Constants ===================== static constexpr int SubTileT = 16; @@ -890,8 +890,8 @@ struct KdaChunkFwdIntraMainloopSm100 { if (thread_idx < TileT) { shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = (thread_idx < sub_seq_len) - ? reinterpret_cast( - params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx] + ? float(reinterpret_cast( + params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) : float(0); } fence_view_async_shared(); diff --git a/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp b/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp index 3a77b9d..b2c96ed 100644 --- a/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp @@ -412,6 +412,12 @@ struct KdaChunkFwdRecompWUKernelSm100 { using KdaChunkFwdRecompWUKernelSm100Default = KdaChunkFwdRecompWUKernelSm100>; using KdaChunkFwdRecompWUKernelSm100StoreQG = KdaChunkFwdRecompWUKernelSm100>; +// BetaBF16 variants: loads beta from bf16 GMEM +using KdaChunkFwdRecompWUKernelSm100Default_BetaBF16 = + KdaChunkFwdRecompWUKernelSm100>; +using KdaChunkFwdRecompWUKernelSm100StoreQG_BetaBF16 = + KdaChunkFwdRecompWUKernelSm100>; + // =================================================================== // __global__ kernel wrapper (free function — CUDA requires this) // =================================================================== @@ -492,9 +498,17 @@ run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cu inline void run_kda_fwd_recomp_w_u_sm100_impl(KDA_fwd_recomp_w_u_params& params, cudaStream_t stream) { if (params.store_qg) { - run_kda_fwd_recomp_w_u_sm100_impl_dispatch(params, stream); + if (params.is_beta_bf16) { + run_kda_fwd_recomp_w_u_sm100_impl_dispatch(params, stream); + } else { + run_kda_fwd_recomp_w_u_sm100_impl_dispatch(params, stream); + } } else { - run_kda_fwd_recomp_w_u_sm100_impl_dispatch(params, stream); + if (params.is_beta_bf16) { + run_kda_fwd_recomp_w_u_sm100_impl_dispatch(params, stream); + } else { + run_kda_fwd_recomp_w_u_sm100_impl_dispatch(params, stream); + } } } diff --git a/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp b/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp index 85de05b..db09b08 100644 --- a/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp @@ -35,7 +35,7 @@ struct KdaChunkFwdRecompWUSm100NamedBarriers { // constants, and the persistent loop bodies for each warp role. // The Kernel struct is templated on this Mainloop. // =================================================================== -template +template struct KdaChunkFwdRecompWUMainloopSm100 { // ===================== Tile / Buffer Constants ===================== static constexpr int TileT = 64; @@ -1008,8 +1008,8 @@ struct KdaChunkFwdRecompWUMainloopSm100 { if (thread_idx < TileT) { float beta_val = (thread_idx < sub_seq_len) - ? reinterpret_cast( - params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx] + ? float(reinterpret_cast( + params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) : float(0); shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = beta_val; } diff --git a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp index 2230cd1..5e045bf 100644 --- a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp +++ b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp @@ -74,8 +74,8 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { using ElementAccumulatorKV = ElementAccumulatorKV_; using ElementO = Element; using ElementAlpha = float; - // TODO: support bf16 beta - using ElementBeta = float; + using ElementBeta = float; // SMEM + compute stays fp32 + using ElementBetaGmem = find_option_t; // GMEM type (float or bf16) using ElementGatedMMA = cutlass::tfloat32_t; using TileShape = TileShape_; @@ -459,7 +459,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { using LoadBeta = CollectiveLoadVector< LoadKindVector::kBeta, MainloopBetaPipeline, - ElementBeta, + ElementBetaGmem, GmemLayoutBeta, ElementBeta, SmemLayoutBeta, @@ -474,7 +474,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { float* ptr_output_state; // layout fixed (kdim, vdim, num_heads, num_seqs):LayoutLeft{} float const* ptr_input_state; float scale; - ElementBeta const* beta_ptr; GmemStrideBeta beta_stride; + ElementBetaGmem const* beta_ptr; GmemStrideBeta beta_stride; }; // clang-format on struct Params { @@ -489,7 +489,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { float* ptr_output_state; float const* ptr_input_state; - ElementBeta const* beta_ptr; + ElementBetaGmem const* beta_ptr; GmemLayoutBeta beta_layout; }; @@ -646,7 +646,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { // auto collective_load = LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/1.0f, pipeline, // storage.smem_beta}; auto collective_load = - LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/0.0f, pipeline, storage.smem_beta}; + LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/ElementBetaGmem(0), pipeline, storage.smem_beta}; auto src_dst = collective_load.partition_SD(problem_size, tile_shape, work_desc); CUTE_NO_UNROLL diff --git a/csrc/kda/sm90/kda_fwd_sm90.cu b/csrc/kda/sm90/kda_fwd_sm90.cu index 248c56a..c13b1bd 100644 --- a/csrc/kda/sm90/kda_fwd_sm90.cu +++ b/csrc/kda/sm90/kda_fwd_sm90.cu @@ -32,7 +32,8 @@ template < typename ArchTag, typename TO, typename TQKV, - typename TState> + typename TState, + typename TBeta = float> void launch_kda_fwd_prefill_kernel_gbai( cudaStream_t stream, @@ -43,7 +44,7 @@ launch_kda_fwd_prefill_kernel_gbai( TQKV const* v, TState const* input_state, float const* alpha, - float const* beta, + TBeta const* beta, int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, @@ -57,7 +58,8 @@ template < typename ArchTag, // TODO: hide this typename TO, typename TQKV, - typename TState> + typename TState, + typename TBeta> void launch_kda_fwd_prefill_kernel( cudaStream_t stream, @@ -68,7 +70,7 @@ launch_kda_fwd_prefill_kernel( TQKV const* v, TState const* input_state, float const* alpha, - float const* beta, + TBeta const* beta, int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, @@ -120,8 +122,9 @@ launch_kda_fwd_prefill_kernel( using bf16 = cute::bfloat16_t; +// TBeta=float (default) template void -launch_kda_fwd_prefill_kernel( +launch_kda_fwd_prefill_kernel( cudaStream_t stream, bf16* output, float* state, @@ -141,4 +144,26 @@ launch_kda_fwd_prefill_kernel( bool safe_gate, int32_t sm_count); +// TBeta=bf16 +template void +launch_kda_fwd_prefill_kernel( + cudaStream_t stream, + bf16* output, + float* state, + bf16 const* q, + bf16 const* k, + bf16 const* v, + float const* input_state, + float const* alpha, + bf16 const* beta, + int32_t const* cu_seqlens, + uint8_t* workspace_buffer, + int32_t num_seqs, + int32_t num_heads, + int32_t head_size, + int64_t total_seqlen, + float scale, + bool safe_gate, + int32_t sm_count); + } // namespace kda::sm90 diff --git a/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu b/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu index 9044d63..93fe869 100644 --- a/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu +++ b/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu @@ -65,4 +65,46 @@ launch_kda_fwd_prefill_kernel_gbai( + cudaStream_t, + bf16*, + float*, + bf16 const*, + bf16 const*, + bf16 const*, + float const*, + float const*, + bf16 const*, + int32_t const*, + uint8_t*, + int32_t, + int32_t, + int32_t, + int64_t, + float, + int32_t); + +// SafeGate=true, InitState=true, BetaBF16 +template void +launch_kda_fwd_prefill_kernel_gbai( + cudaStream_t, + bf16*, + float*, + bf16 const*, + bf16 const*, + bf16 const*, + float const*, + float const*, + bf16 const*, + int32_t const*, + uint8_t*, + int32_t, + int32_t, + int32_t, + int64_t, + float, + int32_t); + } // namespace kda::sm90 diff --git a/csrc/kda/sm90/kernel/options.hpp b/csrc/kda/sm90/kernel/options.hpp index 6140d64..e25fe9d 100644 --- a/csrc/kda/sm90/kernel/options.hpp +++ b/csrc/kda/sm90/kernel/options.hpp @@ -81,6 +81,7 @@ enum class Tag { kNeedsBeta, // delta rule kInitStateFromInput, // if true, initialize state by reading global memory instead of zero initialization. kSafeGate, // KDA + kElementBetaGmem, // GMEM element type for beta (default float, can be bf16) }; } // namespace kda::sm90::kernel diff --git a/csrc/kda/sm90/prefill_kernel.hpp b/csrc/kda/sm90/prefill_kernel.hpp index e413cdc..00cef2d 100644 --- a/csrc/kda/sm90/prefill_kernel.hpp +++ b/csrc/kda/sm90/prefill_kernel.hpp @@ -24,7 +24,8 @@ template < typename ArchTag, // TODO: hide this typename TO, typename TQKV, - typename TState> + typename TState, + typename TBeta = float> void launch_kda_fwd_prefill_kernel( cudaStream_t stream, @@ -35,7 +36,7 @@ launch_kda_fwd_prefill_kernel( TQKV const* v, TState const* input_state, float const* alpha, - float const* beta, + TBeta const* beta, int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, diff --git a/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh b/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh index fcf92ee..2fb9bda 100644 --- a/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh +++ b/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh @@ -37,7 +37,8 @@ template < typename ArchTag, typename TO, typename TQKV, - typename TState> + typename TState, + typename TBeta = float> void launch_kda_fwd_prefill_kernel_gbai( cudaStream_t stream, @@ -48,7 +49,7 @@ launch_kda_fwd_prefill_kernel_gbai( TQKV const* v, TState const* input_state, float const* alpha, - float const* beta, + TBeta const* beta, int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, @@ -77,14 +78,16 @@ launch_kda_fwd_prefill_kernel_gbai( using NeedsAlphaType = std::conditional_t; using InitStateType = std::conditional_t; using Options = decltype(add_option( - Option{}, + Option{}, add_option( - Option{}, + Option{}, add_option( - Option{}, + Option{}, add_option( - Option{}, - add_option(Option{}, DefaultOptions{})))))); + Option{}, + add_option( + Option{}, + add_option(Option{}, DefaultOptions{}))))))); using TileShape = Shape<_64, _64, _128>; using Scheduler = cutlass::gemm::KernelTmaWarpSpecializedCooperative; diff --git a/tests/test_kda.py b/tests/test_kda.py index dff6ea2..fadadbf 100644 --- a/tests/test_kda.py +++ b/tests/test_kda.py @@ -28,6 +28,7 @@ pytestmark = pytest.mark.sm100_only +@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize("disable_recompute", [True, False], ids=["no_recomp", "recomp"]) @pytest.mark.parametrize( ( @@ -71,6 +72,7 @@ def test_safe_gate_chunk( safe_gate: bool, dtype: torch.dtype, disable_recompute: bool, + beta_dtype: torch.dtype, ): torch.manual_seed(42) q = torch.rand(B, T, H, D, dtype=dtype) @@ -92,7 +94,7 @@ def test_safe_gate_chunk( lower_bound = None naive_kda_gate_fn = naive_kda_gate - beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid() + beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) h0 = torch.randn(B, H, D, D, dtype=torch.float32) if use_gate_in_kernel: A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(True), (A_log, dt_bias)) @@ -153,6 +155,7 @@ def test_safe_gate_chunk( assert_close("dh0", ref_dh0, tri_dh0, 0.008) +@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize("disable_recompute", [True, False], ids=["no_recomp", "recomp"]) @pytest.mark.parametrize( ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), @@ -208,6 +211,7 @@ def test_safe_gate_chunk_varlen( dtype: torch.dtype, safe_gate: bool, disable_recompute: bool, + beta_dtype: torch.dtype, ): torch.manual_seed(42) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) @@ -224,7 +228,7 @@ def test_safe_gate_chunk_varlen( if safe_gate: g = g.clamp(-5, 0) - beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid() + beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) h0 = torch.randn((N, H, D, D), dtype=torch.float32) q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(), (q, k, v, g, beta, h0)) diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index d2b5654..b9b59f9 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -30,6 +30,7 @@ pytestmark = pytest.mark.sm90_only +@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize( ( "B", @@ -71,6 +72,7 @@ def test_safe_gate_chunk( use_gate_in_kernel: bool, safe_gate: bool, dtype: torch.dtype, + beta_dtype: torch.dtype, ): from fla.ops.kda.gate import naive_kda_lowerbound_gate @@ -96,7 +98,7 @@ def test_safe_gate_chunk( lower_bound = None naive_kda_gate_fn = naive_kda_gate - beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid() + beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) h0 = torch.randn(B, H, D, D, dtype=torch.float32) # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance h0_vk = h0.transpose(-1, -2).contiguous() @@ -171,6 +173,7 @@ def test_safe_gate_chunk( assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) +@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize( ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), [ @@ -224,6 +227,7 @@ def test_safe_gate_chunk_varlen( cu_seqlens: list[int], dtype: torch.dtype, safe_gate: bool, + beta_dtype: torch.dtype, ): cula_kda_fused_fwd = get_kda_fused_fwd(device) @@ -242,7 +246,7 @@ def test_safe_gate_chunk_varlen( if safe_gate: g = g.clamp(-5, 0) - beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid() + beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) h0 = torch.randn((N, H, D, D), dtype=torch.float32) # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance h0_vk = h0.transpose(-1, -2).contiguous()