diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 9b6309d4c..ea316150e 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import ( from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda, is_flashinfer_available if is_flashinfer_available(): import flashinfer @@ -32,6 +32,11 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInfo +_is_cuda = is_cuda() + +if _is_cuda: + from sgl_kernel import concat_mla_absorb_q + # Constants DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB @@ -482,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): q_rope_reshaped = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - query = torch.cat([q_nope, q_rope_reshaped], dim=-1) + if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64: + query = concat_mla_absorb_q(q_nope, q_rope_reshaped) + else: + query = torch.cat([q_nope, q_rope_reshaped], dim=-1) else: # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function query = q.view(-1, layer.tp_q_head_num, layer.head_dim)