-
Notifications
You must be signed in to change notification settings - Fork 737
[Optimization] [OP] [Models] dsk del prefill mask #7313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,6 +72,7 @@ | |
| from fastdeploy.model_executor.ops.gpu import ( | ||
| cp_gather_indexer_k_quant_cache, | ||
| indexer_k_quant_and_cache, | ||
| merge_prefill_decode_output, | ||
| radix_topk_ragged_transform, | ||
| ) | ||
|
|
||
|
|
@@ -398,7 +399,6 @@ def forward( | |
| fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) | ||
| fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] | ||
| fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) | ||
| fmha_out_prefill = fmha_out_prefill * forward_meta.mask_encoder_batch.cast(fmha_out_prefill.dtype) | ||
| fmha_out = fmha_out_prefill | ||
|
|
||
| if need_do_decode: # max_dec_len_this_time | ||
|
|
@@ -433,7 +433,17 @@ def forward( | |
| ) | ||
|
|
||
| if need_do_prefill: | ||
| fmha_out += fmha_out_decode | ||
| merge_prefill_decode_output( | ||
| fmha_out, | ||
| fmha_out_decode, | ||
| forward_meta.seq_lens_encoder, | ||
| forward_meta.seq_lens_decoder, | ||
| forward_meta.seq_lens_this_time, | ||
| forward_meta.cu_seqlens_q, | ||
| self.num_attention_heads_tp, | ||
| self.v_head_dim, | ||
| 1, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 在 CUDA kernel 中, const int warps = head_dim / 32;
const int tokens_block = (max_token + warps - 1) / warps;
dim3 grid_dims(batch_size, head_num, tokens_block);当 需要确认:
建议:如果 |
||
| ) | ||
| else: | ||
| fmha_out = fmha_out_decode | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -116,18 +116,21 @@ def test_neox_mode(self): | |
| self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True) | ||
|
|
||
| def test_large_num_tokens(self): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 建议添加正确性验证(类似其他测试用例使用 def test_large_num_tokens(self):
num_tokens, num_heads, head_size = 65537, 1, 4
num_kv_heads, rot_dim = 1, 4
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
position_ids_np = np.arange(num_tokens, dtype="int32")
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
query_out, key_out = self._run_op(
query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
)
# 添加正确性验证
query_ref, key_ref = self._ref_rotary(
query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
)
np.testing.assert_allclose(query_out, query_ref, rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(key_out, key_ref, rtol=1e-5, atol=1e-6) |
||
| self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False) | ||
|
|
||
| def test_exceed_max_tokens(self): | ||
| """ | ||
| 测试算子支持大量 tokens(超过 65535) | ||
| 算子使用 2D grid,理论上可支持 65535*65535 个 tokens | ||
| """ | ||
| num_tokens, num_heads, head_size = 65537, 1, 4 | ||
| num_kv_heads, rot_dim = 1, 4 | ||
| query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32") | ||
| key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32") | ||
| position_ids_np = np.arange(num_tokens, dtype="int32") | ||
| cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim) | ||
|
|
||
| with self.assertRaises(Exception): | ||
| self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False) | ||
| # 不应该抛出异常,算子应该能处理大量 tokens | ||
| query_out, key_out = self._run_op( | ||
| query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.