avoid cudaStreamSynchronize in DeepSeekV2AttentionMLA (#4577)

Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
strgrb
2025-03-20 01:02:26 +08:00
committed by GitHub
parent 4942074174
commit df7014a8d2

View File

@@ -658,7 +658,7 @@ class DeepseekV2AttentionMLA(nn.Module):
and forward_batch.forward_mode.is_extend() and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0 and sum(forward_batch.extend_prefix_lens_cpu) == 0
) )
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
@@ -666,7 +666,7 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0 and sum(forward_batch.extend_prefix_lens_cpu) == 0
) )
def forward( def forward(