Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ __global__ void apply_rotary_embedding_kernel(
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
const int head_size,
const int num_tokens) { // 新增 num_tokens 参数用于边界检查

// 用2D grid表示token_idx,突破65535限制
const int token_idx = blockIdx.x + blockIdx.y * gridDim.x;
if (token_idx >= num_tokens) return; // 边界保护

int pos = position_ids[token_idx];
const T* cache_ptr = cos_sin_cache + pos * rot_dim;

Expand Down Expand Up @@ -99,13 +103,13 @@ void FusedRotaryPositionEncoding(
int64_t query_stride = num_heads * head_size;
int64_t key_stride = num_kv_heads * head_size;

if (num_tokens > 65535) {
PD_THROW(
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
}

dim3 grid(num_tokens);
// 拆成2D grid:每维最大65535,总计支持 65535*65535 >> 1024*1024
constexpr int MAX_GRID_X = 65535;
int grid_x = std::min<int64_t>(num_tokens, MAX_GRID_X);
int grid_y = (num_tokens + MAX_GRID_X - 1) / MAX_GRID_X;
dim3 grid(grid_x, grid_y);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));

PD_DISPATCH_FLOATING_AND_HALF_TYPES(
query.dtype(), "apply_rotary_embedding_kernel", [&] {
if (is_neox) {
Expand All @@ -119,7 +123,8 @@ void FusedRotaryPositionEncoding(
key_stride,
num_heads,
num_kv_heads,
head_size);
head_size,
num_tokens);
} else {
apply_rotary_embedding_kernel<data_t, false>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
Expand All @@ -131,7 +136,8 @@ void FusedRotaryPositionEncoding(
key_stride,
num_heads,
num_kv_heads,
head_size);
head_size,
num_tokens);
}
});
}
Expand Down
77 changes: 46 additions & 31 deletions custom_ops/gpu_ops/merge_prefill_decode_output.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,49 @@ __global__ void FillEncoderDecoderResKernel(T *encoder_res_data,
return;
}

