-
Notifications
You must be signed in to change notification settings - Fork 737
[RL] change glm rope_emb calculation #7316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2543,10 +2543,10 @@ void gqa_rotary_qk_variable( | |
| } | ||
| const int pack_num_new = elem_nums / PackSize; | ||
| GetNumBlocks<128>(pack_num_new, &grid_size); | ||
| auto *kernelFn = | ||
| GQANeoxVariableLengthPartialRotaryKernel<T, | ||
| PackSize, | ||
| EnforceFmulRN>; | ||
| auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel< | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 同上,硬编码会影响所有模型,建议改为通过环境变量控制。 |
||
| T, | ||
| PackSize, | ||
| false>; // GLM use EnforceFmulRN=false | ||
| launchWithPdlWhenEnabled(kernelFn, | ||
| grid_size, | ||
| blocksize, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -387,30 +387,32 @@ void gqa_neox_partial_rotary_qk_split_variable( | |
|
|
||
| const float *cos_emb = rotary_emb; | ||
| const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2; | ||
| launchWithPdlWhenEnabled( | ||
| GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize, EnforceFmulRN>, | ||
| grid_size, | ||
| block_size, | ||
| 0, | ||
| stream, | ||
| qkv_input, | ||
| cos_emb, | ||
| sin_emb, | ||
| batch_id_per_token, | ||
| cu_seqlens_q, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| cu_seqlens_k, | ||
| qkv_out, | ||
| q, | ||
| k, | ||
| v, | ||
| elem_nums, | ||
| num_heads, | ||
| kv_num_heads, | ||
| max_model_len, | ||
| head_dim, | ||
| rotary_dim); | ||
| launchWithPdlWhenEnabled(GQAVariableLengthNeoxPartialRotarySplitKernel< | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 同上,硬编码会影响所有模型,建议改为通过环境变量控制。 |
||
| T, | ||
| PackSize, | ||
| false>, // GLM use EnforceFmulRN=false | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| grid_size, | ||
| block_size, | ||
| 0, | ||
| stream, | ||
| qkv_input, | ||
| cos_emb, | ||
| sin_emb, | ||
| batch_id_per_token, | ||
| cu_seqlens_q, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| cu_seqlens_k, | ||
| qkv_out, | ||
| q, | ||
| k, | ||
| v, | ||
| elem_nums, | ||
| num_heads, | ||
| kv_num_heads, | ||
| max_model_len, | ||
| head_dim, | ||
| rotary_dim); | ||
| } | ||
|
|
||
| template <typename T, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -130,10 +130,11 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, | |
| GetNumBlocks(pack_num, &grid_size); | ||
| if (use_neox_style) { | ||
| if (rotary_dim < dim_head) { | ||
| append_speculate_cache_neox_partial_rope_kernel<T, | ||
| PackSize, | ||
| QKV_TYPE, | ||
| EnforceFmulRN> | ||
| append_speculate_cache_neox_partial_rope_kernel< | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 同上,硬编码会影响所有模型,建议改为通过环境变量控制。 |
||
| T, | ||
| PackSize, | ||
| QKV_TYPE, | ||
| false> // GLM use EnforceFmulRN=false | ||
| <<<grid_size, threads_per_block, 0, stream>>>( | ||
| qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] | ||
| key_cache, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议 硬编码
EnforceFmulRN = false会影响所有使用这些 kernel 的模型(包括 DeepSeek V3 等使用 partial rotary embedding 的模型),而不仅仅是 GLM 模型。建议改为通过环境变量
FD_ENABLE_RL控制此参数,保持与 Python 层的一致性。修改方式参考
helper.h:746中的fmul_func定义,可以在 kernel launch 时根据环境变量选择不同的模板实例。