Enable trtllm mla prefix extend (#10526)
This commit is contained in:
@@ -553,7 +553,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
save_kv_cache: bool = True,
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
if (
|
||||
forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
@@ -591,10 +591,45 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
return_lse=forward_batch.mha_return_lse,
|
||||
)
|
||||
else:
|
||||
# replace with trtllm ragged attention once accuracy is resolved.
|
||||
output = super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
if not (
|
||||
forward_batch.attn_attend_prefix_cache is not None
|
||||
and forward_batch.mha_return_lse
|
||||
):
|
||||
output = super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
else:
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
assert q_rope is None
|
||||
assert k_rope is None
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
|
||||
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
||||
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||
bmm1_scale=layer.scaling,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=-1.0,
|
||||
batch_size=forward_batch.batch_size,
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||
enable_pdl=False,
|
||||
is_causal=False,
|
||||
return_lse=True,
|
||||
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user