diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index 963ccfa23d9..e25816fcbb3 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -146,10 +146,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, rope_3d); } else { if (rotary_dim < dim_head) { - auto* kernelFn = - append_decode_cache_T_neox_partial_rope_kernel; + auto* kernelFn = append_decode_cache_T_neox_partial_rope_kernel< + T, + PackSize, + false>; // GLM use EnforceFmulRN=false launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 0cdea537327..60d5d34bf48 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -2543,10 +2543,10 @@ void gqa_rotary_qk_variable( } const int pack_num_new = elem_nums / PackSize; GetNumBlocks<128>(pack_num_new, &grid_size); - auto *kernelFn = - GQANeoxVariableLengthPartialRotaryKernel; + auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel< + T, + PackSize, + false>; // GLM use EnforceFmulRN=false launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index e4d0554fea6..c86ec27dca8 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -387,30 +387,32 @@ void gqa_neox_partial_rotary_qk_split_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2; - launchWithPdlWhenEnabled( - GQAVariableLengthNeoxPartialRotarySplitKernel, - grid_size, - block_size, - 0, - stream, - qkv_input, - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens_encoder, - seq_lens_decoder, - cu_seqlens_k, - qkv_out, - q, - k, - v, - elem_nums, - num_heads, - kv_num_heads, - max_model_len, - head_dim, - rotary_dim); + launchWithPdlWhenEnabled(GQAVariableLengthNeoxPartialRotarySplitKernel< + T, + PackSize, + false>, // GLM use EnforceFmulRN=false + grid_size, + block_size, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + max_model_len, + head_dim, + rotary_dim); } template + append_speculate_cache_neox_partial_rope_kernel< + T, + PackSize, + QKV_TYPE, + false> // GLM use EnforceFmulRN=false <<>>( qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] key_cache, diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index af7203ed6f1..dd77cf2bc0d 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -20,6 +20,7 @@ import paddle from paddle import nn +from fastdeploy import envs from fastdeploy.config import ModelConfig from fastdeploy.platforms import current_platform @@ -87,8 +88,13 @@ def __init__(self, rotary_dim, base, partial_rotary_factor): def __call__(self, position_ids): bsz, max_seq_len = position_ids.shape[:2] - inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) - freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + if envs.FD_ENABLE_RL == 1: + idx = paddle.arange(0, self.rotary_dim, 2, dtype=paddle.int64).astype(paddle.float32) + inv_freq = 1.0 / (self.base ** (idx / self.rotary_dim)) + freqs = paddle.outer(position_ids.astype(inv_freq.dtype), inv_freq) + else: + inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) # shape: [B, S, D/2] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))