Fix "Avoid computing lse in Ragged Prefill when there's no prefix match" (#5555)
This commit is contained in:
@@ -418,6 +418,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
logits_soft_cap = layer.logit_cap
|
||||
|
||||
q = q.contiguous()
|
||||
if not self.forward_metadata.use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -427,7 +428,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=not layer.is_cross_attention,
|
||||
sm_scale=layer.scaling,
|
||||
@@ -437,20 +438,27 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
if self.forward_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
o = self.prefill_wrapper_ragged.forward(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=False,
|
||||
sm_scale=layer.scaling,
|
||||
|
||||
@@ -355,7 +355,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
|
||||
if self.forward_metadata.use_ragged:
|
||||
# ragged prefill
|
||||
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
o = self.prefill_wrapper_ragged.forward(
|
||||
qall,
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
||||
|
||||
Reference in New Issue
Block a user