Fix "Avoid computing lse in Ragged Prefill when there's no prefix match" (#5555)

This commit is contained in:
Wenxuan Tan
2025-05-05 12:32:17 -05:00
committed by GitHub
parent b8559764f6
commit 22da3d978f
3 changed files with 23 additions and 15 deletions

View File

@@ -418,6 +418,7 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap = layer.logit_cap
q = q.contiguous()
if not self.forward_metadata.use_ragged:
if k is not None:
assert v is not None
@@ -427,7 +428,7 @@ class FlashInferAttnBackend(AttentionBackend):
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
@@ -437,20 +438,27 @@ class FlashInferAttnBackend(AttentionBackend):
v_scale=layer.v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
if self.forward_metadata.extend_no_prefix:
o = o1
o = self.prefill_wrapper_ragged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,

View File

@@ -355,7 +355,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
if self.forward_metadata.use_ragged:
# ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
o = self.prefill_wrapper_ragged.forward(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),