From 063c3791fe2ecf8c175677d79dc19db38d4d07dc Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:47:49 +0800 Subject: [PATCH] Fix trtllm_mla slow concat kernel in MTP (#10777) --- .../srt/layers/attention/trtllm_mla_backend.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 3613afd17..7a3f31128 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -505,10 +505,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): q_rope_reshaped = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64: - query = concat_mla_absorb_q(q_nope, q_rope_reshaped) - else: - query = torch.cat([q_nope, q_rope_reshaped], dim=-1) + query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped) else: # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function query = q.view(-1, layer.tp_q_head_num, layer.head_dim) @@ -591,7 +588,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): q_rope = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - q = torch.cat([q, q_rope], dim=-1) + q = _concat_mla_absorb_q_general(q, q_rope) q = q.view(-1, layer.tp_q_head_num, layer.head_dim) @@ -716,3 +713,10 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend): kv_indptr_buf=self.kv_indptr[i], q_indptr_decode_buf=self.q_indptr_decode, ) + + +def _concat_mla_absorb_q_general(q_nope, q_rope): + if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64: + return concat_mla_absorb_q(q_nope, q_rope) + else: + return torch.cat([q_nope, q_rope], dim=-1)