diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7d80794fc..b0a3f20f6 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -95,6 +95,7 @@ from sglang.srt.utils import ( get_int_env_var, is_cuda, is_hip, + is_non_idle_and_non_empty, log_info_on_rank0, ) @@ -206,14 +207,6 @@ class MoEGate(nn.Module): return logits -def is_non_idle_and_non_empty(forward_mode, hidden_states): - return ( - (forward_mode is not None) - and not forward_mode.is_idle() - and hidden_states.shape[0] > 0 - ) - - class DeepseekV2MoE(nn.Module): def __init__( diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 63b124fd7..b5b884e59 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -32,6 +32,7 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, attn_tp_reduce_scatter, @@ -49,7 +50,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput -from sglang.srt.layers.moe.ep_moe.layer import EPMoE +from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention @@ -114,22 +115,22 @@ class Qwen2MoeMLP(nn.Module): class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( self, + layer_id: int, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() - + self.layer_id = layer_id if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." ) - MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE - - self.experts = MoEImpl( + self.experts = get_moe_impl_class()( + layer_id=self.layer_id, num_experts=config.num_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -159,7 +160,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module): self.shared_expert = None self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None + ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) shared_output = None @@ -276,19 +279,6 @@ class Qwen2MoeAttention(nn.Module): return output -class _FFNInputMode(Enum): - # The MLP sublayer requires 1/tp_size tokens as input - SCATTERED = auto() - # The MLP sublayer requires all tokens as input - FULL = auto() - - -@dataclass -class _DecoderLayerInfo: - is_sparse: bool - ffn_input_mode: _FFNInputMode - - class Qwen2MoeDecoderLayer(nn.Module): def __init__( self, @@ -298,6 +288,7 @@ class Qwen2MoeDecoderLayer(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.config = config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -322,16 +313,20 @@ class Qwen2MoeDecoderLayer(nn.Module): self.attn_tp_rank = get_attention_tp_rank() self.local_dp_size = get_local_attention_dp_size() - self.info = self._compute_info(config, layer_id=layer_id) - previous_layer_info = self._compute_info(config, layer_id=layer_id - 1) - self.input_is_scattered = ( - layer_id > 0 - and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED - ) - self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 + # Qwen2MoE all layers are sparse and have no nextn now + self.is_layer_sparse = True + is_previous_layer_sparse = True - if self.info.is_sparse: + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + + if self.is_layer_sparse: self.mlp = Qwen2MoeSparseMoeBlock( + layer_id=layer_id, config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), @@ -348,27 +343,11 @@ class Qwen2MoeDecoderLayer(nn.Module): self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) - - @staticmethod - def _enable_moe_dense_fully_dp(): - return global_server_args_dict["moe_dense_tp_size"] == 1 - - @staticmethod - def _compute_info(config: PretrainedConfig, layer_id: int): - # WARN: Qwen2MOE has no dense_layer, it is only for compatibility. - mlp_only_layers = ( - [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, ) - is_sparse = (layer_id not in mlp_only_layers) and ( - config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 - ) - ffn_input_mode = ( - _FFNInputMode.SCATTERED - if (global_server_args_dict["enable_deepep_moe"] and is_sparse) - or (Qwen2MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) - else _FFNInputMode.FULL - ) - return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode) def forward( self, @@ -377,108 +356,11 @@ class Qwen2MoeDecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.info.ffn_input_mode == _FFNInputMode.SCATTERED: - return self.forward_ffn_with_scattered_input( - positions, hidden_states, forward_batch, residual - ) - elif self.info.ffn_input_mode == _FFNInputMode.FULL: - return self.forward_ffn_with_full_input( - positions, hidden_states, forward_batch, residual - ) - else: - raise NotImplementedError - def forward_ffn_with_full_input( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - if hidden_states.shape[0] == 0: - residual = hidden_states - else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) - # Self Attention - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) - # Gather - if get_tensor_model_parallel_world_size() > 1: - # all gather and all reduce - if self.local_dp_size != 1: - if self.attn_tp_rank == 0: - hidden_states += residual - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, - hidden_states, - ) - dp_gather_partial(hidden_states, local_hidden_states, forward_batch) - dp_scatter(residual, hidden_states, forward_batch) - # TODO extract this bugfix - if hidden_states.shape[0] != 0: - hidden_states = self.post_attention_layernorm(hidden_states) - else: - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - # TODO extract this bugfix - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - elif hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - - # Fully Connected - hidden_states = self.mlp(hidden_states) - - # TODO: use reduce-scatter in MLP to avoid this scatter - # Scatter - if self.local_dp_size != 1: - # important: forward batch.gathered_buffer is used both after scatter and after gather. - # be careful about this! - hidden_states, global_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - dp_scatter(hidden_states, global_hidden_states, forward_batch) - - return hidden_states, residual - - def forward_ffn_with_scattered_input( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - if hidden_states.shape[0] == 0: - residual = hidden_states - else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - - if self.attn_tp_size != 1 and self.input_is_scattered: - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - attn_tp_all_gather( - list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states - ) - - # Self Attention if hidden_states.shape[0] != 0: hidden_states = self.self_attn( positions=positions, @@ -486,47 +368,15 @@ class Qwen2MoeDecoderLayer(nn.Module): forward_batch=forward_batch, ) - if self.attn_tp_size != 1: - if self.input_is_scattered: - tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) - hidden_states = tensor_list[self.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - else: - if self.attn_tp_rank == 0: - hidden_states += residual - tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) - hidden_states = tensor_list[self.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) - residual = hidden_states - if hidden_states.shape[0] != 0: - hidden_states = self.post_attention_layernorm(hidden_states) - else: - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) - if not ( - self._enable_moe_dense_fully_dp() - and (not self.info.is_sparse) - and hidden_states.shape[0] == 0 - ): - hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + hidden_states = self.mlp(hidden_states, forward_batch) - if self.is_last_layer and self.attn_tp_size != 1: - hidden_states += residual - residual = None - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - attn_tp_all_gather( - list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states - ) + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) return hidden_states, residual diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 36ab20a8e..9b96574b6 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -38,6 +38,7 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, attn_tp_reduce_scatter, @@ -78,7 +79,7 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel -from sglang.srt.utils import DeepEPMode, add_prefix +from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty Qwen3MoeConfig = None @@ -150,13 +151,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ) def forward( - self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None + self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: if not global_server_args_dict["enable_deepep_moe"]: return self.forward_normal(hidden_states) else: - return self.forward_deepep(hidden_states, forward_mode) + return self.forward_deepep(hidden_states, forward_batch) def get_moe_weights(self): return [ @@ -180,13 +181,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module): return final_hidden_states.view(num_tokens, hidden_dim) def forward_deepep( - self, hidden_states: torch.Tensor, forward_mode: ForwardMode + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: - if ( - forward_mode is not None - and not forward_mode.is_idle() - and hidden_states.shape[0] > 0 - ): + forward_mode = forward_batch.forward_mode + if is_non_idle_and_non_empty(forward_mode, hidden_states): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) @@ -356,19 +354,6 @@ class Qwen3MoeAttention(nn.Module): return output -class _FFNInputMode(Enum): - # The MLP sublayer requires 1/tp_size tokens as input - SCATTERED = auto() - # The MLP sublayer requires all tokens as input - FULL = auto() - - -@dataclass -class _DecoderLayerInfo: - is_sparse: bool - ffn_input_mode: _FFNInputMode - - class Qwen3MoeDecoderLayer(nn.Module): def __init__( self, @@ -378,6 +363,7 @@ class Qwen3MoeDecoderLayer(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.config = config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -408,15 +394,18 @@ class Qwen3MoeDecoderLayer(nn.Module): self.attn_tp_rank = get_attention_tp_rank() self.local_dp_size = get_local_attention_dp_size() - self.info = self._compute_info(config, layer_id=layer_id) - previous_layer_info = self._compute_info(config, layer_id=layer_id - 1) - self.input_is_scattered = ( - layer_id > 0 - and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED - ) - self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 + # Qwen3MoE all layers are sparse and have no nextn now + self.is_layer_sparse = True + is_previous_layer_sparse = True - if self.info.is_sparse: + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + + if self.is_layer_sparse: self.mlp = Qwen3MoeSparseMoeBlock( layer_id=self.layer_id, config=config, @@ -436,26 +425,11 @@ class Qwen3MoeDecoderLayer(nn.Module): config.hidden_size, eps=config.rms_norm_eps ) - @staticmethod - def _enable_moe_dense_fully_dp(): - return global_server_args_dict["moe_dense_tp_size"] == 1 - - @staticmethod - def _compute_info(config: PretrainedConfig, layer_id: int): - # WARN: Qwen3MOE has no dense_layer, it is only for compatibility. - mlp_only_layers = ( - [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, ) - is_sparse = (layer_id not in mlp_only_layers) and ( - config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 - ) - ffn_input_mode = ( - _FFNInputMode.SCATTERED - if (global_server_args_dict["enable_deepep_moe"] and is_sparse) - or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) - else _FFNInputMode.FULL - ) - return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode) def forward( self, @@ -464,105 +438,11 @@ class Qwen3MoeDecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.info.ffn_input_mode == _FFNInputMode.SCATTERED: - return self.forward_ffn_with_scattered_input( - positions, hidden_states, forward_batch, residual - ) - elif self.info.ffn_input_mode == _FFNInputMode.FULL: - return self.forward_ffn_with_full_input( - positions, hidden_states, forward_batch, residual - ) - else: - raise NotImplementedError - def forward_ffn_with_full_input( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - if hidden_states.shape[0] == 0: - residual = hidden_states - else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) - # Self Attention - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) - # Gather - if get_tensor_model_parallel_world_size() > 1: - if self.local_dp_size != 1: - if self.attn_tp_rank == 0: - hidden_states += residual - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, - hidden_states, - ) - dp_gather_partial(hidden_states, local_hidden_states, forward_batch) - dp_scatter(residual, hidden_states, forward_batch) - hidden_states = self.post_attention_layernorm(hidden_states) - else: - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - # TODO extract this bugfix - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - elif hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - - # Fully Connected - hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) - - # TODO: use reduce-scatter in MLP to avoid this scatter - # Scatter - if self.local_dp_size != 1: - # important: forward batch.gathered_buffer is used both after scatter and after gather. - # be careful about this! - hidden_states, global_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - dp_scatter(hidden_states, global_hidden_states, forward_batch) - - return hidden_states, residual - - def forward_ffn_with_scattered_input( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - if hidden_states.shape[0] == 0: - residual = hidden_states - else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - - if self.attn_tp_size != 1 and self.input_is_scattered: - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - attn_tp_all_gather( - list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states - ) - - # Self Attention if hidden_states.shape[0] != 0: hidden_states = self.self_attn( positions=positions, @@ -570,47 +450,15 @@ class Qwen3MoeDecoderLayer(nn.Module): forward_batch=forward_batch, ) - if self.attn_tp_size != 1: - if self.input_is_scattered: - tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) - hidden_states = tensor_list[self.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - else: - if self.attn_tp_rank == 0: - hidden_states += residual - tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) - hidden_states = tensor_list[self.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) - residual = hidden_states - if hidden_states.shape[0] != 0: - hidden_states = self.post_attention_layernorm(hidden_states) - else: - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) - if not ( - self._enable_moe_dense_fully_dp() - and (not self.info.is_sparse) - and hidden_states.shape[0] == 0 - ): - hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + hidden_states = self.mlp(hidden_states, forward_batch) - if self.is_last_layer and self.attn_tp_size != 1: - hidden_states += residual - residual = None - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - attn_tp_all_gather( - list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states - ) + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) return hidden_states, residual diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0b19272e1..83780ec3e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2026,6 +2026,14 @@ class DeepEPMode(Enum): return DeepEPMode.normal +def is_non_idle_and_non_empty(forward_mode, hidden_states): + return ( + (forward_mode is not None) + and not forward_mode.is_idle() + and hidden_states.shape[0] > 0 + ) + + def fast_topk(values, topk, dim): if topk == 1: # Use max along the specified dimension to get both value and index diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index 0ae99547e..757c33b8f 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -146,7 +146,7 @@ class TestDisaggregationAccuracy(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.62) def test_logprob(self): - prompt = "The capital of taiwan is " + prompt = "The capital of france is " response = requests.post( self.lb_url + "/generate", json={