TRTLLM-MLA FP8 path (#8638)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
@@ -60,6 +60,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
|
|||||||
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
|
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint)
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --kv-cache-dtype fp8_e4m3 --trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
- Ascend
|
- Ascend
|
||||||
```bash
|
```bash
|
||||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
|
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
|
||||||
|
|||||||
@@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
)
|
)
|
||||||
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
||||||
|
|
||||||
|
def quantize_and_rope_for_fp8(
|
||||||
|
self,
|
||||||
|
q_nope: torch.Tensor,
|
||||||
|
q_rope: torch.Tensor,
|
||||||
|
k_nope: torch.Tensor,
|
||||||
|
k_rope: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
cos_sin_cache: torch.Tensor,
|
||||||
|
is_neox: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Quantize and apply RoPE for FP8 attention path.
|
||||||
|
|
||||||
|
This function handles the FP8 quantization and RoPE application for MLA attention.
|
||||||
|
It takes separate query/key nope and rope components, applies RoPE to the rope parts,
|
||||||
|
quantizes all components to FP8, and merges the query components into a single tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
||||||
|
- expected dtype: torch.bfloat16
|
||||||
|
q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
||||||
|
- expected dtype: torch.bfloat16
|
||||||
|
k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
||||||
|
- expected dtype: torch.bfloat16
|
||||||
|
k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
||||||
|
- expected dtype: torch.bfloat16
|
||||||
|
forward_batch: Forward batch containing position information
|
||||||
|
cos_sin_cache: Precomputed cosine/sine cache for RoPE
|
||||||
|
- expected dtype: matches q_/k_ input dtype (torch.bfloat16)
|
||||||
|
is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
|
||||||
|
- merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
||||||
|
- k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
|
||||||
|
- k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
||||||
|
"""
|
||||||
|
attn_dtype = torch.float8_e4m3fn
|
||||||
|
q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
|
||||||
|
|
||||||
|
# Allocate output tensors with FP8 dtype
|
||||||
|
# Query output will contain merged nope + rope components
|
||||||
|
q_out = q_rope.new_empty(
|
||||||
|
q_len,
|
||||||
|
num_heads,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
dtype=attn_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Key outputs maintain original shapes but with FP8 dtype
|
||||||
|
k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
|
||||||
|
k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
|
||||||
|
|
||||||
|
# Apply RoPE and quantize all components in a single fused kernel call
|
||||||
|
# This kernel handles:
|
||||||
|
# 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
|
||||||
|
# 2. Quantization of all components to FP8 format
|
||||||
|
# 3. Output placement into pre-allocated tensors
|
||||||
|
flashinfer.rope.mla_rope_quantize_fp8(
|
||||||
|
q_rope=q_rope,
|
||||||
|
k_rope=k_rope,
|
||||||
|
q_nope=q_nope,
|
||||||
|
k_nope=k_nope,
|
||||||
|
cos_sin_cache=cos_sin_cache,
|
||||||
|
pos_ids=forward_batch.positions,
|
||||||
|
is_neox=is_neox,
|
||||||
|
quantize_dtype=attn_dtype,
|
||||||
|
# Output tensor slicing: q_out contains [nope_part, rope_part]
|
||||||
|
q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
|
||||||
|
k_rope_out=k_rope_out,
|
||||||
|
q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
|
||||||
|
k_nope_out=k_nope_out,
|
||||||
|
# Quantization scales (set to 1.0 for no additional scaling)
|
||||||
|
quant_scale_q=1.0,
|
||||||
|
quant_scale_kv=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return q_out, k_nope_out, k_rope_out
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor, # q_nope
|
||||||
k: torch.Tensor,
|
k: torch.Tensor, # k_nope
|
||||||
v: torch.Tensor,
|
v: torch.Tensor, # not used in this backend
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
cos_sin_cache: Optional[torch.Tensor] = None,
|
||||||
|
is_neox: Optional[bool] = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Run forward for decode using TRTLLM MLA kernel."""
|
"""Run forward for decode using TRTLLM MLA kernel."""
|
||||||
# Save KV cache if requested
|
merge_query = q_rope is not None
|
||||||
if k is not None and save_kv_cache:
|
if self.data_type == torch.float8_e4m3fn:
|
||||||
cache_loc = forward_batch.out_cache_loc
|
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
||||||
if k_rope is not None:
|
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
||||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
assert all(
|
||||||
layer, cache_loc, k, k_rope
|
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
||||||
|
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
||||||
|
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
||||||
|
q,
|
||||||
|
q_rope,
|
||||||
|
k.squeeze(1),
|
||||||
|
k_rope.squeeze(1),
|
||||||
|
forward_batch,
|
||||||
|
cos_sin_cache,
|
||||||
|
is_neox,
|
||||||
|
)
|
||||||
|
merge_query = False
|
||||||
|
|
||||||
|
# Save KV cache if requested
|
||||||
|
if save_kv_cache:
|
||||||
|
assert (
|
||||||
|
k is not None and k_rope is not None
|
||||||
|
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
||||||
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, k_rope
|
||||||
)
|
)
|
||||||
elif v is not None:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
|
||||||
|
|
||||||
# Prepare query tensor inline
|
# Prepare query tensor inline
|
||||||
if q_rope is not None:
|
if merge_query:
|
||||||
# q contains NOPE part (v_head_dim)
|
# For FP16 path, we merge the query and rope parts into a single tensor
|
||||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
q_rope_reshaped = q_rope.view(
|
q_rope_reshaped = q_rope.view(
|
||||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
)
|
)
|
||||||
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
||||||
else:
|
else:
|
||||||
# q already has both parts
|
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
||||||
@@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
# Prepare KV cache inline
|
# Prepare KV cache inline
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
|
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
||||||
# TRT-LLM expects single KV data with extra dimension
|
|
||||||
kv_cache = pages.unsqueeze(1)
|
|
||||||
|
|
||||||
# Get metadata
|
# Get metadata
|
||||||
metadata = (
|
metadata = (
|
||||||
@@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
or self.forward_metadata
|
or self.forward_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
# Scale computation for TRTLLM MLA kernel:
|
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
||||||
# - BMM1 scale = q_scale * k_scale * softmax_scale
|
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
|
||||||
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
|
# Scale components:
|
||||||
# - k_scale is read from model checkpoint if available
|
# - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
|
||||||
# TODO: Change once fp8 path is supported
|
# - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
|
||||||
|
# - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
|
||||||
|
# This unified approach works for both FP16 and FP8 quantized attention paths.
|
||||||
q_scale = 1.0
|
q_scale = 1.0
|
||||||
k_scale = (
|
k_scale = (
|
||||||
layer.k_scale_float
|
layer.k_scale_float
|
||||||
|
|||||||
@@ -1196,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
|
||||||
|
"""
|
||||||
|
Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.current_attention_backend == "trtllm_mla"
|
||||||
|
and forward_batch.forward_mode.is_decode_or_idle()
|
||||||
|
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
def forward_absorb_prepare(
|
def forward_absorb_prepare(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -1275,6 +1285,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||||
|
|
||||||
q_nope_out = q_nope_out.transpose(0, 1)
|
q_nope_out = q_nope_out.transpose(0, 1)
|
||||||
|
|
||||||
|
if not self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
|
||||||
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||||
@@ -1288,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
or self.current_attention_backend == "cutlass_mla"
|
or self.current_attention_backend == "cutlass_mla"
|
||||||
or self.current_attention_backend == "trtllm_mla"
|
or self.current_attention_backend == "trtllm_mla"
|
||||||
):
|
):
|
||||||
|
extra_args = {}
|
||||||
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||||
|
extra_args = {
|
||||||
|
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
||||||
|
"is_neox": self.rotary_emb.is_neox_style,
|
||||||
|
}
|
||||||
attn_output = self.attn_mqa(
|
attn_output = self.attn_mqa(
|
||||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
q_nope_out,
|
||||||
|
k_nope,
|
||||||
|
k_nope,
|
||||||
|
forward_batch,
|
||||||
|
q_rope=q_pe,
|
||||||
|
k_rope=k_pe,
|
||||||
|
**extra_args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||||
|
|||||||
@@ -432,7 +432,10 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.page_size = 128
|
self.page_size = 128
|
||||||
|
|
||||||
if self.attention_backend == "trtllm_mla":
|
if (
|
||||||
|
self.attention_backend == "trtllm_mla"
|
||||||
|
or self.decode_attention_backend == "trtllm_mla"
|
||||||
|
):
|
||||||
if not is_sm100_supported():
|
if not is_sm100_supported():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
||||||
@@ -443,11 +446,17 @@ class ServerArgs:
|
|||||||
f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
|
f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
|
||||||
)
|
)
|
||||||
self.page_size = 64
|
self.page_size = 64
|
||||||
|
|
||||||
if self.speculative_algorithm is not None:
|
if self.speculative_algorithm is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"trtllm_mla backend does not support speculative decoding yet."
|
"trtllm_mla backend does not support speculative decoding yet."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
|
||||||
|
raise ValueError(
|
||||||
|
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.attention_backend == "trtllm_mha"
|
self.attention_backend == "trtllm_mha"
|
||||||
or self.decode_attention_backend == "trtllm_mha"
|
or self.decode_attention_backend == "trtllm_mha"
|
||||||
|
|||||||
@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
|
|||||||
"layer_id": 0,
|
"layer_id": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ROPE_BASE = 10000
|
||||||
|
ROPE_SCALING_CONFIG = {
|
||||||
|
"beta_fast": 32,
|
||||||
|
"beta_slow": 1,
|
||||||
|
"factor": 40,
|
||||||
|
"mscale": 1.0,
|
||||||
|
"mscale_all_dim": 1.0,
|
||||||
|
"original_max_position_embeddings": 4096,
|
||||||
|
"type": "yarn",
|
||||||
|
"rope_type": "deepseek_yarn",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_rotary_emb(config, device=None):
|
||||||
|
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||||
|
|
||||||
|
dev = device or config["device"]
|
||||||
|
rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
|
||||||
|
rotary = get_rope_wrapper(
|
||||||
|
head_size=config["qk_rope_head_dim"],
|
||||||
|
rotary_dim=config["qk_rope_head_dim"],
|
||||||
|
max_position=config["context_len"],
|
||||||
|
base=ROPE_BASE,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=False,
|
||||||
|
device=dev,
|
||||||
|
)
|
||||||
|
rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
|
||||||
|
return rotary
|
||||||
|
|
||||||
|
|
||||||
# Centralized test cases for different test scenarios
|
# Centralized test cases for different test scenarios
|
||||||
TEST_CASES = {
|
TEST_CASES = {
|
||||||
"basic_functionality": [
|
"basic_functionality": [
|
||||||
@@ -63,18 +94,36 @@ TEST_CASES = {
|
|||||||
],
|
],
|
||||||
"decode_output_match": [
|
"decode_output_match": [
|
||||||
{
|
{
|
||||||
"name": "single",
|
"name": "single_fp16",
|
||||||
"batch_size": 1,
|
"batch_size": 1,
|
||||||
"max_seq_len": 64,
|
"max_seq_len": 64,
|
||||||
"page_size": 32,
|
"page_size": 32,
|
||||||
"description": "Single vs reference",
|
"description": "Single FP16 vs reference",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "batch",
|
"name": "single_fp8",
|
||||||
|
"batch_size": 1,
|
||||||
|
"max_seq_len": 64,
|
||||||
|
"page_size": 64,
|
||||||
|
"tolerance": 1e-1,
|
||||||
|
"kv_cache_dtype": torch.float8_e4m3fn,
|
||||||
|
"description": "Single FP8 vs reference",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "batch_fp16",
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"max_seq_len": 64,
|
"max_seq_len": 64,
|
||||||
"page_size": 32,
|
"page_size": 32,
|
||||||
"description": "Batch vs reference",
|
"description": "Batch FP16 vs reference",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "batch_fp8",
|
||||||
|
"batch_size": 32,
|
||||||
|
"max_seq_len": 64,
|
||||||
|
"page_size": 64,
|
||||||
|
"tolerance": 1e-1,
|
||||||
|
"kv_cache_dtype": torch.float8_e4m3fn,
|
||||||
|
"description": "Batch FP8 vs reference",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"page_size_consistency": [
|
"page_size_consistency": [
|
||||||
@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
layer,
|
layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_qkv_tensors(self, batch_size, config):
|
def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
|
||||||
"""Create Q, K, V tensors for testing."""
|
"""Create Q, K, V random tensors for given batch size with separate MLA components.
|
||||||
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
|
||||||
device = config["device"]
|
|
||||||
dtype = config["dtype"]
|
|
||||||
|
|
||||||
q = torch.randn(
|
Args:
|
||||||
(batch_size, config["num_attention_heads"], head_dim),
|
batch_size: Batch size.
|
||||||
dtype=dtype,
|
config: Configuration dict with model dims and device.
|
||||||
|
dtype_override: Optional torch dtype to override config["dtype"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
|
||||||
|
"""
|
||||||
|
device = config["device"]
|
||||||
|
target_dtype = dtype_override or config["dtype"]
|
||||||
|
|
||||||
|
# Create separate nope and rope components for Q
|
||||||
|
q_nope = torch.randn(
|
||||||
|
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
||||||
|
dtype=config["dtype"],
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
k = torch.randn(
|
q_rope = torch.randn(
|
||||||
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
|
(batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
|
||||||
|
dtype=config["dtype"],
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create separate nope and rope components for K
|
||||||
|
k_nope = torch.randn(
|
||||||
|
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
||||||
|
dtype=config["dtype"],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
k_rope = torch.randn(
|
||||||
|
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
||||||
|
dtype=config["dtype"],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# V tensor (unchanged)
|
||||||
v = torch.randn(
|
v = torch.randn(
|
||||||
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
|
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
|
||||||
dtype=dtype,
|
dtype=config["dtype"],
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
return q, k, v
|
|
||||||
|
return q_nope, q_rope, k_nope, k_rope, v
|
||||||
|
|
||||||
def _create_forward_batch(
|
def _create_forward_batch(
|
||||||
self, batch_size, seq_lens, backend, model_runner, config
|
self, batch_size, seq_lens, backend, model_runner, config
|
||||||
@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
)
|
)
|
||||||
fb.req_to_token_pool = model_runner.req_to_token_pool
|
fb.req_to_token_pool = model_runner.req_to_token_pool
|
||||||
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||||
|
|
||||||
|
# Add position information for RoPE
|
||||||
|
fb.positions = torch.arange(batch_size, device=config["device"])
|
||||||
|
|
||||||
return fb
|
return fb
|
||||||
|
|
||||||
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
|
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
|
||||||
@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
for token_idx in range(seq_len - 1):
|
for token_idx in range(seq_len - 1):
|
||||||
# Create random K components for MLA
|
# Create random K components for MLA
|
||||||
cache_k_nope = torch.randn(
|
cache_k_nope = torch.randn(
|
||||||
(1, config["qk_nope_head_dim"]),
|
(1, config["kv_lora_rank"]),
|
||||||
dtype=config["dtype"],
|
dtype=config["dtype"],
|
||||||
device=config["device"],
|
device=config["device"],
|
||||||
)
|
)
|
||||||
@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
batch_size, seq_lens, [model_runner_trtllm], layer, config
|
batch_size, seq_lens, [model_runner_trtllm], layer, config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create Q, K, V tensors
|
# Create Q, K, V tensors with separate MLA components
|
||||||
torch.manual_seed(config["seed_qkv"])
|
torch.manual_seed(config["seed_qkv"])
|
||||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
||||||
|
batch_size, config
|
||||||
|
)
|
||||||
|
|
||||||
# Run forward decode
|
# Run forward decode with separate MLA components
|
||||||
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
|
output = trtllm_backend.forward_decode(
|
||||||
|
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||||
|
)
|
||||||
|
|
||||||
# Basic checks
|
# Basic checks
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
config = self._merge_config(test_case)
|
config = self._merge_config(test_case)
|
||||||
batch_size = config["batch_size"]
|
batch_size = config["batch_size"]
|
||||||
max_seq_len = config["max_seq_len"]
|
max_seq_len = config["max_seq_len"]
|
||||||
|
use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
|
||||||
|
|
||||||
# Create components
|
# Create components
|
||||||
(
|
(
|
||||||
@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
|
|
||||||
# Create Q, K, V tensors for current decode step
|
# Create Q, K, V tensors for current decode step
|
||||||
torch.manual_seed(config["seed_qkv"])
|
torch.manual_seed(config["seed_qkv"])
|
||||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
|
||||||
|
q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
|
||||||
|
self._create_qkv_tensors(batch_size, config)
|
||||||
|
)
|
||||||
|
q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
|
||||||
|
q_nope_ref.clone(),
|
||||||
|
q_rope_ref.clone(),
|
||||||
|
k_nope_ref.clone(),
|
||||||
|
k_rope_ref.clone(),
|
||||||
|
v_ref.clone(),
|
||||||
|
)
|
||||||
|
tolerance = config["tolerance"]
|
||||||
|
|
||||||
|
extra_args = {}
|
||||||
|
if use_fp8:
|
||||||
|
# TRT kernel applies RoPE + FP8 quantization internally
|
||||||
|
# pre-apply RoPE on the reference (FlashInfer) path here so
|
||||||
|
# both paths share the same rope params/cache while keeping
|
||||||
|
# the TRT path unrotated.
|
||||||
|
rotary_emb = build_rotary_emb(config)
|
||||||
|
q_rope_ref, k_rope_ref = rotary_emb(
|
||||||
|
fb_reference.positions, q_rope_ref, k_rope_ref
|
||||||
|
)
|
||||||
|
extra_args = {
|
||||||
|
"cos_sin_cache": rotary_emb.cos_sin_cache,
|
||||||
|
"is_neox": rotary_emb.is_neox_style,
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype = q_rope_ref.dtype
|
||||||
|
q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||||
|
q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||||
|
k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||||
|
k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
||||||
|
|
||||||
# Run forward decode on both backends
|
# Run forward decode on both backends
|
||||||
out_trtllm = trtllm_backend.forward_decode(
|
out_trtllm = trtllm_backend.forward_decode(
|
||||||
q.clone(), k.clone(), v.clone(), layer, fb_trtllm
|
q_nope_trt,
|
||||||
|
k_nope_trt,
|
||||||
|
None,
|
||||||
|
layer,
|
||||||
|
fb_trtllm,
|
||||||
|
q_rope=q_rope_trt,
|
||||||
|
k_rope=k_rope_trt,
|
||||||
|
**extra_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reference backend should also take separate components, not concatenated
|
||||||
out_reference = reference_backend.forward_decode(
|
out_reference = reference_backend.forward_decode(
|
||||||
q.clone(), k.clone(), v.clone(), layer, fb_reference
|
q_nope_ref,
|
||||||
|
k_nope_ref,
|
||||||
|
v_ref,
|
||||||
|
layer,
|
||||||
|
fb_reference,
|
||||||
|
q_rope=q_rope_ref,
|
||||||
|
k_rope=k_rope_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compare outputs
|
# Compare outputs
|
||||||
comparison_passed = compare_outputs(
|
comparison_passed = compare_outputs(
|
||||||
out_trtllm, out_reference, tolerance=config["tolerance"]
|
out_trtllm, out_reference, tolerance=tolerance
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
batch_size, seq_lens, [model_runner], layer, config
|
batch_size, seq_lens, [model_runner], layer, config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create Q, K, V tensors
|
# Create Q, K, V tensors with separate MLA components
|
||||||
torch.manual_seed(config["seed_qkv"])
|
torch.manual_seed(config["seed_qkv"])
|
||||||
q, k, v = self._create_qkv_tensors(batch_size, config)
|
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
||||||
|
batch_size, config
|
||||||
|
)
|
||||||
|
|
||||||
# Run forward decode
|
# Run forward decode with separate MLA components
|
||||||
output = backend.forward_decode(q, k, v, layer, fb)
|
output = backend.forward_decode(
|
||||||
|
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||||
|
)
|
||||||
|
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
)
|
)
|
||||||
backend.init_forward_metadata(fb)
|
backend.init_forward_metadata(fb)
|
||||||
|
|
||||||
# Create Q, K, V tensors
|
# Create Q, K, V tensors with separate MLA components
|
||||||
torch.manual_seed(config["seed_qkv"])
|
torch.manual_seed(config["seed_qkv"])
|
||||||
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
q_nope = torch.randn(
|
||||||
q = torch.randn(
|
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
||||||
(batch_size, config["num_attention_heads"], head_dim),
|
|
||||||
dtype=config["dtype"],
|
dtype=config["dtype"],
|
||||||
device=config["device"],
|
device=config["device"],
|
||||||
)
|
)
|
||||||
k = torch.randn(
|
k_nope = torch.randn(
|
||||||
(batch_size, config["num_kv_heads"], head_dim),
|
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
||||||
dtype=config["dtype"],
|
dtype=config["dtype"],
|
||||||
device=config["device"],
|
device=config["device"],
|
||||||
)
|
)
|
||||||
v = None
|
q_rope = torch.randn(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
config["num_attention_heads"],
|
||||||
|
config["qk_rope_head_dim"],
|
||||||
|
),
|
||||||
|
dtype=config["dtype"],
|
||||||
|
device=config["device"],
|
||||||
|
)
|
||||||
|
k_rope = torch.randn(
|
||||||
|
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
||||||
|
dtype=config["dtype"],
|
||||||
|
device=config["device"],
|
||||||
|
)
|
||||||
|
v = None # Test with None v
|
||||||
|
|
||||||
# Run forward decode
|
# Run forward decode
|
||||||
output = backend.forward_decode(q, k, v, layer, fb)
|
output = backend.forward_decode(
|
||||||
|
q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
|
||||||
|
)
|
||||||
|
|
||||||
# Shape and sanity checks
|
# Shape and sanity checks
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
|
|||||||
Reference in New Issue
Block a user