TRTLLM Gen MLA Decode Kernel Integration (same as #7938) (#8632)

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Faraz
2025-07-31 19:03:40 -04:00
committed by GitHub
parent 3dde86194a
commit 4b04998d38
8 changed files with 1361 additions and 4 deletions

View File

@@ -0,0 +1,372 @@
from __future__ import annotations
"""
Support attention backend for TRTLLM MLA kernels from flashinfer.
"""
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.utils import (
TRITON_PAD_NUM_PAGE_PER_BLOCK,
create_flashmla_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
import flashinfer
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
# Constants
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
# Block constraint from flashinfer requirements
# From flashinfer.decode._check_trtllm_gen_mla_shape:
# block_num % (128 / block_size) == 0
# This imposes that the total number of blocks must be divisible by
# (128 / block_size). We capture the 128 constant here so we can
# compute the LCM with other padding constraints.
TRTLLM_BLOCK_CONSTRAINT = 128
@dataclass
class TRTLLMMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations."""
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""TRTLLM MLA attention kernel from flashinfer."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
config = model_runner.model_config
# Model parameters
self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
# MLA-specific dimensions
self.kv_lora_rank = config.kv_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
# Runtime parameters
self.scaling = config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.page_size = model_runner.page_size
self.req_to_token = model_runner.req_to_token_pool.req_to_token
# Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
# CUDA graph state
self.decode_cuda_graph_metadata = {}
self.cuda_graph_kv_indices = None
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
Args:
max_seq_len: Maximum sequence length in tokens
Returns:
Number of blocks padded to satisfy all constraints
"""
blocks = triton.cdiv(max_seq_len, self.page_size)
# Apply dual constraints (take LCM to satisfy both):
# 1. TRT-LLM: block_num % (128 / page_size) == 0
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
if blocks % constraint_lcm != 0:
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
return blocks
def _create_block_kv_indices(
self,
batch_size: int,
max_blocks: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
"""
Create block KV indices tensor using Triton kernel.
Args:
batch_size: Batch size
max_blocks: Maximum number of blocks per sequence
req_pool_indices: Request pool indices
seq_lens: Sequence lengths
device: Target device
Returns:
Block KV indices tensor
"""
block_kv_indices = torch.full(
(batch_size, max_blocks), -1, dtype=torch.int32, device=device
)
create_flashmla_kv_indices_triton[(batch_size,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_blocks,
TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size,
)
return block_kv_indices
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
"""Initialize CUDA graph state for TRTLLM MLA."""
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
self.cuda_graph_kv_indices = torch.full(
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
)
self.cuda_graph_workspace = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (forward_mode.is_decode_or_idle() and spec_info is None):
return super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
# Custom fast-path for decode/idle without speculative execution.
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size,
)
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (forward_mode.is_decode_or_idle() and spec_info is None):
return super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
metadata = self.decode_cuda_graph_metadata[bs]
# Update block indices for new sequences.
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
None,
metadata.block_kv_indices,
self.req_to_token.stride(0),
metadata.block_kv_indices.shape[1],
TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size,
)
def get_cuda_graph_seq_len_fill_value(self) -> int:
"""Get the fill value for sequence lengths in CUDA graph."""
return 1
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (
forward_batch.forward_mode.is_decode_or_idle()
and forward_batch.spec_info is None
):
return super().init_forward_metadata(forward_batch)
bs = forward_batch.batch_size
# Get maximum sequence length.
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
max_seq = forward_batch.seq_lens_cpu.max().item()
else:
max_seq = forward_batch.seq_lens.max().item()
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens.device,
)
self.forward_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer, block_kv_indices
)
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Run forward for decode using TRTLLM MLA kernel."""
# Save KV cache if requested
if k is not None and save_kv_cache:
cache_loc = forward_batch.out_cache_loc
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, cache_loc, k, k_rope
)
elif v is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
# Prepare query tensor inline
if q_rope is not None:
# q contains NOPE part (v_head_dim)
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope_reshaped = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
else:
# q already has both parts
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
if query.dim() == 3:
query = query.unsqueeze(1)
# Prepare KV cache inline
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
# TRT-LLM expects single KV data with extra dimension
kv_cache = pages.unsqueeze(1)
# Get metadata
metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
or self.forward_metadata
)
# Scale computation for TRTLLM MLA kernel:
# - BMM1 scale = q_scale * k_scale * softmax_scale
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
# - k_scale is read from model checkpoint if available
# TODO: Change once fp8 path is supported
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
# Call TRT-LLM kernel
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
workspace_buffer=metadata.workspace,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=metadata.block_kv_indices,
seq_lens=forward_batch.seq_lens.to(torch.int32),
max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
bmm1_scale=bmm1_scale,
)
# Extract value projection part and reshape
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output

View File

@@ -1,6 +1,11 @@
import triton
import triton.language as tl
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
# Number of pages that the kernel writes per iteration.
# Exposed here so other Python modules can import it instead of hard-coding 64.
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
@triton.jit
def create_flashinfer_kv_indices_triton(
@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
kv_indices_ptr_stride: tl.constexpr,
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
PAGED_SIZE: tl.constexpr = 64,
):
BLOCK_SIZE: tl.constexpr = 4096
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token

View File

@@ -436,6 +436,7 @@ class ModelRunner:
"triton",
"flashmla",
"cutlass_mla",
"trtllm_mla",
"ascend",
]:
logger.info(
@@ -1437,6 +1438,12 @@ class ModelRunner:
)
return CutlassMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mla":
if not self.use_mla_backend:
raise ValueError("trtllm_mla backend can only be used with MLA models.")
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
return TRTLLMMLABackend(self)
elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,

View File

@@ -1259,6 +1259,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.current_attention_backend == "fa3"
or self.current_attention_backend == "flashinfer"
or self.current_attention_backend == "cutlass_mla"
or self.current_attention_backend == "trtllm_mla"
):
attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe

View File

@@ -24,6 +24,7 @@ import tempfile
from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import (
@@ -402,6 +403,22 @@ class ServerArgs:
)
self.page_size = 128
if self.attention_backend == "trtllm_mla":
if not is_sm100_supported():
raise ValueError(
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
)
if self.page_size not in [32, 64]:
logger.warning(
f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
)
self.page_size = 64
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
)
# Set page size
if self.page_size is None:
self.page_size = 1
@@ -1225,6 +1242,7 @@ class ServerArgs:
"torch_native",
"ascend",
"triton",
"trtllm_mla",
],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",

View File

@@ -0,0 +1,945 @@
import math
import unittest
import numpy as np
import torch
from sglang.srt.layers import dp_attention as _dp_attn
# Patch DP-attention globals before importing backends
# TODO: change the interface of both trtllm_mla and flashinfer backends to take tp_size as an argument instead of patching
_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend,
TRTLLMMLADecodeMetadata,
)
from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
from sglang.test.test_utils import CustomTestCase
# Global configuration for all tests
DEFAULT_CONFIG = {
"device": "cuda",
"dtype": torch.bfloat16,
"kv_cache_dtype": torch.bfloat16,
"context_len": 2048,
"max_bs": 64,
"tolerance": 1e-2,
"seed_cache": 42,
"seed_qkv": 123,
# MLA model config (TRTLLM MLA has fixed constraints)
"num_attention_heads": 128,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 512,
"num_kv_heads": 1,
"layer_id": 0,
}
# Centralized test cases for different test scenarios
TEST_CASES = {
"basic_functionality": [
{
"name": "single",
"batch_size": 1,
"max_seq_len": 32,
"page_size": 32,
"description": "Minimal smoke test",
},
{
"name": "batch",
"batch_size": 32,
"max_seq_len": 128,
"page_size": 32,
"description": "Medium-scale batch",
},
],
"decode_output_match": [
{
"name": "single",
"batch_size": 1,
"max_seq_len": 64,
"page_size": 32,
"description": "Single vs reference",
},
{
"name": "batch",
"batch_size": 32,
"max_seq_len": 64,
"page_size": 32,
"description": "Batch vs reference",
},
],
"page_size_consistency": [
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
{
"name": "page_32",
"batch_size": 8,
"max_seq_len": 128,
"page_size": 32,
"description": "32-token pages",
},
{
"name": "page_64",
"batch_size": 8,
"max_seq_len": 128,
"page_size": 64,
"description": "64-token pages",
},
],
"shape_sanity_tests": [
{
"name": "basic",
"batch_size": 1,
"max_seq_len": 128,
"page_size": 32,
"description": "Single sequence",
},
{
"name": "basic_different_pagesize",
"batch_size": 1,
"max_seq_len": 128,
"page_size": 64,
"description": "Different page size",
},
{
"name": "batch",
"batch_size": 8,
"max_seq_len": 128,
"page_size": 32,
"description": "Batch shapes",
},
],
"metadata_tests": [
{
"name": "single_sequence",
"batch_size": 1,
"max_seq_len": 64,
"page_size": 32,
"description": "Single sequence metadata",
},
{
"name": "batch_mixed_lengths",
"batch_size": 8,
"max_seq_len": 128,
"page_size": 32,
"description": "Mixed sequence lengths",
},
{
"name": "large_batch",
"batch_size": 32,
"max_seq_len": 256,
"page_size": 64,
"description": "Large batch stress test",
},
{
"name": "edge_case_short",
"batch_size": 4,
"max_seq_len": 16,
"page_size": 32,
"description": "Sub-page sequences",
},
],
}
class MockModelRunner:
"""Minimal fake ModelRunner for testing MLA backends."""
def __init__(self, config):
self.device = config["device"]
self.dtype = config["dtype"]
self.kv_cache_dtype = config["kv_cache_dtype"]
self.page_size = config["page_size"]
# Model-config stub with MLA attributes
self.model_config = type(
"ModelConfig",
(),
{
"context_len": config["context_len"],
"attention_arch": AttentionArch.MLA,
"num_attention_heads": config["num_attention_heads"],
"kv_lora_rank": config["kv_lora_rank"],
"qk_nope_head_dim": config["qk_nope_head_dim"],
"qk_rope_head_dim": config["qk_rope_head_dim"],
"v_head_dim": config["v_head_dim"],
"scaling": 1.0
/ ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5),
"get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]),
},
)
# Req-to-token pool
max_bs = config["max_bs"]
max_ctx = self.model_config.context_len
req_to_token = torch.arange(
max_bs * max_ctx, dtype=torch.int32, device=self.device
).reshape(max_bs, max_ctx)
self.req_to_token_pool = type(
"TokenPool",
(),
{
"size": max_bs,
"req_to_token": req_to_token,
},
)
# KV-token pool (MLA)
self.token_to_kv_pool = MLATokenToKVPool(
size=max_bs * max_ctx,
page_size=config["page_size"],
dtype=self.kv_cache_dtype,
kv_lora_rank=config["kv_lora_rank"],
qk_rope_head_dim=config["qk_rope_head_dim"],
layer_num=1,
device=self.device,
enable_memory_saver=False,
)
def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
"""Compare outputs with detailed analysis."""
# Basic checks
assert (
trtllm_out.shape == reference_out.shape
), f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}"
assert (
trtllm_out.dtype == reference_out.dtype
), f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}"
# Check for NaN/Inf
assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN"
assert not torch.isnan(reference_out).any(), "Reference output contains NaN"
assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf"
assert not torch.isinf(reference_out).any(), "Reference output contains Inf"
# Element-wise differences
diff = (trtllm_out - reference_out).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
# Check numerical equivalence
all_close = torch.allclose(
trtllm_out, reference_out, rtol=tolerance, atol=tolerance
)
if not all_close:
print(
f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}"
)
# Find top differences for debugging
flat_diff = diff.flatten()
top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices
print("Top 5 differences:")
for i, idx in enumerate(top_diff_indices):
idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape)
trt_val = trtllm_out[idx_tuple].item()
ref_val = reference_out[idx_tuple].item()
print(
f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}"
)
return all_close
@unittest.skipIf(
not torch.cuda.is_available() or not is_flashinfer_available(),
"CUDA + flashinfer required",
)
class TestTRTLLMMLA(CustomTestCase):
"""Test suite for TRTLLM MLA backend with centralized configuration."""
def _merge_config(self, test_case):
"""Merge test case with default configuration."""
config = DEFAULT_CONFIG.copy()
config.update(test_case)
return config
def _create_model_components(self, config):
"""Create model runners, backends, and layer for testing."""
# Create model runners
model_runner_trtllm = MockModelRunner(config)
model_runner_reference = MockModelRunner(config)
# Create backends
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
# Create RadixAttention layer
layer = RadixAttention(
num_heads=config["num_attention_heads"],
head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
scaling=model_runner_trtllm.model_config.scaling,
num_kv_heads=config["num_kv_heads"],
layer_id=config["layer_id"],
v_head_dim=config["v_head_dim"],
prefix="attn_mqa",
)
return (
model_runner_trtllm,
model_runner_reference,
trtllm_backend,
reference_backend,
layer,
)
def _create_qkv_tensors(self, batch_size, config):
"""Create Q, K, V tensors for testing."""
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
device = config["device"]
dtype = config["dtype"]
q = torch.randn(
(batch_size, config["num_attention_heads"], head_dim),
dtype=dtype,
device=device,
)
k = torch.randn(
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
)
v = torch.randn(
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
dtype=dtype,
device=device,
)
return q, k, v
def _create_forward_batch(
self, batch_size, seq_lens, backend, model_runner, config
):
"""Create a forward batch for the given backend."""
fb = ForwardBatch(
batch_size=batch_size,
input_ids=torch.randint(0, 100, (batch_size, 1), device=config["device"]),
out_cache_loc=torch.arange(batch_size, device=config["device"]),
seq_lens_sum=int(seq_lens.sum().item()),
forward_mode=ForwardMode.DECODE,
req_pool_indices=torch.arange(batch_size, device=config["device"]),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
attn_backend=backend,
)
fb.req_to_token_pool = model_runner.req_to_token_pool
fb.token_to_kv_pool = model_runner.token_to_kv_pool
return fb
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
"""Populate KV cache with identical data for both backends."""
torch.manual_seed(config["seed_cache"]) # Fixed seed for reproducible cache
for model_runner in model_runners:
torch.manual_seed(config["seed_cache"]) # Reset seed for each backend
for i in range(batch_size):
seq_len = int(seq_lens[i].item())
for token_idx in range(seq_len - 1):
# Create random K components for MLA
cache_k_nope = torch.randn(
(1, config["qk_nope_head_dim"]),
dtype=config["dtype"],
device=config["device"],
)
cache_k_rope = torch.randn(
(1, config["qk_rope_head_dim"]),
dtype=config["dtype"],
device=config["device"],
)
# Calculate cache location
cache_loc = model_runner.req_to_token_pool.req_to_token[
i, token_idx
]
# Save to KV cache
model_runner.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc.unsqueeze(0),
cache_k_nope.squeeze(0),
cache_k_rope.squeeze(0),
)
def test_basic_functionality(self):
"""Test basic functionality with minimal setup."""
print(f"\nRunning basic functionality tests...")
for test_case in TEST_CASES["basic_functionality"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
model_runner_trtllm, _, trtllm_backend, _, layer = (
self._create_model_components(config)
)
# Create sequence lengths - properly handle different batch sizes
if batch_size == 2:
seq_lens = torch.tensor(
[max_seq_len, max_seq_len // 2], device=config["device"]
)
else:
# For larger batch sizes, create varied sequence lengths
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
max_seq_len // 2,
max_seq_len + 1,
(batch_size,),
device=config["device"],
)
seq_lens[0] = max_seq_len # Ensure at least one max length
# Create forward batch
fb = self._create_forward_batch(
batch_size, seq_lens, trtllm_backend, model_runner_trtllm, config
)
trtllm_backend.init_forward_metadata(fb)
# Populate KV cache
self._populate_kv_cache(
batch_size, seq_lens, [model_runner_trtllm], layer, config
)
# Create Q, K, V tensors
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
# Run forward decode
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
# Basic checks
expected_shape = (
batch_size,
config["num_attention_heads"] * config["v_head_dim"],
)
self.assertEqual(output.shape, expected_shape)
self.assertEqual(output.dtype, config["dtype"])
self.assertFalse(torch.isnan(output).any())
self.assertFalse(torch.isinf(output).any())
def test_decode_output_match(self):
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
print(f"\nRunning decode output matching tests...")
for test_case in TEST_CASES["decode_output_match"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
(
model_runner_trtllm,
model_runner_reference,
trtllm_backend,
reference_backend,
layer,
) = self._create_model_components(config)
# Create identical sequence lengths for both backends
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
1, max_seq_len, (batch_size,), device=config["device"]
)
seq_lens[0] = max_seq_len # Ensure at least one max length
# Create forward batches with identical inputs
fb_trtllm = self._create_forward_batch(
batch_size,
seq_lens.clone(),
trtllm_backend,
model_runner_trtllm,
config,
)
fb_reference = self._create_forward_batch(
batch_size,
seq_lens.clone(),
reference_backend,
model_runner_reference,
config,
)
# Initialize metadata for both backends
trtllm_backend.init_forward_metadata(fb_trtllm)
reference_backend.init_forward_metadata(fb_reference)
# Populate both KV caches identically
self._populate_kv_cache(
batch_size,
seq_lens,
[model_runner_trtllm, model_runner_reference],
layer,
config,
)
# Create Q, K, V tensors for current decode step
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
# Run forward decode on both backends
out_trtllm = trtllm_backend.forward_decode(
q.clone(), k.clone(), v.clone(), layer, fb_trtllm
)
out_reference = reference_backend.forward_decode(
q.clone(), k.clone(), v.clone(), layer, fb_reference
)
# Compare outputs
comparison_passed = compare_outputs(
out_trtllm, out_reference, tolerance=config["tolerance"]
)
self.assertTrue(
comparison_passed,
f"TRTLLM and Reference outputs differ beyond tolerance. "
f"Config: {test_case['name']}, "
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
)
def test_page_size_consistency(self):
"""Test output consistency across different page sizes."""
print(f"\nRunning page size consistency tests...")
for test_case in TEST_CASES["page_size_consistency"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
model_runner, _, backend, _, layer = self._create_model_components(
config
)
# Create sequence lengths
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
1, max_seq_len, (batch_size,), device=config["device"]
)
seq_lens[0] = max_seq_len
# Create forward batch
fb = self._create_forward_batch(
batch_size, seq_lens, backend, model_runner, config
)
backend.init_forward_metadata(fb)
# Populate KV cache
self._populate_kv_cache(
batch_size, seq_lens, [model_runner], layer, config
)
# Create Q, K, V tensors
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
expected_shape = (
batch_size,
config["num_attention_heads"] * config["v_head_dim"],
)
self.assertEqual(
output.shape,
expected_shape,
f"Output shape mismatch: {output.shape} vs {expected_shape}",
)
self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
def test_shape_sanity(self):
"""Smoke test decode across several configurations."""
print(f"\nRunning shape sanity tests...")
for test_case in TEST_CASES["shape_sanity_tests"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
model_runner, _, backend, _, layer = self._create_model_components(
config
)
# Random seq lens (ensure one matches max)
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
1, max_seq_len, (batch_size,), device=config["device"]
)
seq_lens[0] = max_seq_len
fb = self._create_forward_batch(
batch_size, seq_lens, backend, model_runner, config
)
backend.init_forward_metadata(fb)
# Create Q, K, V tensors
torch.manual_seed(config["seed_qkv"])
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
q = torch.randn(
(batch_size, config["num_attention_heads"], head_dim),
dtype=config["dtype"],
device=config["device"],
)
k = torch.randn(
(batch_size, config["num_kv_heads"], head_dim),
dtype=config["dtype"],
device=config["device"],
)
v = None
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
# Shape and sanity checks
expected_shape = (
batch_size,
config["num_attention_heads"] * config["v_head_dim"],
)
self.assertEqual(
output.shape,
expected_shape,
f"Output shape mismatch for {test_case['name']}",
)
self.assertEqual(output.dtype, config["dtype"])
self.assertEqual(output.device.type, "cuda")
self.assertFalse(
torch.isnan(output).any(),
f"Output contains NaN for {test_case['name']}",
)
self.assertFalse(
torch.isinf(output).any(),
f"Output contains Inf for {test_case['name']}",
)
def test_metadata_initialization(self):
"""Test TRTLLM MLA metadata initialization and structure."""
print(f"\nRunning metadata initialization tests...")
for test_case in TEST_CASES["metadata_tests"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
model_runner, _, backend, _, layer = self._create_model_components(
config
)
# Create varied sequence lengths
torch.manual_seed(config["seed_cache"])
if batch_size == 1:
seq_lens = torch.tensor([max_seq_len], device=config["device"])
else:
seq_lens = torch.randint(
max(1, max_seq_len // 4),
max_seq_len + 1,
(batch_size,),
device=config["device"],
)
seq_lens[0] = max_seq_len # Ensure at least one max length
# Create forward batch
fb = self._create_forward_batch(
batch_size, seq_lens, backend, model_runner, config
)
# Initialize metadata
backend.init_forward_metadata(fb)
# Verify metadata exists
self.assertIsNotNone(backend.forward_metadata)
self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
# Test metadata structure
metadata = backend.forward_metadata
self.assertIsNotNone(
metadata.workspace, "Workspace should be allocated"
)
self.assertIsNotNone(
metadata.block_kv_indices, "Block KV indices should be created"
)
# Test workspace properties
self.assertEqual(metadata.workspace.device.type, "cuda")
self.assertEqual(metadata.workspace.dtype, torch.int8)
self.assertGreater(
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
)
# Test block KV indices properties
self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
self.assertEqual(metadata.block_kv_indices.shape[0], batch_size)
# Verify block indices are valid (>= -1, since -1 is padding)
self.assertTrue(
(metadata.block_kv_indices >= -1).all(),
"All block indices should be >= -1 (with -1 as padding)",
)
def test_metadata_block_calculation(self):
"""Test block count calculation logic."""
print(f"\nRunning metadata block calculation tests...")
test_scenarios = [
{"seq_len": 31, "page_size": 32, "expected_min_blocks": 1},
{"seq_len": 32, "page_size": 32, "expected_min_blocks": 1},
{"seq_len": 33, "page_size": 32, "expected_min_blocks": 2},
{"seq_len": 128, "page_size": 32, "expected_min_blocks": 4},
{"seq_len": 128, "page_size": 64, "expected_min_blocks": 2},
]
for scenario in test_scenarios:
with self.subTest(scenario=scenario):
config = self._merge_config(
{
"batch_size": 1,
"max_seq_len": scenario["seq_len"],
"page_size": scenario["page_size"],
}
)
model_runner, _, backend, _, _ = self._create_model_components(config)
# Test internal block calculation
calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"])
# Should be at least the minimum required
self.assertGreaterEqual(
calculated_blocks,
scenario["expected_min_blocks"],
f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})",
)
# Should satisfy page_size constraint
total_tokens = calculated_blocks * scenario["page_size"]
self.assertGreaterEqual(
total_tokens,
scenario["seq_len"],
f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})",
)
# Should satisfy TRT-LLM and Triton constraints
trtllm_constraint = 128 // scenario["page_size"]
constraint_lcm = math.lcm(
trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
)
self.assertEqual(
calculated_blocks % constraint_lcm,
0,
f"Block count should be multiple of LCM of constraints ({constraint_lcm})",
)
def test_metadata_kv_indices_correctness(self):
"""Test KV indices creation and correctness."""
print(f"\nRunning KV indices correctness tests...")
for test_case in TEST_CASES["metadata_tests"][
:2
]: # Test subset for performance
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
model_runner, _, backend, _, layer = self._create_model_components(
config
)
# Create known sequence lengths
torch.manual_seed(config["seed_cache"])
if batch_size == 1:
seq_lens = torch.tensor([max_seq_len], device=config["device"])
else:
seq_lens = torch.randint(
max_seq_len // 2,
max_seq_len + 1,
(batch_size,),
device=config["device"],
)
fb = self._create_forward_batch(
batch_size, seq_lens, backend, model_runner, config
)
# Populate some KV cache to have valid indices
self._populate_kv_cache(
batch_size, seq_lens, [model_runner], layer, config
)
# Initialize metadata
backend.init_forward_metadata(fb)
metadata = backend.forward_metadata
# Verify KV indices structure
block_kv_indices = metadata.block_kv_indices
for i in range(batch_size):
seq_len = seq_lens[i].item()
expected_blocks = backend._calc_padded_blocks(seq_len)
# Count valid (non -1) indices for this sequence
valid_indices = (block_kv_indices[i] >= 0).sum().item()
# Should have at least enough blocks for the sequence
min_required_blocks = (seq_len + config["page_size"] - 1) // config[
"page_size"
]
self.assertGreaterEqual(
valid_indices,
min_required_blocks,
f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}",
)
# Verify indices are within valid range
valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0]
if len(valid_block_indices) > 0:
max_possible_blocks = (
model_runner.token_to_kv_pool.size // config["page_size"]
)
self.assertTrue(
(valid_block_indices < max_possible_blocks).all(),
f"All block indices should be < {max_possible_blocks}",
)
def test_metadata_cuda_graph_compatibility(self):
"""Test metadata compatibility with CUDA graph capture/replay."""
print(f"\nRunning CUDA graph compatibility tests...")
config = self._merge_config(
{"batch_size": 4, "max_seq_len": 64, "page_size": 32}
)
model_runner, _, backend, _, layer = self._create_model_components(config)
batch_size = config["batch_size"]
# Initialize CUDA graph state
backend.init_cuda_graph_state(
max_bs=batch_size, max_num_tokens=config["max_seq_len"] * batch_size
)
# Verify CUDA graph buffers are allocated
self.assertIsNotNone(backend.cuda_graph_kv_indices)
self.assertIsNotNone(backend.cuda_graph_workspace)
# Test capture metadata
seq_lens = torch.full(
(batch_size,), config["max_seq_len"], device=config["device"]
)
req_pool_indices = torch.arange(batch_size, device=config["device"])
backend.init_forward_metadata_capture_cuda_graph(
bs=batch_size,
num_tokens=batch_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=None,
)
# Verify capture metadata
self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
self.assertIsNotNone(capture_metadata.workspace)
self.assertIsNotNone(capture_metadata.block_kv_indices)
# Test replay with different sequence lengths
new_seq_lens = torch.randint(
config["max_seq_len"] // 2,
config["max_seq_len"] + 1,
(batch_size,),
device=config["device"],
)
backend.init_forward_metadata_replay_cuda_graph(
bs=batch_size,
req_pool_indices=req_pool_indices,
seq_lens=new_seq_lens,
seq_lens_sum=new_seq_lens.sum().item(),
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=None,
seq_lens_cpu=new_seq_lens.cpu(),
)
# Verify replay updated the metadata
replay_metadata = backend.forward_metadata
self.assertIsNotNone(replay_metadata)
self.assertEqual(
replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
)
def test_metadata_consistency_across_calls(self):
"""Test metadata consistency across multiple forward calls."""
print(f"\nRunning metadata consistency tests...")
config = self._merge_config(
{"batch_size": 2, "max_seq_len": 64, "page_size": 32}
)
model_runner, _, backend, _, layer = self._create_model_components(config)
# First call
seq_lens_1 = torch.tensor([32, 48], device=config["device"])
fb_1 = self._create_forward_batch(
config["batch_size"], seq_lens_1, backend, model_runner, config
)
backend.init_forward_metadata(fb_1)
metadata_1 = backend.forward_metadata
# Second call with same sequence lengths
seq_lens_2 = torch.tensor([32, 48], device=config["device"])
fb_2 = self._create_forward_batch(
config["batch_size"], seq_lens_2, backend, model_runner, config
)
backend.init_forward_metadata(fb_2)
metadata_2 = backend.forward_metadata
# Metadata structure should be consistent
self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
self.assertEqual(
metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
)
# Third call with different sequence lengths
seq_lens_3 = torch.tensor([16, 64], device=config["device"])
fb_3 = self._create_forward_batch(
config["batch_size"], seq_lens_3, backend, model_runner, config
)
backend.init_forward_metadata(fb_3)
metadata_3 = backend.forward_metadata
# Should still have valid structure
self.assertIsNotNone(metadata_3.workspace)
self.assertIsNotNone(metadata_3.block_kv_indices)
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
if __name__ == "__main__":
unittest.main()