Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
rope_3d);
} else {
if (rotary_dim < dim_head) {
auto* kernelFn =
append_decode_cache_T_neox_partial_rope_kernel<T,
PackSize,
EnforceFmulRN>;
auto* kernelFn = append_decode_cache_T_neox_partial_rope_kernel<
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 硬编码 EnforceFmulRN = false 会影响所有使用这些 kernel 的模型(包括 DeepSeek V3 等使用 partial rotary embedding 的模型),而不仅仅是 GLM 模型。

建议改为通过环境变量 FD_ENABLE_RL 控制此参数,保持与 Python 层的一致性。

修改方式参考 helper.h:746 中的 fmul_func 定义,可以在 kernel launch 时根据环境变量选择不同的模板实例。

T,
PackSize,
false>; // GLM use EnforceFmulRN=false
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2543,10 +2543,10 @@ void gqa_rotary_qk_variable(
}
const int pack_num_new = elem_nums / PackSize;
GetNumBlocks<128>(pack_num_new, &grid_size);
auto *kernelFn =
GQANeoxVariableLengthPartialRotaryKernel<T,
PackSize,
EnforceFmulRN>;
auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel<
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 同上,硬编码会影响所有模型,建议改为通过环境变量控制。

T,
PackSize,
false>; // GLM use EnforceFmulRN=false
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
Expand Down
50 changes: 26 additions & 24 deletions custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -387,30 +387,32 @@ void gqa_neox_partial_rotary_qk_split_variable(

const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2;
launchWithPdlWhenEnabled(
GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize, EnforceFmulRN>,
grid_size,
block_size,
0,
stream,
qkv_input,
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_k,
qkv_out,
q,
k,
v,
elem_nums,
num_heads,
kv_num_heads,
max_model_len,
head_dim,
rotary_dim);
launchWithPdlWhenEnabled(GQAVariableLengthNeoxPartialRotarySplitKernel<
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 同上,硬编码会影响所有模型,建议改为通过环境变量控制。

T,
PackSize,
false>, // GLM use EnforceFmulRN=false

This comment was marked as outdated.

grid_size,
block_size,
0,
stream,
qkv_input,
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_k,
qkv_out,
q,
k,
v,
elem_nums,
num_heads,
kv_num_heads,
max_model_len,
head_dim,
rotary_dim);
}

template <typename T,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
GetNumBlocks(pack_num, &grid_size);
if (use_neox_style) {
if (rotary_dim < dim_head) {
append_speculate_cache_neox_partial_rope_kernel<T,
PackSize,
QKV_TYPE,
EnforceFmulRN>
append_speculate_cache_neox_partial_rope_kernel<
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 同上,硬编码会影响所有模型,建议改为通过环境变量控制。

T,
PackSize,
QKV_TYPE,
false> // GLM use EnforceFmulRN=false
<<<grid_size, threads_per_block, 0, stream>>>(
qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size]
key_cache,
Expand Down
10 changes: 8 additions & 2 deletions fastdeploy/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle
from paddle import nn

from fastdeploy import envs
from fastdeploy.config import ModelConfig
from fastdeploy.platforms import current_platform

Expand Down Expand Up @@ -87,8 +88,13 @@ def __init__(self, rotary_dim, base, partial_rotary_factor):

def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
if envs.FD_ENABLE_RL == 1:
idx = paddle.arange(0, self.rotary_dim, 2, dtype=paddle.int64).astype(paddle.float32)
inv_freq = 1.0 / (self.base ** (idx / self.rotary_dim))
freqs = paddle.outer(position_ids.astype(inv_freq.dtype), inv_freq)
else:
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
Expand Down
Loading