Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
372
python/sglang/srt/layers/attention/trtllm_mla_backend.py
Executable file
372
python/sglang/srt/layers/attention/trtllm_mla_backend.py
Executable file
@@ -0,0 +1,372 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Support attention backend for TRTLLM MLA kernels from flashinfer.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||
from sglang.srt.layers.attention.utils import (
|
||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
create_flashmla_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
import flashinfer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
# Constants
|
||||
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
||||
|
||||
# Block constraint from flashinfer requirements
|
||||
# From flashinfer.decode._check_trtllm_gen_mla_shape:
|
||||
# block_num % (128 / block_size) == 0
|
||||
# This imposes that the total number of blocks must be divisible by
|
||||
# (128 / block_size). We capture the 128 constant here so we can
|
||||
# compute the LCM with other padding constraints.
|
||||
TRTLLM_BLOCK_CONSTRAINT = 128
|
||||
|
||||
|
||||
@dataclass
|
||||
class TRTLLMMLADecodeMetadata:
|
||||
"""Metadata for TRTLLM MLA decode operations."""
|
||||
|
||||
workspace: Optional[torch.Tensor] = None
|
||||
block_kv_indices: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
"""TRTLLM MLA attention kernel from flashinfer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
||||
|
||||
config = model_runner.model_config
|
||||
|
||||
# Model parameters
|
||||
self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
|
||||
self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
|
||||
self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
|
||||
|
||||
# MLA-specific dimensions
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
||||
|
||||
# Runtime parameters
|
||||
self.scaling = config.scaling
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.page_size = model_runner.page_size
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
# Workspace allocation
|
||||
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
||||
self.workspace_buffer = torch.empty(
|
||||
self.workspace_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
|
||||
# CUDA graph state
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.cuda_graph_kv_indices = None
|
||||
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||
|
||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||
"""
|
||||
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
||||
|
||||
Args:
|
||||
max_seq_len: Maximum sequence length in tokens
|
||||
|
||||
Returns:
|
||||
Number of blocks padded to satisfy all constraints
|
||||
"""
|
||||
blocks = triton.cdiv(max_seq_len, self.page_size)
|
||||
|
||||
# Apply dual constraints (take LCM to satisfy both):
|
||||
# 1. TRT-LLM: block_num % (128 / page_size) == 0
|
||||
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
|
||||
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
|
||||
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
|
||||
|
||||
if blocks % constraint_lcm != 0:
|
||||
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
|
||||
return blocks
|
||||
|
||||
def _create_block_kv_indices(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_blocks: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create block KV indices tensor using Triton kernel.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size
|
||||
max_blocks: Maximum number of blocks per sequence
|
||||
req_pool_indices: Request pool indices
|
||||
seq_lens: Sequence lengths
|
||||
device: Target device
|
||||
|
||||
Returns:
|
||||
Block KV indices tensor
|
||||
"""
|
||||
block_kv_indices = torch.full(
|
||||
(batch_size, max_blocks), -1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
create_flashmla_kv_indices_triton[(batch_size,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
block_kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
max_blocks,
|
||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
return block_kv_indices
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Initialize CUDA graph state for TRTLLM MLA."""
|
||||
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
||||
|
||||
self.cuda_graph_kv_indices = torch.full(
|
||||
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.cuda_graph_workspace = torch.empty(
|
||||
self.workspace_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
||||
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
||||
return super().init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
# Custom fast-path for decode/idle without speculative execution.
|
||||
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
||||
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
||||
|
||||
create_flashmla_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
block_kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
max_seqlen_pad,
|
||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Replay CUDA graph with new inputs."""
|
||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
||||
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
||||
return super().init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
seq_lens_cpu,
|
||||
)
|
||||
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
# Update block indices for new sequences.
|
||||
create_flashmla_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
None,
|
||||
metadata.block_kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
metadata.block_kv_indices.shape[1],
|
||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
||||
"""Get the fill value for sequence lengths in CUDA graph."""
|
||||
return 1
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize the metadata for a forward pass."""
|
||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode_or_idle()
|
||||
and forward_batch.spec_info is None
|
||||
):
|
||||
return super().init_forward_metadata(forward_batch)
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
# Get maximum sequence length.
|
||||
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
||||
max_seq = forward_batch.seq_lens_cpu.max().item()
|
||||
else:
|
||||
max_seq = forward_batch.seq_lens.max().item()
|
||||
|
||||
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
||||
block_kv_indices = self._create_block_kv_indices(
|
||||
bs,
|
||||
max_seqlen_pad,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens.device,
|
||||
)
|
||||
|
||||
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
||||
self.workspace_buffer, block_kv_indices
|
||||
)
|
||||
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run forward for decode using TRTLLM MLA kernel."""
|
||||
# Save KV cache if requested
|
||||
if k is not None and save_kv_cache:
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
if k_rope is not None:
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer, cache_loc, k, k_rope
|
||||
)
|
||||
elif v is not None:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
|
||||
# Prepare query tensor inline
|
||||
if q_rope is not None:
|
||||
# q contains NOPE part (v_head_dim)
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope_reshaped = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
||||
else:
|
||||
# q already has both parts
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
||||
if query.dim() == 3:
|
||||
query = query.unsqueeze(1)
|
||||
|
||||
# Prepare KV cache inline
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
|
||||
# TRT-LLM expects single KV data with extra dimension
|
||||
kv_cache = pages.unsqueeze(1)
|
||||
|
||||
# Get metadata
|
||||
metadata = (
|
||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||
or self.forward_metadata
|
||||
)
|
||||
|
||||
# Scale computation for TRTLLM MLA kernel:
|
||||
# - BMM1 scale = q_scale * k_scale * softmax_scale
|
||||
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
|
||||
# - k_scale is read from model checkpoint if available
|
||||
# TODO: Change once fp8 path is supported
|
||||
q_scale = 1.0
|
||||
k_scale = (
|
||||
layer.k_scale_float
|
||||
if getattr(layer, "k_scale_float", None) is not None
|
||||
else 1.0
|
||||
)
|
||||
|
||||
bmm1_scale = q_scale * k_scale * layer.scaling
|
||||
|
||||
# Call TRT-LLM kernel
|
||||
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=metadata.workspace,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=metadata.block_kv_indices,
|
||||
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
||||
max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
|
||||
bmm1_scale=bmm1_scale,
|
||||
)
|
||||
|
||||
# Extract value projection part and reshape
|
||||
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
|
||||
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
return output
|
||||
@@ -1,6 +1,11 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
|
||||
# Number of pages that the kernel writes per iteration.
|
||||
# Exposed here so other Python modules can import it instead of hard-coding 64.
|
||||
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
|
||||
kv_indices_ptr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
kv_indices_ptr_stride: tl.constexpr,
|
||||
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
PAGED_SIZE: tl.constexpr = 64,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 4096
|
||||
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# find the req pool idx, this is for batch to token
|
||||
|
||||
@@ -436,6 +436,7 @@ class ModelRunner:
|
||||
"triton",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
"trtllm_mla",
|
||||
"ascend",
|
||||
]:
|
||||
logger.info(
|
||||
@@ -1437,6 +1438,12 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return CutlassMLABackend(self)
|
||||
elif self.server_args.attention_backend == "trtllm_mla":
|
||||
if not self.use_mla_backend:
|
||||
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
||||
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
||||
|
||||
return TRTLLMMLABackend(self)
|
||||
elif self.server_args.attention_backend == "intel_amx":
|
||||
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||
IntelAMXAttnBackend,
|
||||
|
||||
@@ -1259,6 +1259,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.current_attention_backend == "fa3"
|
||||
or self.current_attention_backend == "flashinfer"
|
||||
or self.current_attention_backend == "cutlass_mla"
|
||||
or self.current_attention_backend == "trtllm_mla"
|
||||
):
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||
|
||||
@@ -24,6 +24,7 @@ import tempfile
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.lora.lora_registry import LoRARef
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.utils import (
|
||||
@@ -402,6 +403,22 @@ class ServerArgs:
|
||||
)
|
||||
self.page_size = 128
|
||||
|
||||
if self.attention_backend == "trtllm_mla":
|
||||
if not is_sm100_supported():
|
||||
raise ValueError(
|
||||
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
||||
)
|
||||
|
||||
if self.page_size not in [32, 64]:
|
||||
logger.warning(
|
||||
f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
|
||||
)
|
||||
self.page_size = 64
|
||||
if self.speculative_algorithm is not None:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
# Set page size
|
||||
if self.page_size is None:
|
||||
self.page_size = 1
|
||||
@@ -1225,6 +1242,7 @@ class ServerArgs:
|
||||
"torch_native",
|
||||
"ascend",
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
|
||||
945
python/sglang/test/attention/test_trtllm_mla_backend.py
Executable file
945
python/sglang/test/attention/test_trtllm_mla_backend.py
Executable file
@@ -0,0 +1,945 @@
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers import dp_attention as _dp_attn
|
||||
|
||||
# Patch DP-attention globals before importing backends
|
||||
# TODO: change the interface of both trtllm_mla and flashinfer backends to take tp_size as an argument instead of patching
|
||||
_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
||||
TRTLLMMLABackend,
|
||||
TRTLLMMLADecodeMetadata,
|
||||
)
|
||||
from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
# Global configuration for all tests
|
||||
DEFAULT_CONFIG = {
|
||||
"device": "cuda",
|
||||
"dtype": torch.bfloat16,
|
||||
"kv_cache_dtype": torch.bfloat16,
|
||||
"context_len": 2048,
|
||||
"max_bs": 64,
|
||||
"tolerance": 1e-2,
|
||||
"seed_cache": 42,
|
||||
"seed_qkv": 123,
|
||||
# MLA model config (TRTLLM MLA has fixed constraints)
|
||||
"num_attention_heads": 128,
|
||||
"kv_lora_rank": 512,
|
||||
"qk_nope_head_dim": 128,
|
||||
"qk_rope_head_dim": 64,
|
||||
"v_head_dim": 512,
|
||||
"num_kv_heads": 1,
|
||||
"layer_id": 0,
|
||||
}
|
||||
|
||||
# Centralized test cases for different test scenarios
|
||||
TEST_CASES = {
|
||||
"basic_functionality": [
|
||||
{
|
||||
"name": "single",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 32,
|
||||
"page_size": 32,
|
||||
"description": "Minimal smoke test",
|
||||
},
|
||||
{
|
||||
"name": "batch",
|
||||
"batch_size": 32,
|
||||
"max_seq_len": 128,
|
||||
"page_size": 32,
|
||||
"description": "Medium-scale batch",
|
||||
},
|
||||
],
|
||||
"decode_output_match": [
|
||||
{
|
||||
"name": "single",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 32,
|
||||
"description": "Single vs reference",
|
||||
},
|
||||
{
|
||||
"name": "batch",
|
||||
"batch_size": 32,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 32,
|
||||
"description": "Batch vs reference",
|
||||
},
|
||||
],
|
||||
"page_size_consistency": [
|
||||
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
|
||||
{
|
||||
"name": "page_32",
|
||||
"batch_size": 8,
|
||||
"max_seq_len": 128,
|
||||
"page_size": 32,
|
||||
"description": "32-token pages",
|
||||
},
|
||||
{
|
||||
"name": "page_64",
|
||||
"batch_size": 8,
|
||||
"max_seq_len": 128,
|
||||
"page_size": 64,
|
||||
"description": "64-token pages",
|
||||
},
|
||||
],
|
||||
"shape_sanity_tests": [
|
||||
{
|
||||
"name": "basic",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 128,
|
||||
"page_size": 32,
|
||||
"description": "Single sequence",
|
||||
},
|
||||
{
|
||||
"name": "basic_different_pagesize",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 128,
|
||||
"page_size": 64,
|
||||
"description": "Different page size",
|
||||
},
|
||||
{
|
||||
"name": "batch",
|
||||
"batch_size": 8,
|
||||
"max_seq_len": 128,
|
||||
"page_size": 32,
|
||||
"description": "Batch shapes",
|
||||
},
|
||||
],
|
||||
"metadata_tests": [
|
||||
{
|
||||
"name": "single_sequence",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 32,
|
||||
"description": "Single sequence metadata",
|
||||
},
|
||||
{
|
||||
"name": "batch_mixed_lengths",
|
||||
"batch_size": 8,
|
||||
"max_seq_len": 128,
|
||||
"page_size": 32,
|
||||
"description": "Mixed sequence lengths",
|
||||
},
|
||||
{
|
||||
"name": "large_batch",
|
||||
"batch_size": 32,
|
||||
"max_seq_len": 256,
|
||||
"page_size": 64,
|
||||
"description": "Large batch stress test",
|
||||
},
|
||||
{
|
||||
"name": "edge_case_short",
|
||||
"batch_size": 4,
|
||||
"max_seq_len": 16,
|
||||
"page_size": 32,
|
||||
"description": "Sub-page sequences",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class MockModelRunner:
|
||||
"""Minimal fake ModelRunner for testing MLA backends."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.device = config["device"]
|
||||
self.dtype = config["dtype"]
|
||||
self.kv_cache_dtype = config["kv_cache_dtype"]
|
||||
self.page_size = config["page_size"]
|
||||
|
||||
# Model-config stub with MLA attributes
|
||||
self.model_config = type(
|
||||
"ModelConfig",
|
||||
(),
|
||||
{
|
||||
"context_len": config["context_len"],
|
||||
"attention_arch": AttentionArch.MLA,
|
||||
"num_attention_heads": config["num_attention_heads"],
|
||||
"kv_lora_rank": config["kv_lora_rank"],
|
||||
"qk_nope_head_dim": config["qk_nope_head_dim"],
|
||||
"qk_rope_head_dim": config["qk_rope_head_dim"],
|
||||
"v_head_dim": config["v_head_dim"],
|
||||
"scaling": 1.0
|
||||
/ ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5),
|
||||
"get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Req-to-token pool
|
||||
max_bs = config["max_bs"]
|
||||
max_ctx = self.model_config.context_len
|
||||
req_to_token = torch.arange(
|
||||
max_bs * max_ctx, dtype=torch.int32, device=self.device
|
||||
).reshape(max_bs, max_ctx)
|
||||
self.req_to_token_pool = type(
|
||||
"TokenPool",
|
||||
(),
|
||||
{
|
||||
"size": max_bs,
|
||||
"req_to_token": req_to_token,
|
||||
},
|
||||
)
|
||||
|
||||
# KV-token pool (MLA)
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
size=max_bs * max_ctx,
|
||||
page_size=config["page_size"],
|
||||
dtype=self.kv_cache_dtype,
|
||||
kv_lora_rank=config["kv_lora_rank"],
|
||||
qk_rope_head_dim=config["qk_rope_head_dim"],
|
||||
layer_num=1,
|
||||
device=self.device,
|
||||
enable_memory_saver=False,
|
||||
)
|
||||
|
||||
|
||||
def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
|
||||
"""Compare outputs with detailed analysis."""
|
||||
|
||||
# Basic checks
|
||||
assert (
|
||||
trtllm_out.shape == reference_out.shape
|
||||
), f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}"
|
||||
assert (
|
||||
trtllm_out.dtype == reference_out.dtype
|
||||
), f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}"
|
||||
|
||||
# Check for NaN/Inf
|
||||
assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN"
|
||||
assert not torch.isnan(reference_out).any(), "Reference output contains NaN"
|
||||
assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf"
|
||||
assert not torch.isinf(reference_out).any(), "Reference output contains Inf"
|
||||
|
||||
# Element-wise differences
|
||||
diff = (trtllm_out - reference_out).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
|
||||
# Check numerical equivalence
|
||||
all_close = torch.allclose(
|
||||
trtllm_out, reference_out, rtol=tolerance, atol=tolerance
|
||||
)
|
||||
|
||||
if not all_close:
|
||||
print(
|
||||
f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}"
|
||||
)
|
||||
# Find top differences for debugging
|
||||
flat_diff = diff.flatten()
|
||||
top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices
|
||||
print("Top 5 differences:")
|
||||
for i, idx in enumerate(top_diff_indices):
|
||||
idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape)
|
||||
trt_val = trtllm_out[idx_tuple].item()
|
||||
ref_val = reference_out[idx_tuple].item()
|
||||
print(
|
||||
f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}"
|
||||
)
|
||||
|
||||
return all_close
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available() or not is_flashinfer_available(),
|
||||
"CUDA + flashinfer required",
|
||||
)
|
||||
class TestTRTLLMMLA(CustomTestCase):
|
||||
"""Test suite for TRTLLM MLA backend with centralized configuration."""
|
||||
|
||||
def _merge_config(self, test_case):
|
||||
"""Merge test case with default configuration."""
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config.update(test_case)
|
||||
return config
|
||||
|
||||
def _create_model_components(self, config):
|
||||
"""Create model runners, backends, and layer for testing."""
|
||||
# Create model runners
|
||||
model_runner_trtllm = MockModelRunner(config)
|
||||
model_runner_reference = MockModelRunner(config)
|
||||
|
||||
# Create backends
|
||||
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
||||
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
||||
|
||||
# Create RadixAttention layer
|
||||
layer = RadixAttention(
|
||||
num_heads=config["num_attention_heads"],
|
||||
head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
|
||||
scaling=model_runner_trtllm.model_config.scaling,
|
||||
num_kv_heads=config["num_kv_heads"],
|
||||
layer_id=config["layer_id"],
|
||||
v_head_dim=config["v_head_dim"],
|
||||
prefix="attn_mqa",
|
||||
)
|
||||
|
||||
return (
|
||||
model_runner_trtllm,
|
||||
model_runner_reference,
|
||||
trtllm_backend,
|
||||
reference_backend,
|
||||
layer,
|
||||
)
|
||||
|
||||
def _create_qkv_tensors(self, batch_size, config):
|
||||
"""Create Q, K, V tensors for testing."""
|
||||
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
||||
device = config["device"]
|
||||
dtype = config["dtype"]
|
||||
|
||||
q = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
k = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
|
||||
)
|
||||
v = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
return q, k, v
|
||||
|
||||
def _create_forward_batch(
|
||||
self, batch_size, seq_lens, backend, model_runner, config
|
||||
):
|
||||
"""Create a forward batch for the given backend."""
|
||||
fb = ForwardBatch(
|
||||
batch_size=batch_size,
|
||||
input_ids=torch.randint(0, 100, (batch_size, 1), device=config["device"]),
|
||||
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
||||
seq_lens_sum=int(seq_lens.sum().item()),
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
req_pool_indices=torch.arange(batch_size, device=config["device"]),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens.cpu(),
|
||||
attn_backend=backend,
|
||||
)
|
||||
fb.req_to_token_pool = model_runner.req_to_token_pool
|
||||
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
return fb
|
||||
|
||||
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
|
||||
"""Populate KV cache with identical data for both backends."""
|
||||
torch.manual_seed(config["seed_cache"]) # Fixed seed for reproducible cache
|
||||
|
||||
for model_runner in model_runners:
|
||||
torch.manual_seed(config["seed_cache"]) # Reset seed for each backend
|
||||
for i in range(batch_size):
|
||||
seq_len = int(seq_lens[i].item())
|
||||
for token_idx in range(seq_len - 1):
|
||||
# Create random K components for MLA
|
||||
cache_k_nope = torch.randn(
|
||||
(1, config["qk_nope_head_dim"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
cache_k_rope = torch.randn(
|
||||
(1, config["qk_rope_head_dim"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
|
||||
# Calculate cache location
|
||||
cache_loc = model_runner.req_to_token_pool.req_to_token[
|
||||
i, token_idx
|
||||
]
|
||||
|
||||
# Save to KV cache
|
||||
model_runner.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer,
|
||||
cache_loc.unsqueeze(0),
|
||||
cache_k_nope.squeeze(0),
|
||||
cache_k_rope.squeeze(0),
|
||||
)
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test basic functionality with minimal setup."""
|
||||
print(f"\nRunning basic functionality tests...")
|
||||
|
||||
for test_case in TEST_CASES["basic_functionality"]:
|
||||
with self.subTest(test_case=test_case["name"]):
|
||||
print(f" Testing {test_case['name']}: {test_case['description']}")
|
||||
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
|
||||
# Create components
|
||||
model_runner_trtllm, _, trtllm_backend, _, layer = (
|
||||
self._create_model_components(config)
|
||||
)
|
||||
|
||||
# Create sequence lengths - properly handle different batch sizes
|
||||
if batch_size == 2:
|
||||
seq_lens = torch.tensor(
|
||||
[max_seq_len, max_seq_len // 2], device=config["device"]
|
||||
)
|
||||
else:
|
||||
# For larger batch sizes, create varied sequence lengths
|
||||
torch.manual_seed(config["seed_cache"])
|
||||
seq_lens = torch.randint(
|
||||
max_seq_len // 2,
|
||||
max_seq_len + 1,
|
||||
(batch_size,),
|
||||
device=config["device"],
|
||||
)
|
||||
seq_lens[0] = max_seq_len # Ensure at least one max length
|
||||
|
||||
# Create forward batch
|
||||
fb = self._create_forward_batch(
|
||||
batch_size, seq_lens, trtllm_backend, model_runner_trtllm, config
|
||||
)
|
||||
trtllm_backend.init_forward_metadata(fb)
|
||||
|
||||
# Populate KV cache
|
||||
self._populate_kv_cache(
|
||||
batch_size, seq_lens, [model_runner_trtllm], layer, config
|
||||
)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
|
||||
# Run forward decode
|
||||
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
|
||||
|
||||
# Basic checks
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config["num_attention_heads"] * config["v_head_dim"],
|
||||
)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
self.assertEqual(output.dtype, config["dtype"])
|
||||
self.assertFalse(torch.isnan(output).any())
|
||||
self.assertFalse(torch.isinf(output).any())
|
||||
|
||||
def test_decode_output_match(self):
|
||||
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
||||
print(f"\nRunning decode output matching tests...")
|
||||
|
||||
for test_case in TEST_CASES["decode_output_match"]:
|
||||
with self.subTest(test_case=test_case["name"]):
|
||||
print(f" Testing {test_case['name']}: {test_case['description']}")
|
||||
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
|
||||
# Create components
|
||||
(
|
||||
model_runner_trtllm,
|
||||
model_runner_reference,
|
||||
trtllm_backend,
|
||||
reference_backend,
|
||||
layer,
|
||||
) = self._create_model_components(config)
|
||||
|
||||
# Create identical sequence lengths for both backends
|
||||
torch.manual_seed(config["seed_cache"])
|
||||
seq_lens = torch.randint(
|
||||
1, max_seq_len, (batch_size,), device=config["device"]
|
||||
)
|
||||
seq_lens[0] = max_seq_len # Ensure at least one max length
|
||||
|
||||
# Create forward batches with identical inputs
|
||||
fb_trtllm = self._create_forward_batch(
|
||||
batch_size,
|
||||
seq_lens.clone(),
|
||||
trtllm_backend,
|
||||
model_runner_trtllm,
|
||||
config,
|
||||
)
|
||||
fb_reference = self._create_forward_batch(
|
||||
batch_size,
|
||||
seq_lens.clone(),
|
||||
reference_backend,
|
||||
model_runner_reference,
|
||||
config,
|
||||
)
|
||||
|
||||
# Initialize metadata for both backends
|
||||
trtllm_backend.init_forward_metadata(fb_trtllm)
|
||||
reference_backend.init_forward_metadata(fb_reference)
|
||||
|
||||
# Populate both KV caches identically
|
||||
self._populate_kv_cache(
|
||||
batch_size,
|
||||
seq_lens,
|
||||
[model_runner_trtllm, model_runner_reference],
|
||||
layer,
|
||||
config,
|
||||
)
|
||||
|
||||
# Create Q, K, V tensors for current decode step
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
|
||||
# Run forward decode on both backends
|
||||
out_trtllm = trtllm_backend.forward_decode(
|
||||
q.clone(), k.clone(), v.clone(), layer, fb_trtllm
|
||||
)
|
||||
out_reference = reference_backend.forward_decode(
|
||||
q.clone(), k.clone(), v.clone(), layer, fb_reference
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
comparison_passed = compare_outputs(
|
||||
out_trtllm, out_reference, tolerance=config["tolerance"]
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
comparison_passed,
|
||||
f"TRTLLM and Reference outputs differ beyond tolerance. "
|
||||
f"Config: {test_case['name']}, "
|
||||
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
||||
)
|
||||
|
||||
def test_page_size_consistency(self):
|
||||
"""Test output consistency across different page sizes."""
|
||||
print(f"\nRunning page size consistency tests...")
|
||||
|
||||
for test_case in TEST_CASES["page_size_consistency"]:
|
||||
with self.subTest(test_case=test_case["name"]):
|
||||
print(f" Testing {test_case['name']}: {test_case['description']}")
|
||||
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
|
||||
# Create components
|
||||
model_runner, _, backend, _, layer = self._create_model_components(
|
||||
config
|
||||
)
|
||||
|
||||
# Create sequence lengths
|
||||
torch.manual_seed(config["seed_cache"])
|
||||
seq_lens = torch.randint(
|
||||
1, max_seq_len, (batch_size,), device=config["device"]
|
||||
)
|
||||
seq_lens[0] = max_seq_len
|
||||
|
||||
# Create forward batch
|
||||
fb = self._create_forward_batch(
|
||||
batch_size, seq_lens, backend, model_runner, config
|
||||
)
|
||||
backend.init_forward_metadata(fb)
|
||||
|
||||
# Populate KV cache
|
||||
self._populate_kv_cache(
|
||||
batch_size, seq_lens, [model_runner], layer, config
|
||||
)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
|
||||
# Run forward decode
|
||||
output = backend.forward_decode(q, k, v, layer, fb)
|
||||
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config["num_attention_heads"] * config["v_head_dim"],
|
||||
)
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
expected_shape,
|
||||
f"Output shape mismatch: {output.shape} vs {expected_shape}",
|
||||
)
|
||||
self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
|
||||
self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
|
||||
|
||||
def test_shape_sanity(self):
|
||||
"""Smoke test decode across several configurations."""
|
||||
print(f"\nRunning shape sanity tests...")
|
||||
|
||||
for test_case in TEST_CASES["shape_sanity_tests"]:
|
||||
with self.subTest(test_case=test_case["name"]):
|
||||
print(f" Testing {test_case['name']}: {test_case['description']}")
|
||||
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
|
||||
model_runner, _, backend, _, layer = self._create_model_components(
|
||||
config
|
||||
)
|
||||
|
||||
# Random seq lens (ensure one matches max)
|
||||
torch.manual_seed(config["seed_cache"])
|
||||
seq_lens = torch.randint(
|
||||
1, max_seq_len, (batch_size,), device=config["device"]
|
||||
)
|
||||
seq_lens[0] = max_seq_len
|
||||
|
||||
fb = self._create_forward_batch(
|
||||
batch_size, seq_lens, backend, model_runner, config
|
||||
)
|
||||
backend.init_forward_metadata(fb)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
||||
q = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], head_dim),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
k = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], head_dim),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
v = None
|
||||
|
||||
# Run forward decode
|
||||
output = backend.forward_decode(q, k, v, layer, fb)
|
||||
|
||||
# Shape and sanity checks
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config["num_attention_heads"] * config["v_head_dim"],
|
||||
)
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
expected_shape,
|
||||
f"Output shape mismatch for {test_case['name']}",
|
||||
)
|
||||
self.assertEqual(output.dtype, config["dtype"])
|
||||
self.assertEqual(output.device.type, "cuda")
|
||||
self.assertFalse(
|
||||
torch.isnan(output).any(),
|
||||
f"Output contains NaN for {test_case['name']}",
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isinf(output).any(),
|
||||
f"Output contains Inf for {test_case['name']}",
|
||||
)
|
||||
|
||||
def test_metadata_initialization(self):
|
||||
"""Test TRTLLM MLA metadata initialization and structure."""
|
||||
print(f"\nRunning metadata initialization tests...")
|
||||
|
||||
for test_case in TEST_CASES["metadata_tests"]:
|
||||
with self.subTest(test_case=test_case["name"]):
|
||||
print(f" Testing {test_case['name']}: {test_case['description']}")
|
||||
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
|
||||
# Create components
|
||||
model_runner, _, backend, _, layer = self._create_model_components(
|
||||
config
|
||||
)
|
||||
|
||||
# Create varied sequence lengths
|
||||
torch.manual_seed(config["seed_cache"])
|
||||
if batch_size == 1:
|
||||
seq_lens = torch.tensor([max_seq_len], device=config["device"])
|
||||
else:
|
||||
seq_lens = torch.randint(
|
||||
max(1, max_seq_len // 4),
|
||||
max_seq_len + 1,
|
||||
(batch_size,),
|
||||
device=config["device"],
|
||||
)
|
||||
seq_lens[0] = max_seq_len # Ensure at least one max length
|
||||
|
||||
# Create forward batch
|
||||
fb = self._create_forward_batch(
|
||||
batch_size, seq_lens, backend, model_runner, config
|
||||
)
|
||||
|
||||
# Initialize metadata
|
||||
backend.init_forward_metadata(fb)
|
||||
|
||||
# Verify metadata exists
|
||||
self.assertIsNotNone(backend.forward_metadata)
|
||||
self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
|
||||
|
||||
# Test metadata structure
|
||||
metadata = backend.forward_metadata
|
||||
self.assertIsNotNone(
|
||||
metadata.workspace, "Workspace should be allocated"
|
||||
)
|
||||
self.assertIsNotNone(
|
||||
metadata.block_kv_indices, "Block KV indices should be created"
|
||||
)
|
||||
|
||||
# Test workspace properties
|
||||
self.assertEqual(metadata.workspace.device.type, "cuda")
|
||||
self.assertEqual(metadata.workspace.dtype, torch.int8)
|
||||
self.assertGreater(
|
||||
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
|
||||
)
|
||||
|
||||
# Test block KV indices properties
|
||||
self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
|
||||
self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
|
||||
self.assertEqual(metadata.block_kv_indices.shape[0], batch_size)
|
||||
|
||||
# Verify block indices are valid (>= -1, since -1 is padding)
|
||||
self.assertTrue(
|
||||
(metadata.block_kv_indices >= -1).all(),
|
||||
"All block indices should be >= -1 (with -1 as padding)",
|
||||
)
|
||||
|
||||
def test_metadata_block_calculation(self):
|
||||
"""Test block count calculation logic."""
|
||||
print(f"\nRunning metadata block calculation tests...")
|
||||
|
||||
test_scenarios = [
|
||||
{"seq_len": 31, "page_size": 32, "expected_min_blocks": 1},
|
||||
{"seq_len": 32, "page_size": 32, "expected_min_blocks": 1},
|
||||
{"seq_len": 33, "page_size": 32, "expected_min_blocks": 2},
|
||||
{"seq_len": 128, "page_size": 32, "expected_min_blocks": 4},
|
||||
{"seq_len": 128, "page_size": 64, "expected_min_blocks": 2},
|
||||
]
|
||||
|
||||
for scenario in test_scenarios:
|
||||
with self.subTest(scenario=scenario):
|
||||
config = self._merge_config(
|
||||
{
|
||||
"batch_size": 1,
|
||||
"max_seq_len": scenario["seq_len"],
|
||||
"page_size": scenario["page_size"],
|
||||
}
|
||||
)
|
||||
|
||||
model_runner, _, backend, _, _ = self._create_model_components(config)
|
||||
|
||||
# Test internal block calculation
|
||||
calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"])
|
||||
|
||||
# Should be at least the minimum required
|
||||
self.assertGreaterEqual(
|
||||
calculated_blocks,
|
||||
scenario["expected_min_blocks"],
|
||||
f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})",
|
||||
)
|
||||
|
||||
# Should satisfy page_size constraint
|
||||
total_tokens = calculated_blocks * scenario["page_size"]
|
||||
self.assertGreaterEqual(
|
||||
total_tokens,
|
||||
scenario["seq_len"],
|
||||
f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})",
|
||||
)
|
||||
|
||||
# Should satisfy TRT-LLM and Triton constraints
|
||||
trtllm_constraint = 128 // scenario["page_size"]
|
||||
constraint_lcm = math.lcm(
|
||||
trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
|
||||
)
|
||||
self.assertEqual(
|
||||
calculated_blocks % constraint_lcm,
|
||||
0,
|
||||
f"Block count should be multiple of LCM of constraints ({constraint_lcm})",
|
||||
)
|
||||
|
||||
def test_metadata_kv_indices_correctness(self):
|
||||
"""Test KV indices creation and correctness."""
|
||||
print(f"\nRunning KV indices correctness tests...")
|
||||
|
||||
for test_case in TEST_CASES["metadata_tests"][
|
||||
:2
|
||||
]: # Test subset for performance
|
||||
with self.subTest(test_case=test_case["name"]):
|
||||
print(f" Testing {test_case['name']}: {test_case['description']}")
|
||||
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
|
||||
model_runner, _, backend, _, layer = self._create_model_components(
|
||||
config
|
||||
)
|
||||
|
||||
# Create known sequence lengths
|
||||
torch.manual_seed(config["seed_cache"])
|
||||
if batch_size == 1:
|
||||
seq_lens = torch.tensor([max_seq_len], device=config["device"])
|
||||
else:
|
||||
seq_lens = torch.randint(
|
||||
max_seq_len // 2,
|
||||
max_seq_len + 1,
|
||||
(batch_size,),
|
||||
device=config["device"],
|
||||
)
|
||||
|
||||
fb = self._create_forward_batch(
|
||||
batch_size, seq_lens, backend, model_runner, config
|
||||
)
|
||||
|
||||
# Populate some KV cache to have valid indices
|
||||
self._populate_kv_cache(
|
||||
batch_size, seq_lens, [model_runner], layer, config
|
||||
)
|
||||
|
||||
# Initialize metadata
|
||||
backend.init_forward_metadata(fb)
|
||||
metadata = backend.forward_metadata
|
||||
|
||||
# Verify KV indices structure
|
||||
block_kv_indices = metadata.block_kv_indices
|
||||
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i].item()
|
||||
expected_blocks = backend._calc_padded_blocks(seq_len)
|
||||
|
||||
# Count valid (non -1) indices for this sequence
|
||||
valid_indices = (block_kv_indices[i] >= 0).sum().item()
|
||||
|
||||
# Should have at least enough blocks for the sequence
|
||||
min_required_blocks = (seq_len + config["page_size"] - 1) // config[
|
||||
"page_size"
|
||||
]
|
||||
self.assertGreaterEqual(
|
||||
valid_indices,
|
||||
min_required_blocks,
|
||||
f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}",
|
||||
)
|
||||
|
||||
# Verify indices are within valid range
|
||||
valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0]
|
||||
if len(valid_block_indices) > 0:
|
||||
max_possible_blocks = (
|
||||
model_runner.token_to_kv_pool.size // config["page_size"]
|
||||
)
|
||||
self.assertTrue(
|
||||
(valid_block_indices < max_possible_blocks).all(),
|
||||
f"All block indices should be < {max_possible_blocks}",
|
||||
)
|
||||
|
||||
def test_metadata_cuda_graph_compatibility(self):
|
||||
"""Test metadata compatibility with CUDA graph capture/replay."""
|
||||
print(f"\nRunning CUDA graph compatibility tests...")
|
||||
|
||||
config = self._merge_config(
|
||||
{"batch_size": 4, "max_seq_len": 64, "page_size": 32}
|
||||
)
|
||||
|
||||
model_runner, _, backend, _, layer = self._create_model_components(config)
|
||||
batch_size = config["batch_size"]
|
||||
|
||||
# Initialize CUDA graph state
|
||||
backend.init_cuda_graph_state(
|
||||
max_bs=batch_size, max_num_tokens=config["max_seq_len"] * batch_size
|
||||
)
|
||||
|
||||
# Verify CUDA graph buffers are allocated
|
||||
self.assertIsNotNone(backend.cuda_graph_kv_indices)
|
||||
self.assertIsNotNone(backend.cuda_graph_workspace)
|
||||
|
||||
# Test capture metadata
|
||||
seq_lens = torch.full(
|
||||
(batch_size,), config["max_seq_len"], device=config["device"]
|
||||
)
|
||||
req_pool_indices = torch.arange(batch_size, device=config["device"])
|
||||
|
||||
backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs=batch_size,
|
||||
num_tokens=batch_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=None,
|
||||
)
|
||||
|
||||
# Verify capture metadata
|
||||
self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
|
||||
capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
|
||||
|
||||
self.assertIsNotNone(capture_metadata.workspace)
|
||||
self.assertIsNotNone(capture_metadata.block_kv_indices)
|
||||
|
||||
# Test replay with different sequence lengths
|
||||
new_seq_lens = torch.randint(
|
||||
config["max_seq_len"] // 2,
|
||||
config["max_seq_len"] + 1,
|
||||
(batch_size,),
|
||||
device=config["device"],
|
||||
)
|
||||
|
||||
backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs=batch_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=new_seq_lens,
|
||||
seq_lens_sum=new_seq_lens.sum().item(),
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=None,
|
||||
seq_lens_cpu=new_seq_lens.cpu(),
|
||||
)
|
||||
|
||||
# Verify replay updated the metadata
|
||||
replay_metadata = backend.forward_metadata
|
||||
self.assertIsNotNone(replay_metadata)
|
||||
self.assertEqual(
|
||||
replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
|
||||
)
|
||||
|
||||
def test_metadata_consistency_across_calls(self):
|
||||
"""Test metadata consistency across multiple forward calls."""
|
||||
print(f"\nRunning metadata consistency tests...")
|
||||
|
||||
config = self._merge_config(
|
||||
{"batch_size": 2, "max_seq_len": 64, "page_size": 32}
|
||||
)
|
||||
|
||||
model_runner, _, backend, _, layer = self._create_model_components(config)
|
||||
|
||||
# First call
|
||||
seq_lens_1 = torch.tensor([32, 48], device=config["device"])
|
||||
fb_1 = self._create_forward_batch(
|
||||
config["batch_size"], seq_lens_1, backend, model_runner, config
|
||||
)
|
||||
backend.init_forward_metadata(fb_1)
|
||||
metadata_1 = backend.forward_metadata
|
||||
|
||||
# Second call with same sequence lengths
|
||||
seq_lens_2 = torch.tensor([32, 48], device=config["device"])
|
||||
fb_2 = self._create_forward_batch(
|
||||
config["batch_size"], seq_lens_2, backend, model_runner, config
|
||||
)
|
||||
backend.init_forward_metadata(fb_2)
|
||||
metadata_2 = backend.forward_metadata
|
||||
|
||||
# Metadata structure should be consistent
|
||||
self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
|
||||
self.assertEqual(
|
||||
metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
|
||||
)
|
||||
|
||||
# Third call with different sequence lengths
|
||||
seq_lens_3 = torch.tensor([16, 64], device=config["device"])
|
||||
fb_3 = self._create_forward_batch(
|
||||
config["batch_size"], seq_lens_3, backend, model_runner, config
|
||||
)
|
||||
backend.init_forward_metadata(fb_3)
|
||||
metadata_3 = backend.forward_metadata
|
||||
|
||||
# Should still have valid structure
|
||||
self.assertIsNotNone(metadata_3.workspace)
|
||||
self.assertIsNotNone(metadata_3.block_kv_indices)
|
||||
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user