Fix TRTLLM MLA Cuda KV Blocks Causing accuracy drop (#9675)
This commit is contained in:
@@ -51,6 +51,7 @@ class TRTLLMMLADecodeMetadata:
|
||||
|
||||
workspace: Optional[torch.Tensor] = None
|
||||
block_kv_indices: Optional[torch.Tensor] = None
|
||||
max_seq_len: Optional[int] = None
|
||||
|
||||
|
||||
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
@@ -207,8 +208,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
)
|
||||
|
||||
# Custom fast-path for decode/idle.
|
||||
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
||||
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
||||
# Capture with full width so future longer sequences are safe during replay
|
||||
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
||||
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq]
|
||||
|
||||
create_flashmla_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
@@ -217,13 +219,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
None,
|
||||
block_kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
max_seqlen_pad,
|
||||
max_blocks_per_seq,
|
||||
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
PAGED_SIZE=self.page_size,
|
||||
)
|
||||
|
||||
# Record the true maximum sequence length for this capture batch so that
|
||||
# the kernel launch path (which requires an int not a tensor) can reuse
|
||||
# it safely during both capture and replay.
|
||||
max_seq_len_val = int(seq_lens.max().item())
|
||||
|
||||
metadata = TRTLLMMLADecodeMetadata(
|
||||
self.decode_cuda_graph_workspace, block_kv_indices
|
||||
self.decode_cuda_graph_workspace,
|
||||
block_kv_indices,
|
||||
max_seq_len_val,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
@@ -268,6 +277,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
PAGED_SIZE=self.page_size,
|
||||
)
|
||||
|
||||
# Update stored max_seq_len so subsequent kernel calls use the correct value
|
||||
# Prefer CPU tensor to avoid GPU synchronization when available.
|
||||
if seq_lens_cpu is not None:
|
||||
metadata.max_seq_len = int(seq_lens_cpu.max().item())
|
||||
else:
|
||||
metadata.max_seq_len = int(seq_lens.max().item())
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
||||
"""Get the fill value for sequence lengths in CUDA graph."""
|
||||
return 1
|
||||
@@ -295,8 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
forward_batch.seq_lens.device,
|
||||
)
|
||||
|
||||
max_seq_len_val = int(max_seq)
|
||||
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
||||
self.workspace_buffer, block_kv_indices
|
||||
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
||||
)
|
||||
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
||||
|
||||
@@ -471,14 +488,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=metadata.block_kv_indices,
|
||||
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
||||
max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
|
||||
max_seq_len=metadata.max_seq_len,
|
||||
bmm1_scale=bmm1_scale,
|
||||
)
|
||||
|
||||
# Extract value projection part and reshape
|
||||
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
|
||||
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
# Reshape output directly without slicing
|
||||
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@@ -208,6 +208,15 @@ class MockModelRunner:
|
||||
self.kv_cache_dtype = config["kv_cache_dtype"]
|
||||
self.page_size = config["page_size"]
|
||||
|
||||
# Server args stub - needed by attention backends
|
||||
self.server_args = type(
|
||||
"ServerArgs",
|
||||
(),
|
||||
{
|
||||
"enable_dp_attention": False, # Default value for testing
|
||||
},
|
||||
)
|
||||
|
||||
# Model-config stub with MLA attributes
|
||||
self.model_config = type(
|
||||
"ModelConfig",
|
||||
@@ -833,7 +842,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
|
||||
# Test workspace properties
|
||||
self.assertEqual(metadata.workspace.device.type, "cuda")
|
||||
self.assertEqual(metadata.workspace.dtype, torch.int8)
|
||||
self.assertEqual(metadata.workspace.dtype, torch.uint8)
|
||||
self.assertGreater(
|
||||
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
|
||||
)
|
||||
@@ -993,8 +1002,8 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
)
|
||||
|
||||
# Verify CUDA graph buffers are allocated
|
||||
self.assertIsNotNone(backend.cuda_graph_kv_indices)
|
||||
self.assertIsNotNone(backend.cuda_graph_workspace)
|
||||
self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
|
||||
self.assertIsNotNone(backend.decode_cuda_graph_workspace)
|
||||
|
||||
# Test capture metadata
|
||||
seq_lens = torch.full(
|
||||
|
||||
Reference in New Issue
Block a user