TRTLLM-MLA FP8 path (#8638)

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Faraz
2025-08-11 17:02:13 -04:00
committed by GitHub
parent 44e86480e8
commit f508cd3cb7
5 changed files with 347 additions and 62 deletions

View File

@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
"layer_id": 0,
}
ROPE_BASE = 10000
ROPE_SCALING_CONFIG = {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
"rope_type": "deepseek_yarn",
}
def build_rotary_emb(config, device=None):
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
dev = device or config["device"]
rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
rotary = get_rope_wrapper(
head_size=config["qk_rope_head_dim"],
rotary_dim=config["qk_rope_head_dim"],
max_position=config["context_len"],
base=ROPE_BASE,
rope_scaling=rope_scaling,
is_neox_style=False,
device=dev,
)
rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
return rotary
# Centralized test cases for different test scenarios
TEST_CASES = {
"basic_functionality": [
@@ -63,18 +94,36 @@ TEST_CASES = {
],
"decode_output_match": [
{
"name": "single",
"name": "single_fp16",
"batch_size": 1,
"max_seq_len": 64,
"page_size": 32,
"description": "Single vs reference",
"description": "Single FP16 vs reference",
},
{
"name": "batch",
"name": "single_fp8",
"batch_size": 1,
"max_seq_len": 64,
"page_size": 64,
"tolerance": 1e-1,
"kv_cache_dtype": torch.float8_e4m3fn,
"description": "Single FP8 vs reference",
},
{
"name": "batch_fp16",
"batch_size": 32,
"max_seq_len": 64,
"page_size": 32,
"description": "Batch vs reference",
"description": "Batch FP16 vs reference",
},
{
"name": "batch_fp8",
"batch_size": 32,
"max_seq_len": 64,
"page_size": 64,
"tolerance": 1e-1,
"kv_cache_dtype": torch.float8_e4m3fn,
"description": "Batch FP8 vs reference",
},
],
"page_size_consistency": [
@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
layer,
)
def _create_qkv_tensors(self, batch_size, config):
"""Create Q, K, V tensors for testing."""
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
device = config["device"]
dtype = config["dtype"]
def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
"""Create Q, K, V random tensors for given batch size with separate MLA components.
q = torch.randn(
(batch_size, config["num_attention_heads"], head_dim),
dtype=dtype,
Args:
batch_size: Batch size.
config: Configuration dict with model dims and device.
dtype_override: Optional torch dtype to override config["dtype"].
Returns:
Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
"""
device = config["device"]
target_dtype = dtype_override or config["dtype"]
# Create separate nope and rope components for Q
q_nope = torch.randn(
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
dtype=config["dtype"],
device=device,
)
k = torch.randn(
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
q_rope = torch.randn(
(batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
dtype=config["dtype"],
device=device,
)
# Create separate nope and rope components for K
k_nope = torch.randn(
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
dtype=config["dtype"],
device=device,
)
k_rope = torch.randn(
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
dtype=config["dtype"],
device=device,
)
# V tensor (unchanged)
v = torch.randn(
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
dtype=dtype,
dtype=config["dtype"],
device=device,
)
return q, k, v
return q_nope, q_rope, k_nope, k_rope, v
def _create_forward_batch(
self, batch_size, seq_lens, backend, model_runner, config
@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
)
fb.req_to_token_pool = model_runner.req_to_token_pool
fb.token_to_kv_pool = model_runner.token_to_kv_pool
# Add position information for RoPE
fb.positions = torch.arange(batch_size, device=config["device"])
return fb
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
for token_idx in range(seq_len - 1):
# Create random K components for MLA
cache_k_nope = torch.randn(
(1, config["qk_nope_head_dim"]),
(1, config["kv_lora_rank"]),
dtype=config["dtype"],
device=config["device"],
)
@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
batch_size, seq_lens, [model_runner_trtllm], layer, config
)
# Create Q, K, V tensors
# Create Q, K, V tensors with separate MLA components
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
batch_size, config
)
# Run forward decode
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
# Run forward decode with separate MLA components
output = trtllm_backend.forward_decode(
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
)
# Basic checks
expected_shape = (
@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
# Create components
(
@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
# Create Q, K, V tensors for current decode step
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
self._create_qkv_tensors(batch_size, config)
)
q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
q_nope_ref.clone(),
q_rope_ref.clone(),
k_nope_ref.clone(),
k_rope_ref.clone(),
v_ref.clone(),
)
tolerance = config["tolerance"]
extra_args = {}
if use_fp8:
# TRT kernel applies RoPE + FP8 quantization internally
# pre-apply RoPE on the reference (FlashInfer) path here so
# both paths share the same rope params/cache while keeping
# the TRT path unrotated.
rotary_emb = build_rotary_emb(config)
q_rope_ref, k_rope_ref = rotary_emb(
fb_reference.positions, q_rope_ref, k_rope_ref
)
extra_args = {
"cos_sin_cache": rotary_emb.cos_sin_cache,
"is_neox": rotary_emb.is_neox_style,
}
dtype = q_rope_ref.dtype
q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
# Run forward decode on both backends
out_trtllm = trtllm_backend.forward_decode(
q.clone(), k.clone(), v.clone(), layer, fb_trtllm
q_nope_trt,
k_nope_trt,
None,
layer,
fb_trtllm,
q_rope=q_rope_trt,
k_rope=k_rope_trt,
**extra_args,
)
# Reference backend should also take separate components, not concatenated
out_reference = reference_backend.forward_decode(
q.clone(), k.clone(), v.clone(), layer, fb_reference
q_nope_ref,
k_nope_ref,
v_ref,
layer,
fb_reference,
q_rope=q_rope_ref,
k_rope=k_rope_ref,
)
# Compare outputs
comparison_passed = compare_outputs(
out_trtllm, out_reference, tolerance=config["tolerance"]
out_trtllm, out_reference, tolerance=tolerance
)
self.assertTrue(
@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
batch_size, seq_lens, [model_runner], layer, config
)
# Create Q, K, V tensors
# Create Q, K, V tensors with separate MLA components
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
batch_size, config
)
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
# Run forward decode with separate MLA components
output = backend.forward_decode(
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
)
expected_shape = (
batch_size,
@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
)
backend.init_forward_metadata(fb)
# Create Q, K, V tensors
# Create Q, K, V tensors with separate MLA components
torch.manual_seed(config["seed_qkv"])
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
q = torch.randn(
(batch_size, config["num_attention_heads"], head_dim),
q_nope = torch.randn(
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
dtype=config["dtype"],
device=config["device"],
)
k = torch.randn(
(batch_size, config["num_kv_heads"], head_dim),
k_nope = torch.randn(
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
dtype=config["dtype"],
device=config["device"],
)
v = None
q_rope = torch.randn(
(
batch_size,
config["num_attention_heads"],
config["qk_rope_head_dim"],
),
dtype=config["dtype"],
device=config["device"],
)
k_rope = torch.randn(
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
dtype=config["dtype"],
device=config["device"],
)
v = None # Test with None v
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
output = backend.forward_decode(
q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
)
# Shape and sanity checks
expected_shape = (