Avoid computing lse in Ragged Prefill when there's no prefix. (#5476)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
Wenxuan Tan
2025-04-18 03:13:57 -05:00
committed by GitHub
parent e465b08ddb
commit bfa3922451
3 changed files with 19 additions and 12 deletions

View File

@@ -425,18 +425,25 @@ class FlashInferAttnBackend(AttentionBackend):
v_scale=v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
if self.forward_metadata.extend_no_prefix:
o = o1
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),

View File

@@ -348,7 +348,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
if self.forward_metadata.use_ragged:
# ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
o = self.prefill_wrapper_ragged.forward(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),