From 646cef2e2ea553843d10e66409ba744973509376 Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Fri, 4 Jul 2025 00:58:20 +0800 Subject: [PATCH] support qwen3 dense model dp attention (#7681) --- python/sglang/srt/models/qwen2.py | 8 ++++- python/sglang/srt/models/qwen3.py | 58 ++++++++++++++++++++++--------- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 714d53fe6..3f6dba752 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, @@ -264,6 +265,7 @@ class Qwen2Model(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, + enable_tp=not global_server_args_dict["enable_dp_attention"], prefix=add_prefix("embed_tokens", prefix), ) else: @@ -332,7 +334,11 @@ class Qwen2Model(nn.Module): } ) else: - hidden_states, _ = self.norm(hidden_states, residual) + if hidden_states.shape[0] != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states # If this function is called, it should always initialize KV cache scale diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index ae7bbfd4c..2035b6c11 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -14,6 +14,8 @@ from sglang.srt.distributed import ( split_tensor_along_last_dim, tensor_model_parallel_all_gather, ) +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor @@ -54,18 +56,21 @@ class Qwen3Attention(nn.Module): self.hidden_size = hidden_size self.tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads - assert self.total_num_heads % self.tp_size == 0 - self.num_heads = self.total_num_heads // self.tp_size + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= self.tp_size: + if self.total_num_kv_heads >= attn_tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % self.tp_size == 0 + assert self.total_num_kv_heads % attn_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert self.tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -84,6 +89,8 @@ class Qwen3Attention(nn.Module): self.total_num_kv_heads, bias=attention_bias, quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( @@ -91,6 +98,9 @@ class Qwen3Attention(nn.Module): hidden_size, bias=attention_bias, quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + reduce_results=False, prefix=add_prefix("o_proj", prefix), ) @@ -176,6 +186,18 @@ class Qwen3DecoderLayer(nn.Module): config.hidden_size, eps=config.rms_norm_eps ) + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=False, + is_previous_layer_sparse=False, + ) + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + ) + def forward( self, positions: torch.Tensor, @@ -184,20 +206,24 @@ class Qwen3DecoderLayer(nn.Module): residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) hidden_states = self.mlp(hidden_states) + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) return hidden_states, residual