diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md index caf23446f..3dfe6cb3d 100644 --- a/docs/backend/attention_backend.md +++ b/docs/backend/attention_backend.md @@ -9,8 +9,12 @@ | **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ | | **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | | **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | +| **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ | | **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ | +**Notes:** +- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend. + Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`. This is because a page size of 16 can be converted to a page size of 1 in the kernel backend. The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1. @@ -48,6 +52,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code ``` +- TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200) +```bash +python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code +``` + - Ascend ```bash python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 8b6d688d1..af5e38677 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -90,7 +90,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads. +- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), **TRTLLM MLA** (optimized for Blackwell architecture), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. @@ -104,7 +104,7 @@ Overall, with these optimizations, we have achieved up to **7x** acceleration in Multi-head Latent Attention for DeepSeek Series Models

-**Usage**: MLA optimization is enabled by default. +**Usage**: MLA optimization is enabled by default. For MLA models on Blackwell architecture (e.g., B200), the default backend is FlashInfer. To use the optimized TRTLLM MLA backend for decode operations, explicitly specify `--attention-backend trtllm_mla`. Note that TRTLLM MLA only optimizes decode operations - prefill operations (including multimodal inputs) will fall back to FlashInfer MLA. **Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details. @@ -161,7 +161,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8 ``` - The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. -- FlashAttention3 FlashMLA and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development. +- FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA and TRTLLM MLA backends are still under development. - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py new file mode 100755 index 000000000..d33201442 --- /dev/null +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -0,0 +1,372 @@ +from __future__ import annotations + +""" +Support attention backend for TRTLLM MLA kernels from flashinfer. +""" + +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import torch +import triton + +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.attention.utils import ( + TRITON_PAD_NUM_PAGE_PER_BLOCK, + create_flashmla_kv_indices_triton, +) +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import is_flashinfer_available + +if is_flashinfer_available(): + import flashinfer + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo + +# Constants +DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB + +# Block constraint from flashinfer requirements +# From flashinfer.decode._check_trtllm_gen_mla_shape: +# block_num % (128 / block_size) == 0 +# This imposes that the total number of blocks must be divisible by +# (128 / block_size). We capture the 128 constant here so we can +# compute the LCM with other padding constraints. +TRTLLM_BLOCK_CONSTRAINT = 128 + + +@dataclass +class TRTLLMMLADecodeMetadata: + """Metadata for TRTLLM MLA decode operations.""" + + workspace: Optional[torch.Tensor] = None + block_kv_indices: Optional[torch.Tensor] = None + + +class TRTLLMMLABackend(FlashInferMLAAttnBackend): + """TRTLLM MLA attention kernel from flashinfer.""" + + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + q_indptr_decode_buf: Optional[torch.Tensor] = None, + ): + super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf) + + config = model_runner.model_config + + # Model parameters + self.num_q_heads = config.num_attention_heads // get_attention_tp_size() + self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size()) + self.num_local_heads = config.num_attention_heads // get_attention_tp_size() + + # MLA-specific dimensions + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim + + # Runtime parameters + self.scaling = config.scaling + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + self.page_size = model_runner.page_size + self.req_to_token = model_runner.req_to_token_pool.req_to_token + + # Workspace allocation + self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 + self.workspace_buffer = torch.empty( + self.workspace_size, dtype=torch.int8, device=self.device + ) + + # CUDA graph state + self.decode_cuda_graph_metadata = {} + self.cuda_graph_kv_indices = None + self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None + + def _calc_padded_blocks(self, max_seq_len: int) -> int: + """ + Calculate padded block count that satisfies both TRT-LLM and Triton constraints. + + Args: + max_seq_len: Maximum sequence length in tokens + + Returns: + Number of blocks padded to satisfy all constraints + """ + blocks = triton.cdiv(max_seq_len, self.page_size) + + # Apply dual constraints (take LCM to satisfy both): + # 1. TRT-LLM: block_num % (128 / page_size) == 0 + # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64 + trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size + constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK) + + if blocks % constraint_lcm != 0: + blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm + return blocks + + def _create_block_kv_indices( + self, + batch_size: int, + max_blocks: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + """ + Create block KV indices tensor using Triton kernel. + + Args: + batch_size: Batch size + max_blocks: Maximum number of blocks per sequence + req_pool_indices: Request pool indices + seq_lens: Sequence lengths + device: Target device + + Returns: + Block KV indices tensor + """ + block_kv_indices = torch.full( + (batch_size, max_blocks), -1, dtype=torch.int32, device=device + ) + + create_flashmla_kv_indices_triton[(batch_size,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_blocks, + TRITON_PAD_NUM_PAGE_PER_BLOCK, + self.page_size, + ) + + return block_kv_indices + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + """Initialize CUDA graph state for TRTLLM MLA.""" + max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) + + self.cuda_graph_kv_indices = torch.full( + (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device + ) + self.cuda_graph_workspace = torch.empty( + self.workspace_size, dtype=torch.int8, device=self.device + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + ): + """Initialize metadata for CUDA graph capture.""" + # Delegate to parent for non-decode modes or when speculative execution is used. + if not (forward_mode.is_decode_or_idle() and spec_info is None): + return super().init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + + # Custom fast-path for decode/idle without speculative execution. + max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) + block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] + + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + TRITON_PAD_NUM_PAGE_PER_BLOCK, + self.page_size, + ) + + metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices) + self.decode_cuda_graph_metadata[bs] = metadata + self.forward_metadata = metadata + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], + ): + """Replay CUDA graph with new inputs.""" + # Delegate to parent for non-decode modes or when speculative execution is used. + if not (forward_mode.is_decode_or_idle() and spec_info is None): + return super().init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + + metadata = self.decode_cuda_graph_metadata[bs] + + # Update block indices for new sequences. + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + None, + metadata.block_kv_indices, + self.req_to_token.stride(0), + metadata.block_kv_indices.shape[1], + TRITON_PAD_NUM_PAGE_PER_BLOCK, + self.page_size, + ) + + def get_cuda_graph_seq_len_fill_value(self) -> int: + """Get the fill value for sequence lengths in CUDA graph.""" + return 1 + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Initialize the metadata for a forward pass.""" + # Delegate to parent for non-decode modes or when speculative execution is used. + if not ( + forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.spec_info is None + ): + return super().init_forward_metadata(forward_batch) + + bs = forward_batch.batch_size + + # Get maximum sequence length. + if getattr(forward_batch, "seq_lens_cpu", None) is not None: + max_seq = forward_batch.seq_lens_cpu.max().item() + else: + max_seq = forward_batch.seq_lens.max().item() + + max_seqlen_pad = self._calc_padded_blocks(max_seq) + block_kv_indices = self._create_block_kv_indices( + bs, + max_seqlen_pad, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens.device, + ) + + self.forward_metadata = TRTLLMMLADecodeMetadata( + self.workspace_buffer, block_kv_indices + ) + forward_batch.decode_trtllm_mla_metadata = self.forward_metadata + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run forward for decode using TRTLLM MLA kernel.""" + # Save KV cache if requested + if k is not None and save_kv_cache: + cache_loc = forward_batch.out_cache_loc + if k_rope is not None: + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, cache_loc, k, k_rope + ) + elif v is not None: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + + # Prepare query tensor inline + if q_rope is not None: + # q contains NOPE part (v_head_dim) + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope_reshaped = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + query = torch.cat([q_nope, q_rope_reshaped], dim=-1) + else: + # q already has both parts + query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + + # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1 + if query.dim() == 3: + query = query.unsqueeze(1) + + # Prepare KV cache inline + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) + # TRT-LLM expects single KV data with extra dimension + kv_cache = pages.unsqueeze(1) + + # Get metadata + metadata = ( + getattr(forward_batch, "decode_trtllm_mla_metadata", None) + or self.forward_metadata + ) + + # Scale computation for TRTLLM MLA kernel: + # - BMM1 scale = q_scale * k_scale * softmax_scale + # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling + # - k_scale is read from model checkpoint if available + # TODO: Change once fp8 path is supported + q_scale = 1.0 + k_scale = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) + + bmm1_scale = q_scale * k_scale * layer.scaling + + # Call TRT-LLM kernel + raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache, + workspace_buffer=metadata.workspace, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=metadata.block_kv_indices, + seq_lens=forward_batch.seq_lens.to(torch.int32), + max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), + bmm1_scale=bmm1_scale, + ) + + # Extract value projection part and reshape + raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() + output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + return output diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 71633d12d..e8cd2e158 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -1,6 +1,11 @@ import triton import triton.language as tl +# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`. +# Number of pages that the kernel writes per iteration. +# Exposed here so other Python modules can import it instead of hard-coding 64. +TRITON_PAD_NUM_PAGE_PER_BLOCK = 64 + @triton.jit def create_flashinfer_kv_indices_triton( @@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton( kv_indices_ptr, req_to_token_ptr_stride: tl.constexpr, kv_indices_ptr_stride: tl.constexpr, + NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK, PAGED_SIZE: tl.constexpr = 64, ): BLOCK_SIZE: tl.constexpr = 4096 - NUM_PAGE_PER_BLOCK: tl.constexpr = 64 pid = tl.program_id(axis=0) # find the req pool idx, this is for batch to token diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 41016c3d9..d04b3c47d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -436,6 +436,7 @@ class ModelRunner: "triton", "flashmla", "cutlass_mla", + "trtllm_mla", "ascend", ]: logger.info( @@ -1437,6 +1438,12 @@ class ModelRunner: ) return CutlassMLABackend(self) + elif self.server_args.attention_backend == "trtllm_mla": + if not self.use_mla_backend: + raise ValueError("trtllm_mla backend can only be used with MLA models.") + from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend + + return TRTLLMMLABackend(self) elif self.server_args.attention_backend == "intel_amx": from sglang.srt.layers.attention.intel_amx_backend import ( IntelAMXAttnBackend, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ace06cb7b..bd0e35a2e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1259,6 +1259,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.current_attention_backend == "fa3" or self.current_attention_backend == "flashinfer" or self.current_attention_backend == "cutlass_mla" + or self.current_attention_backend == "trtllm_mla" ): attn_output = self.attn_mqa( q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 507fb7121..c4a520f1c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -24,6 +24,7 @@ import tempfile from typing import List, Literal, Optional, Union from sglang.srt.hf_transformers_utils import check_gguf_file, get_config +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( @@ -402,6 +403,22 @@ class ServerArgs: ) self.page_size = 128 + if self.attention_backend == "trtllm_mla": + if not is_sm100_supported(): + raise ValueError( + "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend." + ) + + if self.page_size not in [32, 64]: + logger.warning( + f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64." + ) + self.page_size = 64 + if self.speculative_algorithm is not None: + raise ValueError( + "trtllm_mla backend does not support speculative decoding yet." + ) + # Set page size if self.page_size is None: self.page_size = 1 @@ -1225,6 +1242,7 @@ class ServerArgs: "torch_native", "ascend", "triton", + "trtllm_mla", ], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py new file mode 100755 index 000000000..be3ed08f4 --- /dev/null +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -0,0 +1,945 @@ +import math +import unittest + +import numpy as np +import torch + +from sglang.srt.layers import dp_attention as _dp_attn + +# Patch DP-attention globals before importing backends +# TODO: change the interface of both trtllm_mla and flashinfer backends to take tp_size as an argument instead of patching +_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test + +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.attention.trtllm_mla_backend import ( + TRTLLMMLABackend, + TRTLLMMLADecodeMetadata, +) +from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import is_flashinfer_available +from sglang.test.test_utils import CustomTestCase + +# Global configuration for all tests +DEFAULT_CONFIG = { + "device": "cuda", + "dtype": torch.bfloat16, + "kv_cache_dtype": torch.bfloat16, + "context_len": 2048, + "max_bs": 64, + "tolerance": 1e-2, + "seed_cache": 42, + "seed_qkv": 123, + # MLA model config (TRTLLM MLA has fixed constraints) + "num_attention_heads": 128, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 512, + "num_kv_heads": 1, + "layer_id": 0, +} + +# Centralized test cases for different test scenarios +TEST_CASES = { + "basic_functionality": [ + { + "name": "single", + "batch_size": 1, + "max_seq_len": 32, + "page_size": 32, + "description": "Minimal smoke test", + }, + { + "name": "batch", + "batch_size": 32, + "max_seq_len": 128, + "page_size": 32, + "description": "Medium-scale batch", + }, + ], + "decode_output_match": [ + { + "name": "single", + "batch_size": 1, + "max_seq_len": 64, + "page_size": 32, + "description": "Single vs reference", + }, + { + "name": "batch", + "batch_size": 32, + "max_seq_len": 64, + "page_size": 32, + "description": "Batch vs reference", + }, + ], + "page_size_consistency": [ + # Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel + { + "name": "page_32", + "batch_size": 8, + "max_seq_len": 128, + "page_size": 32, + "description": "32-token pages", + }, + { + "name": "page_64", + "batch_size": 8, + "max_seq_len": 128, + "page_size": 64, + "description": "64-token pages", + }, + ], + "shape_sanity_tests": [ + { + "name": "basic", + "batch_size": 1, + "max_seq_len": 128, + "page_size": 32, + "description": "Single sequence", + }, + { + "name": "basic_different_pagesize", + "batch_size": 1, + "max_seq_len": 128, + "page_size": 64, + "description": "Different page size", + }, + { + "name": "batch", + "batch_size": 8, + "max_seq_len": 128, + "page_size": 32, + "description": "Batch shapes", + }, + ], + "metadata_tests": [ + { + "name": "single_sequence", + "batch_size": 1, + "max_seq_len": 64, + "page_size": 32, + "description": "Single sequence metadata", + }, + { + "name": "batch_mixed_lengths", + "batch_size": 8, + "max_seq_len": 128, + "page_size": 32, + "description": "Mixed sequence lengths", + }, + { + "name": "large_batch", + "batch_size": 32, + "max_seq_len": 256, + "page_size": 64, + "description": "Large batch stress test", + }, + { + "name": "edge_case_short", + "batch_size": 4, + "max_seq_len": 16, + "page_size": 32, + "description": "Sub-page sequences", + }, + ], +} + + +class MockModelRunner: + """Minimal fake ModelRunner for testing MLA backends.""" + + def __init__(self, config): + self.device = config["device"] + self.dtype = config["dtype"] + self.kv_cache_dtype = config["kv_cache_dtype"] + self.page_size = config["page_size"] + + # Model-config stub with MLA attributes + self.model_config = type( + "ModelConfig", + (), + { + "context_len": config["context_len"], + "attention_arch": AttentionArch.MLA, + "num_attention_heads": config["num_attention_heads"], + "kv_lora_rank": config["kv_lora_rank"], + "qk_nope_head_dim": config["qk_nope_head_dim"], + "qk_rope_head_dim": config["qk_rope_head_dim"], + "v_head_dim": config["v_head_dim"], + "scaling": 1.0 + / ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5), + "get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]), + }, + ) + + # Req-to-token pool + max_bs = config["max_bs"] + max_ctx = self.model_config.context_len + req_to_token = torch.arange( + max_bs * max_ctx, dtype=torch.int32, device=self.device + ).reshape(max_bs, max_ctx) + self.req_to_token_pool = type( + "TokenPool", + (), + { + "size": max_bs, + "req_to_token": req_to_token, + }, + ) + + # KV-token pool (MLA) + self.token_to_kv_pool = MLATokenToKVPool( + size=max_bs * max_ctx, + page_size=config["page_size"], + dtype=self.kv_cache_dtype, + kv_lora_rank=config["kv_lora_rank"], + qk_rope_head_dim=config["qk_rope_head_dim"], + layer_num=1, + device=self.device, + enable_memory_saver=False, + ) + + +def compare_outputs(trtllm_out, reference_out, tolerance=1e-2): + """Compare outputs with detailed analysis.""" + + # Basic checks + assert ( + trtllm_out.shape == reference_out.shape + ), f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}" + assert ( + trtllm_out.dtype == reference_out.dtype + ), f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}" + + # Check for NaN/Inf + assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN" + assert not torch.isnan(reference_out).any(), "Reference output contains NaN" + assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf" + assert not torch.isinf(reference_out).any(), "Reference output contains Inf" + + # Element-wise differences + diff = (trtllm_out - reference_out).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + # Check numerical equivalence + all_close = torch.allclose( + trtllm_out, reference_out, rtol=tolerance, atol=tolerance + ) + + if not all_close: + print( + f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}" + ) + # Find top differences for debugging + flat_diff = diff.flatten() + top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices + print("Top 5 differences:") + for i, idx in enumerate(top_diff_indices): + idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape) + trt_val = trtllm_out[idx_tuple].item() + ref_val = reference_out[idx_tuple].item() + print( + f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}" + ) + + return all_close + + +@unittest.skipIf( + not torch.cuda.is_available() or not is_flashinfer_available(), + "CUDA + flashinfer required", +) +class TestTRTLLMMLA(CustomTestCase): + """Test suite for TRTLLM MLA backend with centralized configuration.""" + + def _merge_config(self, test_case): + """Merge test case with default configuration.""" + config = DEFAULT_CONFIG.copy() + config.update(test_case) + return config + + def _create_model_components(self, config): + """Create model runners, backends, and layer for testing.""" + # Create model runners + model_runner_trtllm = MockModelRunner(config) + model_runner_reference = MockModelRunner(config) + + # Create backends + trtllm_backend = TRTLLMMLABackend(model_runner_trtllm) + reference_backend = FlashInferMLAAttnBackend(model_runner_reference) + + # Create RadixAttention layer + layer = RadixAttention( + num_heads=config["num_attention_heads"], + head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"], + scaling=model_runner_trtllm.model_config.scaling, + num_kv_heads=config["num_kv_heads"], + layer_id=config["layer_id"], + v_head_dim=config["v_head_dim"], + prefix="attn_mqa", + ) + + return ( + model_runner_trtllm, + model_runner_reference, + trtllm_backend, + reference_backend, + layer, + ) + + def _create_qkv_tensors(self, batch_size, config): + """Create Q, K, V tensors for testing.""" + head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] + device = config["device"] + dtype = config["dtype"] + + q = torch.randn( + (batch_size, config["num_attention_heads"], head_dim), + dtype=dtype, + device=device, + ) + k = torch.randn( + (batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device + ) + v = torch.randn( + (batch_size, config["num_kv_heads"], config["v_head_dim"]), + dtype=dtype, + device=device, + ) + return q, k, v + + def _create_forward_batch( + self, batch_size, seq_lens, backend, model_runner, config + ): + """Create a forward batch for the given backend.""" + fb = ForwardBatch( + batch_size=batch_size, + input_ids=torch.randint(0, 100, (batch_size, 1), device=config["device"]), + out_cache_loc=torch.arange(batch_size, device=config["device"]), + seq_lens_sum=int(seq_lens.sum().item()), + forward_mode=ForwardMode.DECODE, + req_pool_indices=torch.arange(batch_size, device=config["device"]), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + attn_backend=backend, + ) + fb.req_to_token_pool = model_runner.req_to_token_pool + fb.token_to_kv_pool = model_runner.token_to_kv_pool + return fb + + def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config): + """Populate KV cache with identical data for both backends.""" + torch.manual_seed(config["seed_cache"]) # Fixed seed for reproducible cache + + for model_runner in model_runners: + torch.manual_seed(config["seed_cache"]) # Reset seed for each backend + for i in range(batch_size): + seq_len = int(seq_lens[i].item()) + for token_idx in range(seq_len - 1): + # Create random K components for MLA + cache_k_nope = torch.randn( + (1, config["qk_nope_head_dim"]), + dtype=config["dtype"], + device=config["device"], + ) + cache_k_rope = torch.randn( + (1, config["qk_rope_head_dim"]), + dtype=config["dtype"], + device=config["device"], + ) + + # Calculate cache location + cache_loc = model_runner.req_to_token_pool.req_to_token[ + i, token_idx + ] + + # Save to KV cache + model_runner.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc.unsqueeze(0), + cache_k_nope.squeeze(0), + cache_k_rope.squeeze(0), + ) + + def test_basic_functionality(self): + """Test basic functionality with minimal setup.""" + print(f"\nRunning basic functionality tests...") + + for test_case in TEST_CASES["basic_functionality"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + model_runner_trtllm, _, trtllm_backend, _, layer = ( + self._create_model_components(config) + ) + + # Create sequence lengths - properly handle different batch sizes + if batch_size == 2: + seq_lens = torch.tensor( + [max_seq_len, max_seq_len // 2], device=config["device"] + ) + else: + # For larger batch sizes, create varied sequence lengths + torch.manual_seed(config["seed_cache"]) + seq_lens = torch.randint( + max_seq_len // 2, + max_seq_len + 1, + (batch_size,), + device=config["device"], + ) + seq_lens[0] = max_seq_len # Ensure at least one max length + + # Create forward batch + fb = self._create_forward_batch( + batch_size, seq_lens, trtllm_backend, model_runner_trtllm, config + ) + trtllm_backend.init_forward_metadata(fb) + + # Populate KV cache + self._populate_kv_cache( + batch_size, seq_lens, [model_runner_trtllm], layer, config + ) + + # Create Q, K, V tensors + torch.manual_seed(config["seed_qkv"]) + q, k, v = self._create_qkv_tensors(batch_size, config) + + # Run forward decode + output = trtllm_backend.forward_decode(q, k, v, layer, fb) + + # Basic checks + expected_shape = ( + batch_size, + config["num_attention_heads"] * config["v_head_dim"], + ) + self.assertEqual(output.shape, expected_shape) + self.assertEqual(output.dtype, config["dtype"]) + self.assertFalse(torch.isnan(output).any()) + self.assertFalse(torch.isinf(output).any()) + + def test_decode_output_match(self): + """Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" + print(f"\nRunning decode output matching tests...") + + for test_case in TEST_CASES["decode_output_match"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + ( + model_runner_trtllm, + model_runner_reference, + trtllm_backend, + reference_backend, + layer, + ) = self._create_model_components(config) + + # Create identical sequence lengths for both backends + torch.manual_seed(config["seed_cache"]) + seq_lens = torch.randint( + 1, max_seq_len, (batch_size,), device=config["device"] + ) + seq_lens[0] = max_seq_len # Ensure at least one max length + + # Create forward batches with identical inputs + fb_trtllm = self._create_forward_batch( + batch_size, + seq_lens.clone(), + trtllm_backend, + model_runner_trtllm, + config, + ) + fb_reference = self._create_forward_batch( + batch_size, + seq_lens.clone(), + reference_backend, + model_runner_reference, + config, + ) + + # Initialize metadata for both backends + trtllm_backend.init_forward_metadata(fb_trtllm) + reference_backend.init_forward_metadata(fb_reference) + + # Populate both KV caches identically + self._populate_kv_cache( + batch_size, + seq_lens, + [model_runner_trtllm, model_runner_reference], + layer, + config, + ) + + # Create Q, K, V tensors for current decode step + torch.manual_seed(config["seed_qkv"]) + q, k, v = self._create_qkv_tensors(batch_size, config) + + # Run forward decode on both backends + out_trtllm = trtllm_backend.forward_decode( + q.clone(), k.clone(), v.clone(), layer, fb_trtllm + ) + out_reference = reference_backend.forward_decode( + q.clone(), k.clone(), v.clone(), layer, fb_reference + ) + + # Compare outputs + comparison_passed = compare_outputs( + out_trtllm, out_reference, tolerance=config["tolerance"] + ) + + self.assertTrue( + comparison_passed, + f"TRTLLM and Reference outputs differ beyond tolerance. " + f"Config: {test_case['name']}, " + f"Max diff: {(out_trtllm - out_reference).abs().max().item()}", + ) + + def test_page_size_consistency(self): + """Test output consistency across different page sizes.""" + print(f"\nRunning page size consistency tests...") + + for test_case in TEST_CASES["page_size_consistency"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + model_runner, _, backend, _, layer = self._create_model_components( + config + ) + + # Create sequence lengths + torch.manual_seed(config["seed_cache"]) + seq_lens = torch.randint( + 1, max_seq_len, (batch_size,), device=config["device"] + ) + seq_lens[0] = max_seq_len + + # Create forward batch + fb = self._create_forward_batch( + batch_size, seq_lens, backend, model_runner, config + ) + backend.init_forward_metadata(fb) + + # Populate KV cache + self._populate_kv_cache( + batch_size, seq_lens, [model_runner], layer, config + ) + + # Create Q, K, V tensors + torch.manual_seed(config["seed_qkv"]) + q, k, v = self._create_qkv_tensors(batch_size, config) + + # Run forward decode + output = backend.forward_decode(q, k, v, layer, fb) + + expected_shape = ( + batch_size, + config["num_attention_heads"] * config["v_head_dim"], + ) + self.assertEqual( + output.shape, + expected_shape, + f"Output shape mismatch: {output.shape} vs {expected_shape}", + ) + self.assertFalse(torch.isnan(output).any(), "Output contains NaN") + self.assertFalse(torch.isinf(output).any(), "Output contains Inf") + + def test_shape_sanity(self): + """Smoke test decode across several configurations.""" + print(f"\nRunning shape sanity tests...") + + for test_case in TEST_CASES["shape_sanity_tests"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + model_runner, _, backend, _, layer = self._create_model_components( + config + ) + + # Random seq lens (ensure one matches max) + torch.manual_seed(config["seed_cache"]) + seq_lens = torch.randint( + 1, max_seq_len, (batch_size,), device=config["device"] + ) + seq_lens[0] = max_seq_len + + fb = self._create_forward_batch( + batch_size, seq_lens, backend, model_runner, config + ) + backend.init_forward_metadata(fb) + + # Create Q, K, V tensors + torch.manual_seed(config["seed_qkv"]) + head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] + q = torch.randn( + (batch_size, config["num_attention_heads"], head_dim), + dtype=config["dtype"], + device=config["device"], + ) + k = torch.randn( + (batch_size, config["num_kv_heads"], head_dim), + dtype=config["dtype"], + device=config["device"], + ) + v = None + + # Run forward decode + output = backend.forward_decode(q, k, v, layer, fb) + + # Shape and sanity checks + expected_shape = ( + batch_size, + config["num_attention_heads"] * config["v_head_dim"], + ) + self.assertEqual( + output.shape, + expected_shape, + f"Output shape mismatch for {test_case['name']}", + ) + self.assertEqual(output.dtype, config["dtype"]) + self.assertEqual(output.device.type, "cuda") + self.assertFalse( + torch.isnan(output).any(), + f"Output contains NaN for {test_case['name']}", + ) + self.assertFalse( + torch.isinf(output).any(), + f"Output contains Inf for {test_case['name']}", + ) + + def test_metadata_initialization(self): + """Test TRTLLM MLA metadata initialization and structure.""" + print(f"\nRunning metadata initialization tests...") + + for test_case in TEST_CASES["metadata_tests"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + model_runner, _, backend, _, layer = self._create_model_components( + config + ) + + # Create varied sequence lengths + torch.manual_seed(config["seed_cache"]) + if batch_size == 1: + seq_lens = torch.tensor([max_seq_len], device=config["device"]) + else: + seq_lens = torch.randint( + max(1, max_seq_len // 4), + max_seq_len + 1, + (batch_size,), + device=config["device"], + ) + seq_lens[0] = max_seq_len # Ensure at least one max length + + # Create forward batch + fb = self._create_forward_batch( + batch_size, seq_lens, backend, model_runner, config + ) + + # Initialize metadata + backend.init_forward_metadata(fb) + + # Verify metadata exists + self.assertIsNotNone(backend.forward_metadata) + self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata) + + # Test metadata structure + metadata = backend.forward_metadata + self.assertIsNotNone( + metadata.workspace, "Workspace should be allocated" + ) + self.assertIsNotNone( + metadata.block_kv_indices, "Block KV indices should be created" + ) + + # Test workspace properties + self.assertEqual(metadata.workspace.device.type, "cuda") + self.assertEqual(metadata.workspace.dtype, torch.int8) + self.assertGreater( + metadata.workspace.numel(), 0, "Workspace should have non-zero size" + ) + + # Test block KV indices properties + self.assertEqual(metadata.block_kv_indices.device.type, "cuda") + self.assertEqual(metadata.block_kv_indices.dtype, torch.int32) + self.assertEqual(metadata.block_kv_indices.shape[0], batch_size) + + # Verify block indices are valid (>= -1, since -1 is padding) + self.assertTrue( + (metadata.block_kv_indices >= -1).all(), + "All block indices should be >= -1 (with -1 as padding)", + ) + + def test_metadata_block_calculation(self): + """Test block count calculation logic.""" + print(f"\nRunning metadata block calculation tests...") + + test_scenarios = [ + {"seq_len": 31, "page_size": 32, "expected_min_blocks": 1}, + {"seq_len": 32, "page_size": 32, "expected_min_blocks": 1}, + {"seq_len": 33, "page_size": 32, "expected_min_blocks": 2}, + {"seq_len": 128, "page_size": 32, "expected_min_blocks": 4}, + {"seq_len": 128, "page_size": 64, "expected_min_blocks": 2}, + ] + + for scenario in test_scenarios: + with self.subTest(scenario=scenario): + config = self._merge_config( + { + "batch_size": 1, + "max_seq_len": scenario["seq_len"], + "page_size": scenario["page_size"], + } + ) + + model_runner, _, backend, _, _ = self._create_model_components(config) + + # Test internal block calculation + calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"]) + + # Should be at least the minimum required + self.assertGreaterEqual( + calculated_blocks, + scenario["expected_min_blocks"], + f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})", + ) + + # Should satisfy page_size constraint + total_tokens = calculated_blocks * scenario["page_size"] + self.assertGreaterEqual( + total_tokens, + scenario["seq_len"], + f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})", + ) + + # Should satisfy TRT-LLM and Triton constraints + trtllm_constraint = 128 // scenario["page_size"] + constraint_lcm = math.lcm( + trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK + ) + self.assertEqual( + calculated_blocks % constraint_lcm, + 0, + f"Block count should be multiple of LCM of constraints ({constraint_lcm})", + ) + + def test_metadata_kv_indices_correctness(self): + """Test KV indices creation and correctness.""" + print(f"\nRunning KV indices correctness tests...") + + for test_case in TEST_CASES["metadata_tests"][ + :2 + ]: # Test subset for performance + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + model_runner, _, backend, _, layer = self._create_model_components( + config + ) + + # Create known sequence lengths + torch.manual_seed(config["seed_cache"]) + if batch_size == 1: + seq_lens = torch.tensor([max_seq_len], device=config["device"]) + else: + seq_lens = torch.randint( + max_seq_len // 2, + max_seq_len + 1, + (batch_size,), + device=config["device"], + ) + + fb = self._create_forward_batch( + batch_size, seq_lens, backend, model_runner, config + ) + + # Populate some KV cache to have valid indices + self._populate_kv_cache( + batch_size, seq_lens, [model_runner], layer, config + ) + + # Initialize metadata + backend.init_forward_metadata(fb) + metadata = backend.forward_metadata + + # Verify KV indices structure + block_kv_indices = metadata.block_kv_indices + + for i in range(batch_size): + seq_len = seq_lens[i].item() + expected_blocks = backend._calc_padded_blocks(seq_len) + + # Count valid (non -1) indices for this sequence + valid_indices = (block_kv_indices[i] >= 0).sum().item() + + # Should have at least enough blocks for the sequence + min_required_blocks = (seq_len + config["page_size"] - 1) // config[ + "page_size" + ] + self.assertGreaterEqual( + valid_indices, + min_required_blocks, + f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}", + ) + + # Verify indices are within valid range + valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0] + if len(valid_block_indices) > 0: + max_possible_blocks = ( + model_runner.token_to_kv_pool.size // config["page_size"] + ) + self.assertTrue( + (valid_block_indices < max_possible_blocks).all(), + f"All block indices should be < {max_possible_blocks}", + ) + + def test_metadata_cuda_graph_compatibility(self): + """Test metadata compatibility with CUDA graph capture/replay.""" + print(f"\nRunning CUDA graph compatibility tests...") + + config = self._merge_config( + {"batch_size": 4, "max_seq_len": 64, "page_size": 32} + ) + + model_runner, _, backend, _, layer = self._create_model_components(config) + batch_size = config["batch_size"] + + # Initialize CUDA graph state + backend.init_cuda_graph_state( + max_bs=batch_size, max_num_tokens=config["max_seq_len"] * batch_size + ) + + # Verify CUDA graph buffers are allocated + self.assertIsNotNone(backend.cuda_graph_kv_indices) + self.assertIsNotNone(backend.cuda_graph_workspace) + + # Test capture metadata + seq_lens = torch.full( + (batch_size,), config["max_seq_len"], device=config["device"] + ) + req_pool_indices = torch.arange(batch_size, device=config["device"]) + + backend.init_forward_metadata_capture_cuda_graph( + bs=batch_size, + num_tokens=batch_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=None, + ) + + # Verify capture metadata + self.assertIn(batch_size, backend.decode_cuda_graph_metadata) + capture_metadata = backend.decode_cuda_graph_metadata[batch_size] + + self.assertIsNotNone(capture_metadata.workspace) + self.assertIsNotNone(capture_metadata.block_kv_indices) + + # Test replay with different sequence lengths + new_seq_lens = torch.randint( + config["max_seq_len"] // 2, + config["max_seq_len"] + 1, + (batch_size,), + device=config["device"], + ) + + backend.init_forward_metadata_replay_cuda_graph( + bs=batch_size, + req_pool_indices=req_pool_indices, + seq_lens=new_seq_lens, + seq_lens_sum=new_seq_lens.sum().item(), + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=None, + seq_lens_cpu=new_seq_lens.cpu(), + ) + + # Verify replay updated the metadata + replay_metadata = backend.forward_metadata + self.assertIsNotNone(replay_metadata) + self.assertEqual( + replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr() + ) + + def test_metadata_consistency_across_calls(self): + """Test metadata consistency across multiple forward calls.""" + print(f"\nRunning metadata consistency tests...") + + config = self._merge_config( + {"batch_size": 2, "max_seq_len": 64, "page_size": 32} + ) + + model_runner, _, backend, _, layer = self._create_model_components(config) + + # First call + seq_lens_1 = torch.tensor([32, 48], device=config["device"]) + fb_1 = self._create_forward_batch( + config["batch_size"], seq_lens_1, backend, model_runner, config + ) + backend.init_forward_metadata(fb_1) + metadata_1 = backend.forward_metadata + + # Second call with same sequence lengths + seq_lens_2 = torch.tensor([32, 48], device=config["device"]) + fb_2 = self._create_forward_batch( + config["batch_size"], seq_lens_2, backend, model_runner, config + ) + backend.init_forward_metadata(fb_2) + metadata_2 = backend.forward_metadata + + # Metadata structure should be consistent + self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape) + self.assertEqual( + metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape + ) + + # Third call with different sequence lengths + seq_lens_3 = torch.tensor([16, 64], device=config["device"]) + fb_3 = self._create_forward_batch( + config["batch_size"], seq_lens_3, backend, model_runner, config + ) + backend.init_forward_metadata(fb_3) + metadata_3 = backend.forward_metadata + + # Should still have valid structure + self.assertIsNotNone(metadata_3.workspace) + self.assertIsNotNone(metadata_3.block_kv_indices) + self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"]) + + +if __name__ == "__main__": + unittest.main()