diff --git a/docs/advanced_features/attention_backend.md b/docs/advanced_features/attention_backend.md index 9aff14c58..e4c56ea53 100644 --- a/docs/advanced_features/attention_backend.md +++ b/docs/advanced_features/attention_backend.md @@ -60,6 +60,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code ``` +- TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint) +```bash +python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --kv-cache-dtype fp8_e4m3 --trust-remote-code +``` + - Ascend ```bash python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index f255f9ce2..d4ea74bf4 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ) forward_batch.decode_trtllm_mla_metadata = self.forward_metadata + def quantize_and_rope_for_fp8( + self, + q_nope: torch.Tensor, + q_rope: torch.Tensor, + k_nope: torch.Tensor, + k_rope: torch.Tensor, + forward_batch: ForwardBatch, + cos_sin_cache: torch.Tensor, + is_neox: bool, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize and apply RoPE for FP8 attention path. + + This function handles the FP8 quantization and RoPE application for MLA attention. + It takes separate query/key nope and rope components, applies RoPE to the rope parts, + quantizes all components to FP8, and merges the query components into a single tensor. + + Args: + q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank] + - expected dtype: torch.bfloat16 + q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim] + - expected dtype: torch.bfloat16 + k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank] + - expected dtype: torch.bfloat16 + k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim] + - expected dtype: torch.bfloat16 + forward_batch: Forward batch containing position information + cos_sin_cache: Precomputed cosine/sine cache for RoPE + - expected dtype: matches q_/k_ input dtype (torch.bfloat16) + is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation) + + Returns: + tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8 + - merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn + - k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn + - k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn + """ + attn_dtype = torch.float8_e4m3fn + q_len, num_heads = q_rope.shape[0], q_rope.shape[1] + + # Allocate output tensors with FP8 dtype + # Query output will contain merged nope + rope components + q_out = q_rope.new_empty( + q_len, + num_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + dtype=attn_dtype, + ) + + # Key outputs maintain original shapes but with FP8 dtype + k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype) + k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype) + + # Apply RoPE and quantize all components in a single fused kernel call + # This kernel handles: + # 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions + # 2. Quantization of all components to FP8 format + # 3. Output placement into pre-allocated tensors + flashinfer.rope.mla_rope_quantize_fp8( + q_rope=q_rope, + k_rope=k_rope, + q_nope=q_nope, + k_nope=k_nope, + cos_sin_cache=cos_sin_cache, + pos_ids=forward_batch.positions, + is_neox=is_neox, + quantize_dtype=attn_dtype, + # Output tensor slicing: q_out contains [nope_part, rope_part] + q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end + k_rope_out=k_rope_out, + q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning + k_nope_out=k_nope_out, + # Quantization scales (set to 1.0 for no additional scaling) + quant_scale_q=1.0, + quant_scale_kv=1.0, + ) + + return q_out, k_nope_out, k_rope_out + def forward_decode( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: torch.Tensor, # q_nope + k: torch.Tensor, # k_nope + v: torch.Tensor, # not used in this backend layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, + cos_sin_cache: Optional[torch.Tensor] = None, + is_neox: Optional[bool] = False, ) -> torch.Tensor: """Run forward for decode using TRTLLM MLA kernel.""" + merge_query = q_rope is not None + if self.data_type == torch.float8_e4m3fn: + # For FP8 path, we quantize the query and rope parts and merge them into a single tensor + # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend + assert all( + x is not None for x in [q_rope, k_rope, cos_sin_cache] + ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None." + q, k, k_rope = self.quantize_and_rope_for_fp8( + q, + q_rope, + k.squeeze(1), + k_rope.squeeze(1), + forward_batch, + cos_sin_cache, + is_neox, + ) + merge_query = False + # Save KV cache if requested - if k is not None and save_kv_cache: - cache_loc = forward_batch.out_cache_loc - if k_rope is not None: - forward_batch.token_to_kv_pool.set_mla_kv_buffer( - layer, cache_loc, k, k_rope - ) - elif v is not None: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + assert ( + k is not None and k_rope is not None + ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None." + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, forward_batch.out_cache_loc, k, k_rope + ) # Prepare query tensor inline - if q_rope is not None: - # q contains NOPE part (v_head_dim) + if merge_query: + # For FP16 path, we merge the query and rope parts into a single tensor q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) q_rope_reshaped = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) query = torch.cat([q_nope, q_rope_reshaped], dim=-1) else: - # q already has both parts + # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function query = q.view(-1, layer.tp_q_head_num, layer.head_dim) # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1 @@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): # Prepare KV cache inline k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) - # TRT-LLM expects single KV data with extra dimension - kv_cache = pages.unsqueeze(1) + kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1) # Get metadata metadata = ( @@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): or self.forward_metadata ) - # Scale computation for TRTLLM MLA kernel: - # - BMM1 scale = q_scale * k_scale * softmax_scale - # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling - # - k_scale is read from model checkpoint if available - # TODO: Change once fp8 path is supported + # Scale computation for TRTLLM MLA kernel BMM1 operation: + # The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale + # Scale components: + # - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths) + # - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available) + # - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling + # This unified approach works for both FP16 and FP8 quantized attention paths. q_scale = 1.0 k_scale = ( layer.k_scale_float diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f59d7f248..235718ded 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1196,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module): output, _ = self.o_proj(attn_output) return output + def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool: + """ + Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path. + """ + return ( + self.current_attention_backend == "trtllm_mla" + and forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.attn_backend.data_type == torch.float8_e4m3fn + ) + def forward_absorb_prepare( self, positions: torch.Tensor, @@ -1275,7 +1285,9 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_nope_out = q_nope_out.transpose(0, 1) - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + if not self._fuse_rope_for_trtllm_mla(forward_batch): + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator @@ -1288,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module): or self.current_attention_backend == "cutlass_mla" or self.current_attention_backend == "trtllm_mla" ): + extra_args = {} + if self._fuse_rope_for_trtllm_mla(forward_batch): + extra_args = { + "cos_sin_cache": self.rotary_emb.cos_sin_cache, + "is_neox": self.rotary_emb.is_neox_style, + } attn_output = self.attn_mqa( - q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe + q_nope_out, + k_nope, + k_nope, + forward_batch, + q_rope=q_pe, + k_rope=k_pe, + **extra_args, ) else: q = torch.cat([q_nope_out, q_pe], dim=-1) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b4fd46748..9df37d168 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -432,7 +432,10 @@ class ServerArgs: ) self.page_size = 128 - if self.attention_backend == "trtllm_mla": + if ( + self.attention_backend == "trtllm_mla" + or self.decode_attention_backend == "trtllm_mla" + ): if not is_sm100_supported(): raise ValueError( "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend." @@ -443,11 +446,17 @@ class ServerArgs: f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64." ) self.page_size = 64 + if self.speculative_algorithm is not None: raise ValueError( "trtllm_mla backend does not support speculative decoding yet." ) + if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]: + raise ValueError( + "TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto." + ) + if ( self.attention_backend == "trtllm_mha" or self.decode_attention_backend == "trtllm_mha" diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index be3ed08f4..18a7f77ea 100755 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -43,6 +43,37 @@ DEFAULT_CONFIG = { "layer_id": 0, } +ROPE_BASE = 10000 +ROPE_SCALING_CONFIG = { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + "rope_type": "deepseek_yarn", +} + + +def build_rotary_emb(config, device=None): + from sglang.srt.layers.rotary_embedding import get_rope_wrapper + + dev = device or config["device"] + rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG) + rotary = get_rope_wrapper( + head_size=config["qk_rope_head_dim"], + rotary_dim=config["qk_rope_head_dim"], + max_position=config["context_len"], + base=ROPE_BASE, + rope_scaling=rope_scaling, + is_neox_style=False, + device=dev, + ) + rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev) + return rotary + + # Centralized test cases for different test scenarios TEST_CASES = { "basic_functionality": [ @@ -63,18 +94,36 @@ TEST_CASES = { ], "decode_output_match": [ { - "name": "single", + "name": "single_fp16", "batch_size": 1, "max_seq_len": 64, "page_size": 32, - "description": "Single vs reference", + "description": "Single FP16 vs reference", }, { - "name": "batch", + "name": "single_fp8", + "batch_size": 1, + "max_seq_len": 64, + "page_size": 64, + "tolerance": 1e-1, + "kv_cache_dtype": torch.float8_e4m3fn, + "description": "Single FP8 vs reference", + }, + { + "name": "batch_fp16", "batch_size": 32, "max_seq_len": 64, "page_size": 32, - "description": "Batch vs reference", + "description": "Batch FP16 vs reference", + }, + { + "name": "batch_fp8", + "batch_size": 32, + "max_seq_len": 64, + "page_size": 64, + "tolerance": 1e-1, + "kv_cache_dtype": torch.float8_e4m3fn, + "description": "Batch FP8 vs reference", }, ], "page_size_consistency": [ @@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase): layer, ) - def _create_qkv_tensors(self, batch_size, config): - """Create Q, K, V tensors for testing.""" - head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] - device = config["device"] - dtype = config["dtype"] + def _create_qkv_tensors(self, batch_size, config, dtype_override=None): + """Create Q, K, V random tensors for given batch size with separate MLA components. - q = torch.randn( - (batch_size, config["num_attention_heads"], head_dim), - dtype=dtype, + Args: + batch_size: Batch size. + config: Configuration dict with model dims and device. + dtype_override: Optional torch dtype to override config["dtype"]. + + Returns: + Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache) + """ + device = config["device"] + target_dtype = dtype_override or config["dtype"] + + # Create separate nope and rope components for Q + q_nope = torch.randn( + (batch_size, config["num_attention_heads"], config["kv_lora_rank"]), + dtype=config["dtype"], device=device, ) - k = torch.randn( - (batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device + q_rope = torch.randn( + (batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]), + dtype=config["dtype"], + device=device, ) + + # Create separate nope and rope components for K + k_nope = torch.randn( + (batch_size, config["num_kv_heads"], config["kv_lora_rank"]), + dtype=config["dtype"], + device=device, + ) + k_rope = torch.randn( + (batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]), + dtype=config["dtype"], + device=device, + ) + + # V tensor (unchanged) v = torch.randn( (batch_size, config["num_kv_heads"], config["v_head_dim"]), - dtype=dtype, + dtype=config["dtype"], device=device, ) - return q, k, v + + return q_nope, q_rope, k_nope, k_rope, v def _create_forward_batch( self, batch_size, seq_lens, backend, model_runner, config @@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase): ) fb.req_to_token_pool = model_runner.req_to_token_pool fb.token_to_kv_pool = model_runner.token_to_kv_pool + + # Add position information for RoPE + fb.positions = torch.arange(batch_size, device=config["device"]) + return fb def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config): @@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase): for token_idx in range(seq_len - 1): # Create random K components for MLA cache_k_nope = torch.randn( - (1, config["qk_nope_head_dim"]), + (1, config["kv_lora_rank"]), dtype=config["dtype"], device=config["device"], ) @@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase): batch_size, seq_lens, [model_runner_trtllm], layer, config ) - # Create Q, K, V tensors + # Create Q, K, V tensors with separate MLA components torch.manual_seed(config["seed_qkv"]) - q, k, v = self._create_qkv_tensors(batch_size, config) + q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors( + batch_size, config + ) - # Run forward decode - output = trtllm_backend.forward_decode(q, k, v, layer, fb) + # Run forward decode with separate MLA components + output = trtllm_backend.forward_decode( + q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope + ) # Basic checks expected_shape = ( @@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase): config = self._merge_config(test_case) batch_size = config["batch_size"] max_seq_len = config["max_seq_len"] + use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn # Create components ( @@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase): # Create Q, K, V tensors for current decode step torch.manual_seed(config["seed_qkv"]) - q, k, v = self._create_qkv_tensors(batch_size, config) + + q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = ( + self._create_qkv_tensors(batch_size, config) + ) + q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = ( + q_nope_ref.clone(), + q_rope_ref.clone(), + k_nope_ref.clone(), + k_rope_ref.clone(), + v_ref.clone(), + ) + tolerance = config["tolerance"] + + extra_args = {} + if use_fp8: + # TRT kernel applies RoPE + FP8 quantization internally + # pre-apply RoPE on the reference (FlashInfer) path here so + # both paths share the same rope params/cache while keeping + # the TRT path unrotated. + rotary_emb = build_rotary_emb(config) + q_rope_ref, k_rope_ref = rotary_emb( + fb_reference.positions, q_rope_ref, k_rope_ref + ) + extra_args = { + "cos_sin_cache": rotary_emb.cos_sin_cache, + "is_neox": rotary_emb.is_neox_style, + } + + dtype = q_rope_ref.dtype + q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype) + q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype) + k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype) + k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype) # Run forward decode on both backends out_trtllm = trtllm_backend.forward_decode( - q.clone(), k.clone(), v.clone(), layer, fb_trtllm + q_nope_trt, + k_nope_trt, + None, + layer, + fb_trtllm, + q_rope=q_rope_trt, + k_rope=k_rope_trt, + **extra_args, ) + + # Reference backend should also take separate components, not concatenated out_reference = reference_backend.forward_decode( - q.clone(), k.clone(), v.clone(), layer, fb_reference + q_nope_ref, + k_nope_ref, + v_ref, + layer, + fb_reference, + q_rope=q_rope_ref, + k_rope=k_rope_ref, ) # Compare outputs comparison_passed = compare_outputs( - out_trtllm, out_reference, tolerance=config["tolerance"] + out_trtllm, out_reference, tolerance=tolerance ) self.assertTrue( @@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase): batch_size, seq_lens, [model_runner], layer, config ) - # Create Q, K, V tensors + # Create Q, K, V tensors with separate MLA components torch.manual_seed(config["seed_qkv"]) - q, k, v = self._create_qkv_tensors(batch_size, config) + q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors( + batch_size, config + ) - # Run forward decode - output = backend.forward_decode(q, k, v, layer, fb) + # Run forward decode with separate MLA components + output = backend.forward_decode( + q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope + ) expected_shape = ( batch_size, @@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase): ) backend.init_forward_metadata(fb) - # Create Q, K, V tensors + # Create Q, K, V tensors with separate MLA components torch.manual_seed(config["seed_qkv"]) - head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] - q = torch.randn( - (batch_size, config["num_attention_heads"], head_dim), + q_nope = torch.randn( + (batch_size, config["num_attention_heads"], config["kv_lora_rank"]), dtype=config["dtype"], device=config["device"], ) - k = torch.randn( - (batch_size, config["num_kv_heads"], head_dim), + k_nope = torch.randn( + (batch_size, config["num_kv_heads"], config["kv_lora_rank"]), dtype=config["dtype"], device=config["device"], ) - v = None + q_rope = torch.randn( + ( + batch_size, + config["num_attention_heads"], + config["qk_rope_head_dim"], + ), + dtype=config["dtype"], + device=config["device"], + ) + k_rope = torch.randn( + (batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]), + dtype=config["dtype"], + device=config["device"], + ) + v = None # Test with None v # Run forward decode - output = backend.forward_decode(q, k, v, layer, fb) + output = backend.forward_decode( + q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope + ) # Shape and sanity checks expected_shape = (