TRTLLM-MLA FP8 path (#8638)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
@@ -60,6 +60,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
|
||||
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
|
||||
```
|
||||
|
||||
- TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint)
|
||||
```bash
|
||||
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --kv-cache-dtype fp8_e4m3 --trust-remote-code
|
||||
```
|
||||
|
||||
- Ascend
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
|
||||
|
||||
@@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
)
|
||||
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
||||
|
||||
def quantize_and_rope_for_fp8(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_rope: torch.Tensor,
|
||||
k_nope: torch.Tensor,
|
||||
k_rope: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Quantize and apply RoPE for FP8 attention path.
|
||||
|
||||
This function handles the FP8 quantization and RoPE application for MLA attention.
|
||||
It takes separate query/key nope and rope components, applies RoPE to the rope parts,
|
||||
quantizes all components to FP8, and merges the query components into a single tensor.
|
||||
|
||||
Args:
|
||||
q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
||||
- expected dtype: torch.bfloat16
|
||||
q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
||||
- expected dtype: torch.bfloat16
|
||||
k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
||||
- expected dtype: torch.bfloat16
|
||||
k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
||||
- expected dtype: torch.bfloat16
|
||||
forward_batch: Forward batch containing position information
|
||||
cos_sin_cache: Precomputed cosine/sine cache for RoPE
|
||||
- expected dtype: matches q_/k_ input dtype (torch.bfloat16)
|
||||
is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
|
||||
|
||||
Returns:
|
||||
tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
|
||||
- merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
||||
- k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
|
||||
- k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
||||
"""
|
||||
attn_dtype = torch.float8_e4m3fn
|
||||
q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
|
||||
|
||||
# Allocate output tensors with FP8 dtype
|
||||
# Query output will contain merged nope + rope components
|
||||
q_out = q_rope.new_empty(
|
||||
q_len,
|
||||
num_heads,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
dtype=attn_dtype,
|
||||
)
|
||||
|
||||
# Key outputs maintain original shapes but with FP8 dtype
|
||||
k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
|
||||
k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
|
||||
|
||||
# Apply RoPE and quantize all components in a single fused kernel call
|
||||
# This kernel handles:
|
||||
# 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
|
||||
# 2. Quantization of all components to FP8 format
|
||||
# 3. Output placement into pre-allocated tensors
|
||||
flashinfer.rope.mla_rope_quantize_fp8(
|
||||
q_rope=q_rope,
|
||||
k_rope=k_rope,
|
||||
q_nope=q_nope,
|
||||
k_nope=k_nope,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
pos_ids=forward_batch.positions,
|
||||
is_neox=is_neox,
|
||||
quantize_dtype=attn_dtype,
|
||||
# Output tensor slicing: q_out contains [nope_part, rope_part]
|
||||
q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
|
||||
k_rope_out=k_rope_out,
|
||||
q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
|
||||
k_nope_out=k_nope_out,
|
||||
# Quantization scales (set to 1.0 for no additional scaling)
|
||||
quant_scale_q=1.0,
|
||||
quant_scale_kv=1.0,
|
||||
)
|
||||
|
||||
return q_out, k_nope_out, k_rope_out
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q: torch.Tensor, # q_nope
|
||||
k: torch.Tensor, # k_nope
|
||||
v: torch.Tensor, # not used in this backend
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
cos_sin_cache: Optional[torch.Tensor] = None,
|
||||
is_neox: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
"""Run forward for decode using TRTLLM MLA kernel."""
|
||||
merge_query = q_rope is not None
|
||||
if self.data_type == torch.float8_e4m3fn:
|
||||
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
||||
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
||||
assert all(
|
||||
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
||||
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
||||
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
||||
q,
|
||||
q_rope,
|
||||
k.squeeze(1),
|
||||
k_rope.squeeze(1),
|
||||
forward_batch,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
)
|
||||
merge_query = False
|
||||
|
||||
# Save KV cache if requested
|
||||
if k is not None and save_kv_cache:
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
if k_rope is not None:
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer, cache_loc, k, k_rope
|
||||
)
|
||||
elif v is not None:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
if save_kv_cache:
|
||||
assert (
|
||||
k is not None and k_rope is not None
|
||||
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
|
||||
# Prepare query tensor inline
|
||||
if q_rope is not None:
|
||||
# q contains NOPE part (v_head_dim)
|
||||
if merge_query:
|
||||
# For FP16 path, we merge the query and rope parts into a single tensor
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope_reshaped = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
||||
else:
|
||||
# q already has both parts
|
||||
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
||||
@@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
# Prepare KV cache inline
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
|
||||
# TRT-LLM expects single KV data with extra dimension
|
||||
kv_cache = pages.unsqueeze(1)
|
||||
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
||||
|
||||
# Get metadata
|
||||
metadata = (
|
||||
@@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
or self.forward_metadata
|
||||
)
|
||||
|
||||
# Scale computation for TRTLLM MLA kernel:
|
||||
# - BMM1 scale = q_scale * k_scale * softmax_scale
|
||||
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
|
||||
# - k_scale is read from model checkpoint if available
|
||||
# TODO: Change once fp8 path is supported
|
||||
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
||||
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
|
||||
# Scale components:
|
||||
# - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
|
||||
# - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
|
||||
# - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
|
||||
# This unified approach works for both FP16 and FP8 quantized attention paths.
|
||||
q_scale = 1.0
|
||||
k_scale = (
|
||||
layer.k_scale_float
|
||||
|
||||
@@ -1196,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
|
||||
"""
|
||||
Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
|
||||
"""
|
||||
return (
|
||||
self.current_attention_backend == "trtllm_mla"
|
||||
and forward_batch.forward_mode.is_decode_or_idle()
|
||||
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
def forward_absorb_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -1275,7 +1285,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
if not self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||
|
||||
@@ -1288,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
or self.current_attention_backend == "cutlass_mla"
|
||||
or self.current_attention_backend == "trtllm_mla"
|
||||
):
|
||||
extra_args = {}
|
||||
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||
extra_args = {
|
||||
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
||||
"is_neox": self.rotary_emb.is_neox_style,
|
||||
}
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
q_rope=q_pe,
|
||||
k_rope=k_pe,
|
||||
**extra_args,
|
||||
)
|
||||
else:
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
|
||||
@@ -432,7 +432,10 @@ class ServerArgs:
|
||||
)
|
||||
self.page_size = 128
|
||||
|
||||
if self.attention_backend == "trtllm_mla":
|
||||
if (
|
||||
self.attention_backend == "trtllm_mla"
|
||||
or self.decode_attention_backend == "trtllm_mla"
|
||||
):
|
||||
if not is_sm100_supported():
|
||||
raise ValueError(
|
||||
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
||||
@@ -443,11 +446,17 @@ class ServerArgs:
|
||||
f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
|
||||
)
|
||||
self.page_size = 64
|
||||
|
||||
if self.speculative_algorithm is not None:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
|
||||
raise ValueError(
|
||||
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
|
||||
)
|
||||
|
||||
if (
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.decode_attention_backend == "trtllm_mha"
|
||||
|
||||
@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
|
||||
"layer_id": 0,
|
||||
}
|
||||
|
||||
ROPE_BASE = 10000
|
||||
ROPE_SCALING_CONFIG = {
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
"rope_type": "deepseek_yarn",
|
||||
}
|
||||
|
||||
|
||||
def build_rotary_emb(config, device=None):
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
|
||||
dev = device or config["device"]
|
||||
rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
|
||||
rotary = get_rope_wrapper(
|
||||
head_size=config["qk_rope_head_dim"],
|
||||
rotary_dim=config["qk_rope_head_dim"],
|
||||
max_position=config["context_len"],
|
||||
base=ROPE_BASE,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=dev,
|
||||
)
|
||||
rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
|
||||
return rotary
|
||||
|
||||
|
||||
# Centralized test cases for different test scenarios
|
||||
TEST_CASES = {
|
||||
"basic_functionality": [
|
||||
@@ -63,18 +94,36 @@ TEST_CASES = {
|
||||
],
|
||||
"decode_output_match": [
|
||||
{
|
||||
"name": "single",
|
||||
"name": "single_fp16",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 32,
|
||||
"description": "Single vs reference",
|
||||
"description": "Single FP16 vs reference",
|
||||
},
|
||||
{
|
||||
"name": "batch",
|
||||
"name": "single_fp8",
|
||||
"batch_size": 1,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 64,
|
||||
"tolerance": 1e-1,
|
||||
"kv_cache_dtype": torch.float8_e4m3fn,
|
||||
"description": "Single FP8 vs reference",
|
||||
},
|
||||
{
|
||||
"name": "batch_fp16",
|
||||
"batch_size": 32,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 32,
|
||||
"description": "Batch vs reference",
|
||||
"description": "Batch FP16 vs reference",
|
||||
},
|
||||
{
|
||||
"name": "batch_fp8",
|
||||
"batch_size": 32,
|
||||
"max_seq_len": 64,
|
||||
"page_size": 64,
|
||||
"tolerance": 1e-1,
|
||||
"kv_cache_dtype": torch.float8_e4m3fn,
|
||||
"description": "Batch FP8 vs reference",
|
||||
},
|
||||
],
|
||||
"page_size_consistency": [
|
||||
@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
layer,
|
||||
)
|
||||
|
||||
def _create_qkv_tensors(self, batch_size, config):
|
||||
"""Create Q, K, V tensors for testing."""
|
||||
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
||||
device = config["device"]
|
||||
dtype = config["dtype"]
|
||||
def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
|
||||
"""Create Q, K, V random tensors for given batch size with separate MLA components.
|
||||
|
||||
q = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], head_dim),
|
||||
dtype=dtype,
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
config: Configuration dict with model dims and device.
|
||||
dtype_override: Optional torch dtype to override config["dtype"].
|
||||
|
||||
Returns:
|
||||
Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
|
||||
"""
|
||||
device = config["device"]
|
||||
target_dtype = dtype_override or config["dtype"]
|
||||
|
||||
# Create separate nope and rope components for Q
|
||||
q_nope = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
k = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
|
||||
q_rope = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create separate nope and rope components for K
|
||||
k_nope = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
k_rope = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# V tensor (unchanged)
|
||||
v = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
|
||||
dtype=dtype,
|
||||
dtype=config["dtype"],
|
||||
device=device,
|
||||
)
|
||||
return q, k, v
|
||||
|
||||
return q_nope, q_rope, k_nope, k_rope, v
|
||||
|
||||
def _create_forward_batch(
|
||||
self, batch_size, seq_lens, backend, model_runner, config
|
||||
@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
)
|
||||
fb.req_to_token_pool = model_runner.req_to_token_pool
|
||||
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
|
||||
# Add position information for RoPE
|
||||
fb.positions = torch.arange(batch_size, device=config["device"])
|
||||
|
||||
return fb
|
||||
|
||||
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
|
||||
@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
for token_idx in range(seq_len - 1):
|
||||
# Create random K components for MLA
|
||||
cache_k_nope = torch.randn(
|
||||
(1, config["qk_nope_head_dim"]),
|
||||
(1, config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
batch_size, seq_lens, [model_runner_trtllm], layer, config
|
||||
)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
# Create Q, K, V tensors with separate MLA components
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
||||
batch_size, config
|
||||
)
|
||||
|
||||
# Run forward decode
|
||||
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
|
||||
# Run forward decode with separate MLA components
|
||||
output = trtllm_backend.forward_decode(
|
||||
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||
)
|
||||
|
||||
# Basic checks
|
||||
expected_shape = (
|
||||
@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
|
||||
|
||||
# Create components
|
||||
(
|
||||
@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
|
||||
# Create Q, K, V tensors for current decode step
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
|
||||
q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
|
||||
self._create_qkv_tensors(batch_size, config)
|
||||
)
|
||||
q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
|
||||
q_nope_ref.clone(),
|
||||
q_rope_ref.clone(),
|
||||
k_nope_ref.clone(),
|
||||
k_rope_ref.clone(),
|
||||
v_ref.clone(),
|
||||
)
|
||||
tolerance = config["tolerance"]
|
||||
|
||||
extra_args = {}
|
||||
if use_fp8:
|
||||
# TRT kernel applies RoPE + FP8 quantization internally
|
||||
# pre-apply RoPE on the reference (FlashInfer) path here so
|
||||
# both paths share the same rope params/cache while keeping
|
||||
# the TRT path unrotated.
|
||||
rotary_emb = build_rotary_emb(config)
|
||||
q_rope_ref, k_rope_ref = rotary_emb(
|
||||
fb_reference.positions, q_rope_ref, k_rope_ref
|
||||
)
|
||||
extra_args = {
|
||||
"cos_sin_cache": rotary_emb.cos_sin_cache,
|
||||
"is_neox": rotary_emb.is_neox_style,
|
||||
}
|
||||
|
||||
dtype = q_rope_ref.dtype
|
||||
q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||
q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||
k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||
k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||
|
||||
# Run forward decode on both backends
|
||||
out_trtllm = trtllm_backend.forward_decode(
|
||||
q.clone(), k.clone(), v.clone(), layer, fb_trtllm
|
||||
q_nope_trt,
|
||||
k_nope_trt,
|
||||
None,
|
||||
layer,
|
||||
fb_trtllm,
|
||||
q_rope=q_rope_trt,
|
||||
k_rope=k_rope_trt,
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
# Reference backend should also take separate components, not concatenated
|
||||
out_reference = reference_backend.forward_decode(
|
||||
q.clone(), k.clone(), v.clone(), layer, fb_reference
|
||||
q_nope_ref,
|
||||
k_nope_ref,
|
||||
v_ref,
|
||||
layer,
|
||||
fb_reference,
|
||||
q_rope=q_rope_ref,
|
||||
k_rope=k_rope_ref,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
comparison_passed = compare_outputs(
|
||||
out_trtllm, out_reference, tolerance=config["tolerance"]
|
||||
out_trtllm, out_reference, tolerance=tolerance
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
batch_size, seq_lens, [model_runner], layer, config
|
||||
)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
# Create Q, K, V tensors with separate MLA components
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
||||
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
||||
batch_size, config
|
||||
)
|
||||
|
||||
# Run forward decode
|
||||
output = backend.forward_decode(q, k, v, layer, fb)
|
||||
# Run forward decode with separate MLA components
|
||||
output = backend.forward_decode(
|
||||
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||
)
|
||||
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
)
|
||||
backend.init_forward_metadata(fb)
|
||||
|
||||
# Create Q, K, V tensors
|
||||
# Create Q, K, V tensors with separate MLA components
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
||||
q = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], head_dim),
|
||||
q_nope = torch.randn(
|
||||
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
k = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], head_dim),
|
||||
k_nope = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
v = None
|
||||
q_rope = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
config["num_attention_heads"],
|
||||
config["qk_rope_head_dim"],
|
||||
),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
k_rope = torch.randn(
|
||||
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
||||
dtype=config["dtype"],
|
||||
device=config["device"],
|
||||
)
|
||||
v = None # Test with None v
|
||||
|
||||
# Run forward decode
|
||||
output = backend.forward_decode(q, k, v, layer, fb)
|
||||
output = backend.forward_decode(
|
||||
q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||
)
|
||||
|
||||
# Shape and sanity checks
|
||||
expected_shape = (
|
||||
|
||||
Reference in New Issue
Block a user