From 0b6f535f66e092cf96f7b17092c35a3bdf801ed3 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Mon, 13 Oct 2025 17:51:25 +0800 Subject: [PATCH] [Reland] perf: optimize qwen-vl with symm mem allreduce (#11457) Co-authored-by: luoyuan.luo --- .../device_communicators/all_reduce_utils.py | 8 ++--- .../sglang/srt/distributed/parallel_state.py | 3 ++ python/sglang/srt/layers/rotary_embedding.py | 35 ++++++++++++++----- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py index 99d6ebf2e..62e342a8e 100644 --- a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py @@ -3,13 +3,13 @@ MiB = 1024 * 1024 SYMM_MEM_ALL_REDUCE_MAX_SIZES = { 9: { 2: 64 * MiB, # 64 MB - 4: 32 * MiB, # 32 MB - 6: 64 * MiB, # 64 MB - 8: 64 * MiB, # 64 MB + 4: 64 * MiB, # 64 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB }, 10: { 2: 64 * MiB, # 64 MB - 4: 32 * MiB, # 32 MB + 4: 64 * MiB, # 64 MB 6: 128 * MiB, # 128 MB 8: 128 * MiB, # 128 MB }, diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 78e3f2b9a..775e98ca0 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -615,8 +615,11 @@ class GroupCoordinator: def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm + symm_mem_comm = self.symm_mem_comm if pynccl_comm is not None and not pynccl_comm.disabled: pynccl_comm.all_reduce(input_) + elif symm_mem_comm is not None and not symm_mem_comm.disabled: + symm_mem_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=self.device_group) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 91e58f6a0..55c3121ba 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1008,6 +1008,17 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): return cache +def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + x_t = x[0].clone() + x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3] + x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3] + return x_t + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -1020,12 +1031,14 @@ class MRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[List[int]] = None, + mrope_interleaved: bool = False, ) -> None: super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) self.mrope_section = mrope_section + self.mrope_interleaved = mrope_interleaved if self.mrope_section: expected_sum = rotary_dim // 2 actual_sum = sum(self.mrope_section) @@ -1086,15 +1099,18 @@ class MRotaryEmbedding(RotaryEmbedding): cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section - - cos = torch.cat( - [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], - dim=-1, - ) - sin = torch.cat( - [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], - dim=-1, - ) + if self.mrope_interleaved: + cos = apply_interleaved_rope(cos, self.mrope_section) + sin = apply_interleaved_rope(sin, self.mrope_section) + else: + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -1768,6 +1784,7 @@ def get_rope( is_neox_style, dtype, mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), ) else: rotary_emb = RotaryEmbedding(