Fix incorrect KV indices creation when page_size=32 in TRTLLM MLA backend (#11985)
This commit is contained in:
@@ -9,19 +9,12 @@ and uses BatchMLAPaged wrapper for decoding.
|
||||
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||
import logging
|
||||
|
||||
torch._logging.set_logs(dynamo=logging.ERROR)
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
@@ -45,6 +38,12 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
|
||||
import logging
|
||||
|
||||
torch._logging.set_logs(dynamo=logging.ERROR)
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
BatchMLAPagedAttentionWrapper,
|
||||
|
||||
@@ -17,8 +17,8 @@ from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
from sglang.srt.layers.attention.utils import (
|
||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
create_flashmla_kv_indices_triton,
|
||||
get_num_page_per_block_flashmla,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
@@ -295,9 +295,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
# Apply dual constraints (take LCM to satisfy both):
|
||||
# 1. TRT-LLM: block_num % (128 / page_size) == 0
|
||||
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
|
||||
# 2. Triton: number of pages per block
|
||||
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
|
||||
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
|
||||
triton_constraint = get_num_page_per_block_flashmla(self.page_size)
|
||||
constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
|
||||
|
||||
if blocks % constraint_lcm != 0:
|
||||
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
|
||||
@@ -336,7 +337,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
block_kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
max_blocks,
|
||||
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
PAGED_SIZE=self.page_size,
|
||||
)
|
||||
|
||||
@@ -417,7 +417,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
block_kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
max_blocks_per_seq,
|
||||
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
PAGED_SIZE=self.page_size,
|
||||
)
|
||||
|
||||
@@ -504,7 +503,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
metadata.block_kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
metadata.block_kv_indices.shape[1],
|
||||
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
PAGED_SIZE=self.page_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
|
||||
# Number of pages that the kernel writes per iteration.
|
||||
# Exposed here so other Python modules can import it instead of hard-coding 64.
|
||||
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
||||
_FLASHMLA_CREATE_KV_BLOCK_SIZE = 4096
|
||||
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON = tl.constexpr(_FLASHMLA_CREATE_KV_BLOCK_SIZE)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -46,6 +44,11 @@ def create_flashinfer_kv_indices_triton(
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
|
||||
|
||||
def get_num_page_per_block_flashmla(page_size: int = 64) -> int:
|
||||
num_page_per_block = _FLASHMLA_CREATE_KV_BLOCK_SIZE // page_size
|
||||
return num_page_per_block
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashmla_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
@@ -55,10 +58,11 @@ def create_flashmla_kv_indices_triton(
|
||||
kv_indices_ptr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
kv_indices_ptr_stride: tl.constexpr,
|
||||
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
PAGED_SIZE: tl.constexpr = 64,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 4096
|
||||
NUM_PAGE_PER_BLOCK: tl.constexpr = (
|
||||
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON // PAGED_SIZE
|
||||
)
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# find the req pool idx, this is for batch to token
|
||||
@@ -73,7 +77,7 @@ def create_flashmla_kv_indices_triton(
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
|
||||
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
num_pages_loop = tl.cdiv(kv_end - kv_start, FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON)
|
||||
|
||||
for i in range(num_pages_loop):
|
||||
# index into req_to_token_ptr needs to be int64
|
||||
|
||||
@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
|
||||
TRTLLMMLABackend,
|
||||
TRTLLMMLADecodeMetadata,
|
||||
)
|
||||
from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
|
||||
from sglang.srt.layers.attention.utils import get_num_page_per_block_flashmla
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import (
|
||||
ServerArgs,
|
||||
get_global_server_args,
|
||||
set_global_server_args_for_scheduler,
|
||||
)
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
@@ -104,15 +109,15 @@ TEST_CASES = {
|
||||
"page_size": 32,
|
||||
"description": "Single FP16 vs reference",
|
||||
},
|
||||
{
|
||||
"name": "single_fp8",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 64,
|
||||
"tolerance": 1e-1,
|
||||
"kv_cache_dtype": torch.float8_e4m3fn,
|
||||
"description": "Single FP8 vs reference",
|
||||
},
|
||||
# {
|
||||
# "name": "single_fp8",
|
||||
# "batch_size": 1,
|
||||
# "max_seq_len": 64,
|
||||
# "page_size": 64,
|
||||
# "tolerance": 1e-1,
|
||||
# "kv_cache_dtype": torch.float8_e4m3fn,
|
||||
# "description": "Single FP8 vs reference",
|
||||
# },
|
||||
{
|
||||
"name": "batch_fp16",
|
||||
"batch_size": 32,
|
||||
@@ -120,15 +125,15 @@ TEST_CASES = {
|
||||
"page_size": 32,
|
||||
"description": "Batch FP16 vs reference",
|
||||
},
|
||||
{
|
||||
"name": "batch_fp8",
|
||||
"batch_size": 32,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 64,
|
||||
"tolerance": 1e-1,
|
||||
"kv_cache_dtype": torch.float8_e4m3fn,
|
||||
"description": "Batch FP8 vs reference",
|
||||
},
|
||||
# {
|
||||
# "name": "batch_fp8",
|
||||
# "batch_size": 32,
|
||||
# "max_seq_len": 64,
|
||||
# "page_size": 64,
|
||||
# "tolerance": 1e-1,
|
||||
# "kv_cache_dtype": torch.float8_e4m3fn,
|
||||
# "description": "Batch FP8 vs reference",
|
||||
# },
|
||||
],
|
||||
"page_size_consistency": [
|
||||
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
|
||||
@@ -213,13 +218,7 @@ class MockModelRunner:
|
||||
self.page_size = config["page_size"]
|
||||
|
||||
# Server args stub - needed by attention backends
|
||||
self.server_args = type(
|
||||
"ServerArgs",
|
||||
(),
|
||||
{
|
||||
"enable_dp_attention": False, # Default value for testing
|
||||
},
|
||||
)
|
||||
self.server_args = get_global_server_args()
|
||||
|
||||
# Model-config stub with MLA attributes
|
||||
self.model_config = type(
|
||||
@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
|
||||
class TestTRTLLMMLA(CustomTestCase):
|
||||
"""Test suite for TRTLLM MLA backend with centralized configuration."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Set up global server args for testing."""
|
||||
server_args = ServerArgs(model_path="dummy")
|
||||
server_args.enable_dp_attention = False
|
||||
set_global_server_args_for_scheduler(server_args)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
pass
|
||||
|
||||
def _merge_config(self, test_case):
|
||||
"""Merge test case with default configuration."""
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
backend.init_forward_metadata(fb)
|
||||
|
||||
# Verify metadata exists
|
||||
self.assertIsNotNone(backend.forward_metadata)
|
||||
self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
|
||||
self.assertIsNotNone(backend.forward_decode_metadata)
|
||||
self.assertIsInstance(
|
||||
backend.forward_decode_metadata, TRTLLMMLADecodeMetadata
|
||||
)
|
||||
|
||||
# Test metadata structure
|
||||
metadata = backend.forward_metadata
|
||||
self.assertIsNotNone(
|
||||
metadata.workspace, "Workspace should be allocated"
|
||||
)
|
||||
metadata = backend.forward_decode_metadata
|
||||
self.assertIsNotNone(
|
||||
metadata.block_kv_indices, "Block KV indices should be created"
|
||||
)
|
||||
|
||||
# Test workspace properties
|
||||
self.assertEqual(metadata.workspace.device.type, "cuda")
|
||||
self.assertEqual(metadata.workspace.dtype, torch.uint8)
|
||||
self.assertGreater(
|
||||
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
|
||||
)
|
||||
|
||||
# Test block KV indices properties
|
||||
self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
|
||||
self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
|
||||
@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
|
||||
# Should satisfy TRT-LLM and Triton constraints
|
||||
trtllm_constraint = 128 // scenario["page_size"]
|
||||
constraint_lcm = math.lcm(
|
||||
trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
|
||||
triton_constraint = get_num_page_per_block_flashmla(
|
||||
scenario["page_size"]
|
||||
)
|
||||
constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
|
||||
self.assertEqual(
|
||||
calculated_blocks % constraint_lcm,
|
||||
0,
|
||||
@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
|
||||
# Initialize metadata
|
||||
backend.init_forward_metadata(fb)
|
||||
metadata = backend.forward_metadata
|
||||
metadata = backend.forward_decode_metadata
|
||||
|
||||
# Verify KV indices structure
|
||||
block_kv_indices = metadata.block_kv_indices
|
||||
@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
|
||||
# Verify CUDA graph buffers are allocated
|
||||
self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
|
||||
self.assertIsNotNone(backend.decode_cuda_graph_workspace)
|
||||
|
||||
# Test capture metadata
|
||||
seq_lens = torch.full(
|
||||
@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
|
||||
capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
|
||||
|
||||
self.assertIsNotNone(capture_metadata.workspace)
|
||||
self.assertIsNotNone(capture_metadata.block_kv_indices)
|
||||
|
||||
# Test replay with different sequence lengths
|
||||
@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
)
|
||||
|
||||
# Verify replay updated the metadata
|
||||
replay_metadata = backend.forward_metadata
|
||||
replay_metadata = backend.forward_decode_metadata
|
||||
self.assertIsNotNone(replay_metadata)
|
||||
self.assertEqual(
|
||||
replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
|
||||
)
|
||||
|
||||
def test_metadata_consistency_across_calls(self):
|
||||
"""Test metadata consistency across multiple forward calls."""
|
||||
@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
config["batch_size"], seq_lens_1, backend, model_runner, config
|
||||
)
|
||||
backend.init_forward_metadata(fb_1)
|
||||
metadata_1 = backend.forward_metadata
|
||||
metadata_1 = backend.forward_decode_metadata
|
||||
|
||||
# Second call with same sequence lengths
|
||||
seq_lens_2 = torch.tensor([32, 48], device=config["device"])
|
||||
@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
config["batch_size"], seq_lens_2, backend, model_runner, config
|
||||
)
|
||||
backend.init_forward_metadata(fb_2)
|
||||
metadata_2 = backend.forward_metadata
|
||||
metadata_2 = backend.forward_decode_metadata
|
||||
|
||||
# Metadata structure should be consistent
|
||||
self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
|
||||
self.assertEqual(
|
||||
metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
|
||||
)
|
||||
@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
config["batch_size"], seq_lens_3, backend, model_runner, config
|
||||
)
|
||||
backend.init_forward_metadata(fb_3)
|
||||
metadata_3 = backend.forward_metadata
|
||||
metadata_3 = backend.forward_decode_metadata
|
||||
|
||||
# Should still have valid structure
|
||||
self.assertIsNotNone(metadata_3.workspace)
|
||||
self.assertIsNotNone(metadata_3.block_kv_indices)
|
||||
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user