From c49c1d9226ad0a380aa854f60ea7daf5db191477 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 14 Jun 2025 06:19:31 +0800 Subject: [PATCH] Remove 200us slow concat kernel (part 2: srt) (#7020) --- .../layers/attention/cutlass_mla_backend.py | 44 ++++++++++++++----- python/sglang/srt/models/deepseek_v2.py | 6 ++- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index afa03434c..a49ed1ab5 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -233,25 +233,49 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, ): cache_loc = forward_batch.out_cache_loc if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - cache_loc, - k, - v, - ) - bs = forward_batch.batch_size + if k_rope is not None: + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc, + k, + k_rope, + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) + + # Reshape inputs + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = reshaped_q[:, :, : layer.v_head_dim] + q_rope = reshaped_q[:, :, layer.v_head_dim :] + + q_nope = q_nope.to(self.q_data_type) + q_rope = q_rope.to(self.q_data_type) + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) - o = cutlass_mla_decode( - q_nope_and_q_pe=reshape_q.to(self.q_data_type), + q_nope=q_nope, + q_pe=q_rope, kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim), seq_lens=forward_batch.seq_lens.to(torch.int32), page_table=self.forward_metadata.block_kv_indices, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0239736c7..c5c05b016 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1013,7 +1013,11 @@ class DeepseekV2AttentionMLA(nn.Module): def forward_absorb_core( self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator ): - if self.attention_backend == "fa3" or self.attention_backend == "flashinfer": + if ( + self.attention_backend == "fa3" + or self.attention_backend == "flashinfer" + or self.attention_backend == "cutlass_mla" + ): attn_output = self.attn_mqa( q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe )