Integrate trtllm ragged attention for prefill self-attention (#9801)

This commit is contained in:
Elfie Guo
2025-09-05 02:18:00 -07:00
committed by GitHub
parent f98366604b
commit bebd0576e5
4 changed files with 300 additions and 44 deletions

View File

@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
def update_wrapper(
self,
forward_batch: ForwardBatch,
disable_flashinfer_ragged: bool = False,
):
assert forward_batch.num_prefix_chunks is not None
num_prefix_chunks = forward_batch.num_prefix_chunks
@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
causal=False,
)
# ragged prefill
self.ragged_wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=True,
)
if not disable_flashinfer_ragged:
self.ragged_wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=True,
)
def forward(
self,
@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
def init_mha_chunk_metadata(
self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
):
"""Init the metadata for a forward pass."""
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
def forward_extend(
self,

View File

@@ -45,6 +45,15 @@ TRTLLM_BLOCK_CONSTRAINT = 128
global_zero_init_workspace_buffer = None
@dataclass
class TRTLLMMLAPrefillMetadata:
"""Metadata for TRTLLM MLA prefill operations."""
max_seq_len: int
cum_seq_lens: torch.Tensor
seq_lens: torch.Tensor
@dataclass
class TRTLLMMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations."""
@@ -101,7 +110,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# CUDA graph state
self.decode_cuda_graph_metadata = {}
self.decode_cuda_graph_kv_indices = None
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
@@ -235,7 +245,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val,
)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
self.forward_decode_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
@@ -291,31 +301,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass."""
# Delegate to parent for non-decode modes.
if not forward_batch.forward_mode.is_decode_or_idle():
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
cum_seq_lens_q = torch.cat(
(
torch.tensor([0], device=forward_batch.seq_lens.device),
torch.cumsum(seq_lens, dim=0),
)
).int()
max_seq_len = max(forward_batch.extend_seq_lens_cpu)
self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
max_seq_len,
cum_seq_lens_q,
seq_lens,
)
elif forward_batch.forward_mode.is_decode_or_idle():
bs = forward_batch.batch_size
# Get maximum sequence length.
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
max_seq = forward_batch.seq_lens_cpu.max().item()
else:
max_seq = forward_batch.seq_lens.max().item()
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens.device,
)
max_seq_len_val = int(max_seq)
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer, block_kv_indices, max_seq_len_val
)
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else:
return super().init_forward_metadata(forward_batch)
bs = forward_batch.batch_size
# Get maximum sequence length.
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
max_seq = forward_batch.seq_lens_cpu.max().item()
else:
max_seq = forward_batch.seq_lens.max().item()
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens.device,
)
max_seq_len_val = int(max_seq)
self.forward_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer, block_kv_indices, max_seq_len_val
)
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
def quantize_and_rope_for_fp8(
self,
@@ -459,7 +490,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Get metadata
metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
or self.forward_metadata
or self.forward_decode_metadata
)
# Scale computation for TRTLLM MLA kernel BMM1 operation:
@@ -496,6 +527,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
if not forward_batch.attn_attend_prefix_cache:
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=self.forward_prefill_metadata.seq_lens,
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=self.forward_prefill_metadata.max_seq_len,
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
enable_pdl=False,
is_causal=True,
return_lse=forward_batch.mha_return_lse,
)
else:
# replace with trtllm ragged attention once accuracy is resolved.
output = super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
return output
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""

View File

@@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend == "flashinfer"
or attention_backend == "fa3"
or attention_backend == "flashmla"
or attention_backend == "trtllm_mla"
or attention_backend == "cutlass_mla"
):
# Use MHA with chunked KV cache when prefilling on long sequences.
@@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "trtllm_mla":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()