update forward_return_lse (#3425)

This commit is contained in:
Yineng Zhang
2025-02-09 20:18:44 +08:00
committed by GitHub
parent 4d2dbeaca7
commit 014cab4dd2

View File

@@ -409,9 +409,9 @@ class FlashInferAttnBackend(AttentionBackend):
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,