TRTLLM-MLA FP8 path (#8638)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
|
||||
"layer_id": 0,
|
||||
}
|
||||
|
||||
ROPE_BASE = 10000
|
||||
ROPE_SCALING_CONFIG = {
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
"rope_type": "deepseek_yarn",
|
||||
}
|
||||
|
||||
|
||||
def build_rotary_emb(config, device=None):
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
|
||||
dev = device or config["device"]
|
||||
rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
|
||||
rotary = get_rope_wrapper(
|
||||
head_size=config["qk_rope_head_dim"],
|
||||
rotary_dim=config["qk_rope_head_dim"],
|
||||
max_position=config["context_len"],
|
||||
base=ROPE_BASE,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=dev,
|
||||
)
|
||||
rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
|
||||
return rotary
|
||||
|
||||
|
||||
# Centralized test cases for different test scenarios
|
||||
TEST_CASES = {
|
||||
"basic_functionality": [
|
||||
@@ -63,18 +94,36 @@ TEST_CASES = {
|
||||
],
|
||||
"decode_output_match": [
|
||||
{
|
||||
"name": "single",
|
||||
"name": "single_fp16",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 32,
|
||||
"description": "Single vs reference",
|
||||
"description": "Single FP16 vs reference",
|
||||
},
|
||||
{
|
||||
"name": "batch",
|
||||
"name": "single_fp8",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 64,
|
||||
"tolerance": 1e-1,
|
||||
"kv_cache_dtype": torch.float8_e4m3fn,
|
||||
"description": "Single FP8 vs reference",
|
||||
},
|
||||
{
|
||||
"name": "batch_fp16",
|
||||
"batch_size": 32,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 32,
|
||||
"description": "Batch vs reference",
|
||||
"description": "Batch FP16 vs reference",
|
||||
},
|
||||
{
|
||||
"name": "batch_fp8",
|
||||
"batch_size": 32,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 64,
|
||||
"tolerance": 1e-1,
|
||||
"kv_cache_dtype": torch.float8_e4m3fn,
|
||||
"description": "Batch FP8 vs reference",
|
||||
},
|
||||
],
|
||||
"page_size_consistency": [
|
||||
@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
layer,
|
||||
)
|
||||
|
||||
def _create_qkv_tensors(self, batch_size, config):
|
||||
"""Create Q, K, V tensors for testing."""
|
||||
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
||||
device = config["device"]
|
||||
dtype = config["dtype"]
|
||||
def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
|
||||
"""Create Q, K, V random tensors for given batch size with separate MLA components.
|
||||
|
||||
q = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], head_dim),
|
||||
dtype=dtype,
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
config: Configuration dict with model dims and device.
|
||||
dtype_override: Optional torch dtype to override config["dtype"].
|
||||
|
||||
Returns:
|
||||
Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
|
||||
"""
|
||||
device = config["device"]
|
||||
target_dtype = dtype_override or config["dtype"]
|
||||
|
||||
# Create separate nope and rope components for Q
|
||||
q_nope = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
k = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
|
||||
q_rope = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create separate nope and rope components for K
|
||||
k_nope = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
k_rope = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# V tensor (unchanged)
|
||||
v = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
|
||||
dtype=dtype,
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
return q, k, v
|
||||
|
||||
return q_nope, q_rope, k_nope, k_rope, v
|
||||
|
||||
def _create_forward_batch(
|
||||
self, batch_size, seq_lens, backend, model_runner, config
|
||||
@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
)
|
||||
fb.req_to_token_pool = model_runner.req_to_token_pool
|
||||
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
|
||||
# Add position information for RoPE
|
||||
fb.positions = torch.arange(batch_size, device=config["device"])
|
||||
|
||||
return fb
|
||||
|
||||
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
|
||||
@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
for token_idx in range(seq_len - 1):
|
||||
# Create random K components for MLA
|
||||
cache_k_nope = torch.randn(
|
||||
(1, config["qk_nope_head_dim"]),
|
||||
(1, config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
batch_size, seq_lens, [model_runner_trtllm], layer, config
|
||||
)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
# Create Q, K, V tensors with separate MLA components
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
||||
batch_size, config
|
||||
)
|
||||
|
||||
# Run forward decode
|
||||
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
|
||||
# Run forward decode with separate MLA components
|
||||
output = trtllm_backend.forward_decode(
|
||||
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||
)
|
||||
|
||||
# Basic checks
|
||||
expected_shape = (
|
||||
@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
|
||||
|
||||
# Create components
|
||||
(
|
||||
@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
|
||||
# Create Q, K, V tensors for current decode step
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
|
||||
q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
|
||||
self._create_qkv_tensors(batch_size, config)
|
||||
)
|
||||
q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
|
||||
q_nope_ref.clone(),
|
||||
q_rope_ref.clone(),
|
||||
k_nope_ref.clone(),
|
||||
k_rope_ref.clone(),
|
||||
v_ref.clone(),
|
||||
)
|
||||
tolerance = config["tolerance"]
|
||||
|
||||
extra_args = {}
|
||||
if use_fp8:
|
||||
# TRT kernel applies RoPE + FP8 quantization internally
|
||||
# pre-apply RoPE on the reference (FlashInfer) path here so
|
||||
# both paths share the same rope params/cache while keeping
|
||||
# the TRT path unrotated.
|
||||
rotary_emb = build_rotary_emb(config)
|
||||
q_rope_ref, k_rope_ref = rotary_emb(
|
||||
fb_reference.positions, q_rope_ref, k_rope_ref
|
||||
)
|
||||
extra_args = {
|
||||
"cos_sin_cache": rotary_emb.cos_sin_cache,
|
||||
"is_neox": rotary_emb.is_neox_style,
|
||||
}
|
||||
|
||||
dtype = q_rope_ref.dtype
|
||||
q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||
q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||
k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||
k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||
|
||||
# Run forward decode on both backends
|
||||
out_trtllm = trtllm_backend.forward_decode(
|
||||
q.clone(), k.clone(), v.clone(), layer, fb_trtllm
|
||||
q_nope_trt,
|
||||
k_nope_trt,
|
||||
None,
|
||||
layer,
|
||||
fb_trtllm,
|
||||
q_rope=q_rope_trt,
|
||||
k_rope=k_rope_trt,
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
# Reference backend should also take separate components, not concatenated
|
||||
out_reference = reference_backend.forward_decode(
|
||||
q.clone(), k.clone(), v.clone(), layer, fb_reference
|
||||
q_nope_ref,
|
||||
k_nope_ref,
|
||||
v_ref,
|
||||
layer,
|
||||
fb_reference,
|
||||
q_rope=q_rope_ref,
|
||||
k_rope=k_rope_ref,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
comparison_passed = compare_outputs(
|
||||
out_trtllm, out_reference, tolerance=config["tolerance"]
|
||||
out_trtllm, out_reference, tolerance=tolerance
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
batch_size, seq_lens, [model_runner], layer, config
|
||||
)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
# Create Q, K, V tensors with separate MLA components
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
||||
batch_size, config
|
||||
)
|
||||
|
||||
# Run forward decode
|
||||
output = backend.forward_decode(q, k, v, layer, fb)
|
||||
# Run forward decode with separate MLA components
|
||||
output = backend.forward_decode(
|
||||
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||
)
|
||||
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
)
|
||||
backend.init_forward_metadata(fb)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
# Create Q, K, V tensors with separate MLA components
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
||||
q = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], head_dim),
|
||||
q_nope = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
k = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], head_dim),
|
||||
k_nope = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
v = None
|
||||
q_rope = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
config["num_attention_heads"],
|
||||
config["qk_rope_head_dim"],
|
||||
),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
k_rope = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
v = None # Test with None v
|
||||
|
||||
# Run forward decode
|
||||
output = backend.forward_decode(q, k, v, layer, fb)
|
||||
output = backend.forward_decode(
|
||||
q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||
)
|
||||
|
||||
# Shape and sanity checks
|
||||
expected_shape = (
|
||||
|
||||
Reference in New Issue
Block a user