Enable trtllm mla prefix extend (#10526)
This commit is contained in:
@@ -553,7 +553,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_target_verify()
|
forward_batch.forward_mode.is_target_verify()
|
||||||
or forward_batch.forward_mode.is_draft_extend()
|
or forward_batch.forward_mode.is_draft_extend()
|
||||||
@@ -591,10 +591,45 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
return_lse=forward_batch.mha_return_lse,
|
return_lse=forward_batch.mha_return_lse,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# replace with trtllm ragged attention once accuracy is resolved.
|
if not (
|
||||||
output = super().forward_extend(
|
forward_batch.attn_attend_prefix_cache is not None
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
and forward_batch.mha_return_lse
|
||||||
)
|
):
|
||||||
|
output = super().forward_extend(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# MHA for chunked prefix kv cache when running model with MLA
|
||||||
|
assert forward_batch.prefix_chunk_idx is not None
|
||||||
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||||
|
assert q_rope is None
|
||||||
|
assert k_rope is None
|
||||||
|
chunk_idx = forward_batch.prefix_chunk_idx
|
||||||
|
|
||||||
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
|
||||||
|
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
|
||||||
|
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
workspace_buffer=self.workspace_buffer,
|
||||||
|
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
||||||
|
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||||
|
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||||
|
bmm1_scale=layer.scaling,
|
||||||
|
bmm2_scale=1.0,
|
||||||
|
o_sf_scale=-1.0,
|
||||||
|
batch_size=forward_batch.batch_size,
|
||||||
|
window_left=-1,
|
||||||
|
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||||
|
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||||
|
enable_pdl=False,
|
||||||
|
is_causal=False,
|
||||||
|
return_lse=True,
|
||||||
|
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
||||||
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user