avoid cudaStreamSynchronize in DeepSeekV2AttentionMLA (#4577)
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
@@ -658,7 +658,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
and forward_batch.forward_mode.is_extend()
|
and forward_batch.forward_mode.is_extend()
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
and forward_batch.extend_prefix_lens.sum() == 0
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||||
@@ -666,7 +666,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
forward_batch.forward_mode.is_extend()
|
forward_batch.forward_mode.is_extend()
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
and forward_batch.extend_prefix_lens.sum() == 0
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user