From df7014a8d23090c684d0f2d6c84019a1201569d7 Mon Sep 17 00:00:00 2001 From: strgrb Date: Thu, 20 Mar 2025 01:02:26 +0800 Subject: [PATCH] avoid cudaStreamSynchronize in DeepSeekV2AttentionMLA (#4577) Co-authored-by: Zhang Kaihong --- python/sglang/srt/models/deepseek_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 27a12c627..ffcc9a955 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -658,7 +658,7 @@ class DeepseekV2AttentionMLA(nn.Module): and forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() - and forward_batch.extend_prefix_lens.sum() == 0 + and sum(forward_batch.extend_prefix_lens_cpu) == 0 ) else: # Triton: Use normal computation for prefill and use weight absorption for extend/decode @@ -666,7 +666,7 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() - and forward_batch.extend_prefix_lens.sum() == 0 + and sum(forward_batch.extend_prefix_lens_cpu) == 0 ) def forward(