From 3b9d97f3350daf130bd79a0d51f97c52bf7edf8f Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Fri, 10 Oct 2025 22:24:45 +0800 Subject: [PATCH] perf: optimize qwen-vl with symm mem allreduce (#11381) Co-authored-by: luoyuan.luo --- .../device_communicators/all_reduce_utils.py | 8 ++-- .../sglang/srt/distributed/parallel_state.py | 3 ++ python/sglang/srt/layers/rotary_embedding.py | 35 ++++++++++---- python/sglang/srt/managers/schedule_batch.py | 6 ++- python/sglang/srt/models/qwen2.py | 47 +++++++++++++++++-- 5 files changed, 82 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py index 99d6ebf2e..62e342a8e 100644 --- a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py @@ -3,13 +3,13 @@ MiB = 1024 * 1024 SYMM_MEM_ALL_REDUCE_MAX_SIZES = { 9: { 2: 64 * MiB, # 64 MB - 4: 32 * MiB, # 32 MB - 6: 64 * MiB, # 64 MB - 8: 64 * MiB, # 64 MB + 4: 64 * MiB, # 64 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB }, 10: { 2: 64 * MiB, # 64 MB - 4: 32 * MiB, # 32 MB + 4: 64 * MiB, # 64 MB 6: 128 * MiB, # 128 MB 8: 128 * MiB, # 128 MB }, diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 009aba52e..7a88911a7 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -603,8 +603,11 @@ class GroupCoordinator: def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm + symm_mem_comm = self.symm_mem_comm if pynccl_comm is not None and not pynccl_comm.disabled: pynccl_comm.all_reduce(input_) + elif symm_mem_comm is not None and not symm_mem_comm.disabled: + symm_mem_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=self.device_group) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index b86d1e9de..4dc8474b5 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1008,6 +1008,17 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): return cache +def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + x_t = x[0].clone() + x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3] + x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3] + return x_t + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -1020,12 +1031,14 @@ class MRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[List[int]] = None, + mrope_interleaved: bool = False, ) -> None: super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) self.mrope_section = mrope_section + self.mrope_interleaved = mrope_interleaved if self.mrope_section: expected_sum = rotary_dim // 2 actual_sum = sum(self.mrope_section) @@ -1086,15 +1099,18 @@ class MRotaryEmbedding(RotaryEmbedding): cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section - - cos = torch.cat( - [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], - dim=-1, - ) - sin = torch.cat( - [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], - dim=-1, - ) + if self.mrope_interleaved: + cos = apply_interleaved_rope(cos, self.mrope_section) + sin = apply_interleaved_rope(sin, self.mrope_section) + else: + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -1773,6 +1789,7 @@ def get_rope( is_neox_style, dtype, mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), ) else: rotary_emb = RotaryEmbedding( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3a10ab4b4..b0377bd98 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1766,7 +1766,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None - self.seq_lens_sum = self.seq_lens.sum().item() + if isinstance(self.seq_lens_cpu, torch.Tensor): + # CPU tensor + self.seq_lens_sum = int(self.seq_lens_cpu.sum().item()) + else: + self.seq_lens_sum = int(np.asarray(self.seq_lens_cpu).sum()) self.output_ids = self.output_ids[keep_indices_device] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 531f5b6e9..8b4d84cd3 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -27,6 +27,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -88,10 +89,17 @@ class Qwen2MLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x): + def forward( + self, + x, + should_allreduce_fusion: bool = False, + ): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x) + x, _ = self.down_proj( + x, + skip_all_reduce=should_allreduce_fusion, + ) return x @@ -109,9 +117,11 @@ class Qwen2Attention(nn.Module): quant_config: Optional[QuantizationConfig] = None, dual_chunk_attention_config: Optional[dict[str, Any]] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = hidden_size + tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 @@ -143,6 +153,8 @@ class Qwen2Attention(nn.Module): self.total_num_kv_heads, bias=True, quant_config=quant_config, + tp_rank=tp_rank, + tp_size=tp_size, prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( @@ -150,6 +162,8 @@ class Qwen2Attention(nn.Module): hidden_size, bias=False, quant_config=quant_config, + tp_rank=tp_rank, + tp_size=tp_size, prefix=add_prefix("o_proj", prefix), ) @@ -195,6 +209,7 @@ class Qwen2DecoderLayer(nn.Module): alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() + self.config = config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) @@ -216,6 +231,18 @@ class Qwen2DecoderLayer(nn.Module): dual_chunk_attention_config=dual_chunk_attention_config, prefix=add_prefix("self_attn", prefix), ) + + self.layer_id = layer_id + self.is_layer_sparse = False + is_previous_layer_sparse = False + + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, @@ -228,6 +255,14 @@ class Qwen2DecoderLayer(nn.Module): config.hidden_size, eps=config.rms_norm_eps ) + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, + is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1), + ) + def forward( self, positions: torch.Tensor, @@ -249,7 +284,13 @@ class Qwen2DecoderLayer(nn.Module): # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - hidden_states = self.mlp(hidden_states) + should_allreduce_fusion = ( + self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer( + forward_batch + ) + ) + + hidden_states = self.mlp(hidden_states, should_allreduce_fusion) return hidden_states, residual