const int load_idx =
((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4;
const int base_idx =
((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim;

*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
if (head_dim == 128) {
const int load_idx = base_idx + land_id * 4;
*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
} else if (head_dim == 192) {
const int load_idx = base_idx + land_id * 4;
*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
if (land_id < 16) {

This comment was marked as outdated.

*reinterpret_cast<float2 *>(encoder_res_data + load_idx + 128) =
*reinterpret_cast<float2 *>(decoder_res_data + load_idx + 128);
}
} else if (head_dim == 256) {
// float4 = 单条LDG.128,性能最优
const int load_idx = base_idx + land_id * 8;
*reinterpret_cast<float4 *>(encoder_res_data + load_idx) =
*reinterpret_cast<float4 *>(decoder_res_data + load_idx);
}
}

#define LAUNCH_KERNEL(T, WARPS) \
FillEncoderDecoderResKernel<WARPS> \
<<<grid_dims, head_dim, 0, encoder_res.stream()>>>( \
const_cast<T *>(encoder_res.data<T>()), \
const_cast<T *>(decoder_res.data<T>()), \
seq_lens_encoder.data<int>(), \
seq_lens_decoder.data<int>(), \
seq_lens_this_time.data<int>(), \
cu_seq_q.data<int>(), \
head_num, \
head_dim)

#define LAUNCH_KERNEL_BY_HEAD_DIM(T) \
if (head_dim == 128) \
LAUNCH_KERNEL(T, 4); \
else if (head_dim == 192) \
LAUNCH_KERNEL(T, 6); \
else if (head_dim == 256) \
LAUNCH_KERNEL(T, 8)

void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
Expand All @@ -60,41 +96,20 @@ void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res,
const int head_num,
const int head_dim,
const int max_token) {
if (head_dim != 128) {
PD_THROW("Only supported head_dim = 128");
if (head_dim != 128 && head_dim != 192 && head_dim != 256) {
PD_THROW("Only supported head_dim = 128, 192 or 256");
}
const int batch_size = seq_lens_encoder.shape()[0];
constexpr int warps = 4;
const int warps = head_dim / 32;
const int tokens_block = (max_token + warps - 1) / warps;
dim3 grid_dims;
grid_dims.x = batch_size;
grid_dims.y = head_num;
grid_dims.z = tokens_block;
dim3 grid_dims(batch_size, head_num, tokens_block);

if (encoder_res.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
FillEncoderDecoderResKernel<warps>
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
const_cast<T *>(encoder_res.data<T>()),
const_cast<T *>(decoder_res.data<T>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seq_q.data<int>(),
head_num,
head_dim);
LAUNCH_KERNEL_BY_HEAD_DIM(T);
} else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
FillEncoderDecoderResKernel<warps>
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
const_cast<T *>(encoder_res.data<T>()),
const_cast<T *>(decoder_res.data<T>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seq_q.data<int>(),
head_num,
head_dim);
LAUNCH_KERNEL_BY_HEAD_DIM(T);
}
}

Expand Down
14 changes: 12 additions & 2 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from fastdeploy.model_executor.ops.gpu import (
cp_gather_indexer_k_quant_cache,
indexer_k_quant_and_cache,
merge_prefill_decode_output,
radix_topk_ragged_transform,
)

Expand Down Expand Up @@ -398,7 +399,6 @@ def forward(
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * forward_meta.mask_encoder_batch.cast(fmha_out_prefill.dtype)
fmha_out = fmha_out_prefill

if need_do_decode: # max_dec_len_this_time
Expand Down Expand Up @@ -433,7 +433,17 @@ def forward(
)

if need_do_prefill:
fmha_out += fmha_out_decode
merge_prefill_decode_output(
fmha_out,
fmha_out_decode,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_attention_heads_tp,
self.v_head_dim,
1,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 max_token=1 被硬编码传递给 merge_prefill_decode_output

在 CUDA kernel 中,max_token 用于计算 grid 的 z 维度:

const int warps = head_dim / 32;
const int tokens_block = (max_token + warps - 1) / warps;
dim3 grid_dims(batch_size, head_num, tokens_block);

max_token=1 时,tokens_block=1token_id = warp_id(范围 0-31)。如果 seq_lens_this_time[bidb] > warps,会导致某些 token 无法被处理。

需要确认

  1. 在 DeepSeek V3 的使用场景中,seq_lens_this_time[bidb] 是否总是 ≤ warps?
  2. 是否有场景需要一次解码生成多个 token(如 speculative decoding)?

建议:如果 max_token=1 是针对当前场景的优化,请添加注释说明原因;否则考虑使用 forward_meta.max_len_tensor_cpu[2](max_dec_len_this_time)作为参数。

)
else:
fmha_out = fmha_out_decode

Expand Down
13 changes: 8 additions & 5 deletions tests/operators/test_fused_rotary_position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,21 @@ def test_neox_mode(self):
self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True)

def test_large_num_tokens(self):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 test_large_num_tokens 只验证算子不抛出异常,没有验证输出结果的正确性。

建议添加正确性验证(类似其他测试用例使用 _check_correctness):

def test_large_num_tokens(self):
    num_tokens, num_heads, head_size = 65537, 1, 4
    num_kv_heads, rot_dim = 1, 4
    query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
    key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
    position_ids_np = np.arange(num_tokens, dtype="int32")
    cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
    
    query_out, key_out = self._run_op(
        query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
    )
    # 添加正确性验证
    query_ref, key_ref = self._ref_rotary(
        query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
    )
    np.testing.assert_allclose(query_out, query_ref, rtol=1e-5, atol=1e-6)
    np.testing.assert_allclose(key_out, key_ref, rtol=1e-5, atol=1e-6)

self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False)

def test_exceed_max_tokens(self):
"""
测试算子支持大量 tokens(超过 65535)
算子使用 2D grid,理论上可支持 65535*65535 个 tokens
"""
num_tokens, num_heads, head_size = 65537, 1, 4
num_kv_heads, rot_dim = 1, 4
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
position_ids_np = np.arange(num_tokens, dtype="int32")
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)

with self.assertRaises(Exception):
self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False)
# 不应该抛出异常,算子应该能处理大量 tokens
query_out, key_out = self._run_op(
query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
)


if __name__ == "__main__":
Expand Down
Loading