diff --git a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py index 62e342a8e..99d6ebf2e 100644 --- a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py @@ -3,13 +3,13 @@ MiB = 1024 * 1024 SYMM_MEM_ALL_REDUCE_MAX_SIZES = { 9: { 2: 64 * MiB, # 64 MB - 4: 64 * MiB, # 64 MB - 6: 128 * MiB, # 128 MB - 8: 128 * MiB, # 128 MB + 4: 32 * MiB, # 32 MB + 6: 64 * MiB, # 64 MB + 8: 64 * MiB, # 64 MB }, 10: { 2: 64 * MiB, # 64 MB - 4: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB 6: 128 * MiB, # 128 MB 8: 128 * MiB, # 128 MB }, diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 7a88911a7..009aba52e 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -603,11 +603,8 @@ class GroupCoordinator: def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm - symm_mem_comm = self.symm_mem_comm if pynccl_comm is not None and not pynccl_comm.disabled: pynccl_comm.all_reduce(input_) - elif symm_mem_comm is not None and not symm_mem_comm.disabled: - symm_mem_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=self.device_group) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 4dc8474b5..b86d1e9de 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1008,17 +1008,6 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): return cache -def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: - """Apply interleaved MRoPE to 3D rotary embeddings. - Reorganizes frequency layout from chunked [TTT...HHH...WWW] to - interleaved [THTHWHTHW...TT], preserving frequency continuity. - """ - x_t = x[0].clone() - x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3] - x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3] - return x_t - - class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -1031,14 +1020,12 @@ class MRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[List[int]] = None, - mrope_interleaved: bool = False, ) -> None: super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) self.mrope_section = mrope_section - self.mrope_interleaved = mrope_interleaved if self.mrope_section: expected_sum = rotary_dim // 2 actual_sum = sum(self.mrope_section) @@ -1099,18 +1086,15 @@ class MRotaryEmbedding(RotaryEmbedding): cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section - if self.mrope_interleaved: - cos = apply_interleaved_rope(cos, self.mrope_section) - sin = apply_interleaved_rope(sin, self.mrope_section) - else: - cos = torch.cat( - [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], - dim=-1, - ) - sin = torch.cat( - [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], - dim=-1, - ) + + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -1789,7 +1773,6 @@ def get_rope( is_neox_style, dtype, mrope_section=rope_scaling["mrope_section"], - mrope_interleaved=rope_scaling.get("mrope_interleaved", False), ) else: rotary_emb = RotaryEmbedding( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b0377bd98..3a10ab4b4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1766,11 +1766,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None - if isinstance(self.seq_lens_cpu, torch.Tensor): - # CPU tensor - self.seq_lens_sum = int(self.seq_lens_cpu.sum().item()) - else: - self.seq_lens_sum = int(np.asarray(self.seq_lens_cpu).sum()) + self.seq_lens_sum = self.seq_lens.sum().item() self.output_ids = self.output_ids[keep_indices_device] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 8b4d84cd3..531f5b6e9 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -27,7 +27,6 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -89,17 +88,10 @@ class Qwen2MLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward( - self, - x, - should_allreduce_fusion: bool = False, - ): + def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj( - x, - skip_all_reduce=should_allreduce_fusion, - ) + x, _ = self.down_proj(x) return x @@ -117,11 +109,9 @@ class Qwen2Attention(nn.Module): quant_config: Optional[QuantizationConfig] = None, dual_chunk_attention_config: Optional[dict[str, Any]] = None, prefix: str = "", - alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = hidden_size - tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 @@ -153,8 +143,6 @@ class Qwen2Attention(nn.Module): self.total_num_kv_heads, bias=True, quant_config=quant_config, - tp_rank=tp_rank, - tp_size=tp_size, prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( @@ -162,8 +150,6 @@ class Qwen2Attention(nn.Module): hidden_size, bias=False, quant_config=quant_config, - tp_rank=tp_rank, - tp_size=tp_size, prefix=add_prefix("o_proj", prefix), ) @@ -209,7 +195,6 @@ class Qwen2DecoderLayer(nn.Module): alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() - self.config = config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) @@ -231,18 +216,6 @@ class Qwen2DecoderLayer(nn.Module): dual_chunk_attention_config=dual_chunk_attention_config, prefix=add_prefix("self_attn", prefix), ) - - self.layer_id = layer_id - self.is_layer_sparse = False - is_previous_layer_sparse = False - - self.layer_scatter_modes = LayerScatterModes.init_new( - layer_id=layer_id, - num_layers=config.num_hidden_layers, - is_layer_sparse=self.is_layer_sparse, - is_previous_layer_sparse=is_previous_layer_sparse, - ) - self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, @@ -255,14 +228,6 @@ class Qwen2DecoderLayer(nn.Module): config.hidden_size, eps=config.rms_norm_eps ) - self.layer_communicator = LayerCommunicator( - layer_scatter_modes=self.layer_scatter_modes, - input_layernorm=self.input_layernorm, - post_attention_layernorm=self.post_attention_layernorm, - allow_reduce_scatter=True, - is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1), - ) - def forward( self, positions: torch.Tensor, @@ -284,13 +249,7 @@ class Qwen2DecoderLayer(nn.Module): # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - should_allreduce_fusion = ( - self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer( - forward_batch - ) - ) - - hidden_states = self.mlp(hidden_states, should_allreduce_fusion) + hidden_states = self.mlp(hidden_states) return hidden_states, residual