diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 27a12c627..ffcc9a955 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -658,7 +658,7 @@ class DeepseekV2AttentionMLA(nn.Module): and forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() - and forward_batch.extend_prefix_lens.sum() == 0 + and sum(forward_batch.extend_prefix_lens_cpu) == 0 ) else: # Triton: Use normal computation for prefill and use weight absorption for extend/decode @@ -666,7 +666,7 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() - and forward_batch.extend_prefix_lens.sum() == 0 + and sum(forward_batch.extend_prefix_lens_cpu) == 0 ) def forward(