diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 8d07d9933..188d772c7 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -18,7 +18,10 @@ import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton -from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.dp_attention import ( + get_attention_tp_size, + is_dp_attention_enabled, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode if TYPE_CHECKING: @@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend): (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) + self.enable_dp_attention = is_dp_attention_enabled() + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend): if self.use_mla: self.mla_indices_updater_prefill.update( forward_batch.req_pool_indices, - forward_batch.extend_prefix_lens, - sum(forward_batch.extend_prefix_lens_cpu), + forward_batch.seq_lens, + forward_batch.seq_lens_sum, forward_batch.extend_seq_lens, - max(forward_batch.extend_seq_lens_cpu), - forward_batch.seq_lens_cpu.max().item(), + forward_batch.extend_seq_lens.max().item(), + forward_batch.seq_lens.max().item(), spec_info=None, ) - self.mla_indices_updater_prefill.kv_indptr += ( - self.mla_indices_updater_prefill.qo_indptr - ) + + kv_indices = self.mla_indices_updater_prefill.kv_indices + self.forward_metadata = ForwardMetadata( self.mla_indices_updater_prefill.kv_indptr, - self.mla_indices_updater_prefill.kv_indices, + kv_indices, self.mla_indices_updater_prefill.qo_indptr, self.kv_last_page_len[:bs], self.mla_indices_updater_prefill.max_q_len, @@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend): assert len(k.shape) == 3 assert len(v.shape) == 3 - if kv_indices.shape[0] == 0: - o = flash_attn_varlen_func( - q, - k, - v, - qo_indptr, - qo_indptr, - max_q_len, - max_q_len, - softmax_scale=layer.scaling, - causal=True, - ) - return o - elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): - K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) - kvc, k_pe = torch.split( - K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 - ) - kvprefix = layer.kv_b_proj(kvc.contiguous())[0] + if forward_batch.forward_mode.is_extend(): + if kv_indices.shape[0] == 0: + o = flash_attn_varlen_func( + q, + k, + v, + qo_indptr, + qo_indptr, + max_q_len, + max_q_len, + softmax_scale=layer.scaling, + causal=True, + ) + return o + elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): + K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) + kvc, k_pe = torch.split( + K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 + ) + kvprefix = layer.kv_b_proj(kvc.contiguous())[0] - kvprefix = kvprefix.view( - -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim - ) - k_prefix, v_prefix = torch.split( - kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 - ) - k_prefix = torch.cat( - [ - k_prefix, - torch.broadcast_to( - k_pe, - (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), - ), - ], - dim=-1, - ) - assert ( - forward_batch.extend_prefix_lens.shape - == forward_batch.extend_seq_lens.shape - ) - k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu) - k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu) - assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu) - k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el]) - v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu) - v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu) - v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el]) + kvprefix = kvprefix.view( + -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim + ) + k_prefix, v_prefix = torch.split( + kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 + ) + k_prefix = torch.cat( + [ + k_prefix, + torch.broadcast_to( + k_pe, + (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), + ), + ], + dim=-1, + ) + assert ( + forward_batch.extend_prefix_lens.shape + == forward_batch.extend_seq_lens.shape + ) - o = flash_attn_varlen_func( - q, - k, - v, - qo_indptr, - kv_indptr, - max_q_len, - max_kv_len, - softmax_scale=layer.scaling, - causal=True, - ) - return o + k = k_prefix + v = v_prefix + + o = flash_attn_varlen_func( + q, + k, + v, + qo_indptr, + kv_indptr, + max_q_len, + max_kv_len, + softmax_scale=layer.scaling, + causal=True, + ) + return o + + else: + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + ) + else: + o = torch.empty_like(q) + + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, + kv_indptr, + kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + layer.scaling, + layer.logit_cap, + ) + K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim) + return o elif forward_batch.forward_mode.is_target_verify(): o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) mla_decode_fwd( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a2296b569..32726d11b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1085,7 +1085,13 @@ class DeepseekV2AttentionMLA(nn.Module): and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() ): - return AttnForwardMethod.MHA + if is_dp_attention_enabled(): + if sum(forward_batch.extend_prefix_lens_cpu) == 0: + return AttnForwardMethod.MHA + else: + return AttnForwardMethod.MLA + else: + return AttnForwardMethod.MHA else: return AttnForwardMethod.MLA else: