[Fix] Resolve performance drop in speculative decoding aiter backend (#11087)
This commit is contained in:
@@ -619,7 +619,11 @@ class AiterAttnBackend(AttentionBackend):
|
||||
assert len(k.shape) == 3
|
||||
assert len(v.shape) == 3
|
||||
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
if kv_indices.shape[0] == 0:
|
||||
o = flash_attn_varlen_func(
|
||||
q,
|
||||
|
||||
Reference in New Issue
Block a user