[Fix] Resolve performance drop in speculative decoding aiter backend (#11087)
This commit is contained in:
@@ -619,7 +619,11 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
assert len(k.shape) == 3
|
assert len(k.shape) == 3
|
||||||
assert len(v.shape) == 3
|
assert len(v.shape) == 3
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_extend():
|
if (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
if kv_indices.shape[0] == 0:
|
if kv_indices.shape[0] == 0:
|
||||||
o = flash_attn_varlen_func(
|
o = flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
|
|||||||
Reference in New Issue
Block a user