From 311de47bb7a8baf646f58a472ea1b6712bd51ff6 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Wed, 17 Sep 2025 06:49:22 +0800 Subject: [PATCH] [2/2] Speed up trtllm_mla attention backend (#10474) --- .../srt/layers/attention/trtllm_mla_backend.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 9b6309d4c..ea316150e 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import ( from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda, is_flashinfer_available if is_flashinfer_available(): import flashinfer @@ -32,6 +32,11 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInfo +_is_cuda = is_cuda() + +if _is_cuda: + from sgl_kernel import concat_mla_absorb_q + # Constants DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB @@ -482,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): q_rope_reshaped = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - query = torch.cat([q_nope, q_rope_reshaped], dim=-1) + if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64: + query = concat_mla_absorb_q(q_nope, q_rope_reshaped) + else: + query = torch.cat([q_nope, q_rope_reshaped], dim=-1) else: # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function query = q.view(-1, layer.tp_q_head_num, layer.head_dim)