Revert "Avoid computing lse in Ragged Prefill when there's no prefix.… (#5544)

This commit is contained in:
Yineng Zhang
2025-04-18 16:50:21 -07:00
committed by GitHub
parent 08b518d51f
commit a6f892e5d0
3 changed files with 12 additions and 19 deletions

View File

@@ -428,25 +428,18 @@ class FlashInferAttnBackend(AttentionBackend):
v_scale=v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
if self.forward_metadata.extend_no_prefix:
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
)
o = o1
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),

View File

@@ -348,7 +348,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
if self.forward_metadata.use_ragged:
# ragged prefill
o = self.prefill_wrapper_ragged.forward(
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),