Integrate trtllm ragged attention for prefill self-attention (#9801)
This commit is contained in:
@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
|
||||
def update_wrapper(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
disable_flashinfer_ragged: bool = False,
|
||||
):
|
||||
assert forward_batch.num_prefix_chunks is not None
|
||||
num_prefix_chunks = forward_batch.num_prefix_chunks
|
||||
@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
|
||||
causal=False,
|
||||
)
|
||||
# ragged prefill
|
||||
self.ragged_wrapper.begin_forward(
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=qo_indptr,
|
||||
num_qo_heads=self.num_local_heads,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
head_dim_vo=self.v_head_dim,
|
||||
q_data_type=self.q_data_type,
|
||||
causal=True,
|
||||
)
|
||||
if not disable_flashinfer_ragged:
|
||||
self.ragged_wrapper.begin_forward(
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=qo_indptr,
|
||||
num_qo_heads=self.num_local_heads,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
head_dim_vo=self.v_head_dim,
|
||||
q_data_type=self.q_data_type,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
||||
def init_mha_chunk_metadata(
|
||||
self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
|
||||
):
|
||||
"""Init the metadata for a forward pass."""
|
||||
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
||||
self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
|
||||
@@ -45,6 +45,15 @@ TRTLLM_BLOCK_CONSTRAINT = 128
|
||||
global_zero_init_workspace_buffer = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TRTLLMMLAPrefillMetadata:
|
||||
"""Metadata for TRTLLM MLA prefill operations."""
|
||||
|
||||
max_seq_len: int
|
||||
cum_seq_lens: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TRTLLMMLADecodeMetadata:
|
||||
"""Metadata for TRTLLM MLA decode operations."""
|
||||
@@ -101,7 +110,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
# CUDA graph state
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.decode_cuda_graph_kv_indices = None
|
||||
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
||||
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||
|
||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||
"""
|
||||
@@ -235,7 +245,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
max_seq_len_val,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
self.forward_decode_metadata = metadata
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
@@ -291,31 +301,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize the metadata for a forward pass."""
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_batch.forward_mode.is_decode_or_idle():
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
|
||||
cum_seq_lens_q = torch.cat(
|
||||
(
|
||||
torch.tensor([0], device=forward_batch.seq_lens.device),
|
||||
torch.cumsum(seq_lens, dim=0),
|
||||
)
|
||||
).int()
|
||||
max_seq_len = max(forward_batch.extend_seq_lens_cpu)
|
||||
self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
|
||||
max_seq_len,
|
||||
cum_seq_lens_q,
|
||||
seq_lens,
|
||||
)
|
||||
elif forward_batch.forward_mode.is_decode_or_idle():
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
# Get maximum sequence length.
|
||||
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
||||
max_seq = forward_batch.seq_lens_cpu.max().item()
|
||||
else:
|
||||
max_seq = forward_batch.seq_lens.max().item()
|
||||
|
||||
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
||||
block_kv_indices = self._create_block_kv_indices(
|
||||
bs,
|
||||
max_seqlen_pad,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens.device,
|
||||
)
|
||||
|
||||
max_seq_len_val = int(max_seq)
|
||||
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
||||
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
||||
)
|
||||
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
||||
else:
|
||||
return super().init_forward_metadata(forward_batch)
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
# Get maximum sequence length.
|
||||
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
||||
max_seq = forward_batch.seq_lens_cpu.max().item()
|
||||
else:
|
||||
max_seq = forward_batch.seq_lens.max().item()
|
||||
|
||||
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
||||
block_kv_indices = self._create_block_kv_indices(
|
||||
bs,
|
||||
max_seqlen_pad,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens.device,
|
||||
)
|
||||
|
||||
max_seq_len_val = int(max_seq)
|
||||
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
||||
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
||||
)
|
||||
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
||||
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
||||
super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
|
||||
|
||||
def quantize_and_rope_for_fp8(
|
||||
self,
|
||||
@@ -459,7 +490,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
# Get metadata
|
||||
metadata = (
|
||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||
or self.forward_metadata
|
||||
or self.forward_decode_metadata
|
||||
)
|
||||
|
||||
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
||||
@@ -496,6 +527,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
return output
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
|
||||
if not forward_batch.attn_attend_prefix_cache:
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
seq_lens=self.forward_prefill_metadata.seq_lens,
|
||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
||||
bmm1_scale=layer.scaling,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=1.0,
|
||||
batch_size=forward_batch.batch_size,
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
||||
enable_pdl=False,
|
||||
is_causal=True,
|
||||
return_lse=forward_batch.mha_return_lse,
|
||||
)
|
||||
else:
|
||||
# replace with trtllm ragged attention once accuracy is resolved.
|
||||
output = super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
||||
|
||||
@@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attention_backend == "flashinfer"
|
||||
or attention_backend == "fa3"
|
||||
or attention_backend == "flashmla"
|
||||
or attention_backend == "trtllm_mla"
|
||||
or attention_backend == "cutlass_mla"
|
||||
):
|
||||
# Use MHA with chunked KV cache when prefilling on long sequences.
|
||||
@@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "trtllm_mla":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "aiter":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
|
||||
Reference in New Issue
Block a user