From 9e2f7252db8f2e1b903dba31484f7efb0b772c41 Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Thu, 11 Sep 2025 03:49:43 +0800 Subject: [PATCH] add dual stream for qwen2_moe (#10252) --- python/sglang/srt/models/qwen2_moe.py | 68 ++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 194e513ac..ffb619940 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -65,10 +65,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.two_batch_overlap import model_forward_maybe_tbo -from sglang.srt.utils import add_prefix, make_layers +from sglang.srt.utils import add_prefix, is_cuda, make_layers logger = logging.getLogger(__name__) +_is_cuda = is_cuda() + class Qwen2MoeMLP(nn.Module): def __init__( @@ -122,11 +124,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module): layer_id: int, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + alt_stream: Optional[torch.cuda.Stream] = None, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.layer_id = layer_id + self.alt_stream = alt_stream if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -168,6 +172,37 @@ class Qwen2MoeSparseMoeBlock(nn.Module): self.shared_expert = None self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + def _forward_shared_experts(self, hidden_states: torch.Tensor): + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + if self.shared_expert_gate is not None: + shared_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output + ) + return shared_output + + def _forward_router_experts(self, hidden_states: torch.Tensor): + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) + return self.experts(hidden_states, topk_output) + + def forward_normal_dual_stream( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + shared_output = self._forward_shared_experts(hidden_states) + + with torch.cuda.stream(self.alt_stream): + router_output = self._forward_router_experts(hidden_states) + + current_stream.wait_stream(self.alt_stream) + + return router_output, shared_output + def forward( self, hidden_states: torch.Tensor, @@ -176,18 +211,20 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None - if self.shared_expert is not None: - shared_output = self.shared_expert(hidden_states) - if self.shared_expert_gate is not None: - shared_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output - ) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - topk_output = self.topk(hidden_states, router_logits) - final_hidden_states = self.experts(hidden_states, topk_output) + DUAL_STREAM_TOKEN_THRESHOLD = 1024 + if ( + self.alt_stream is not None + and hidden_states.shape[0] > 0 + and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD + ): + final_hidden_states, shared_output = self.forward_normal_dual_stream( + hidden_states + ) + else: + shared_output = self._forward_shared_experts(hidden_states) + final_hidden_states = self._forward_router_experts(hidden_states) + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1 and not use_reduce_scatter: @@ -346,6 +383,7 @@ class Qwen2MoeDecoderLayer(nn.Module): layer_id=layer_id, config=config, quant_config=quant_config, + alt_stream=alt_stream, prefix=add_prefix("mlp", prefix), ) else: @@ -528,8 +566,12 @@ class Qwen2MoeForCausalLM(nn.Module): self.pp_group = get_pp_group() self.config = config self.quant_config = quant_config + alt_stream = torch.cuda.Stream() if _is_cuda else None self.model = Qwen2MoeModel( - config, quant_config, prefix=add_prefix("model", prefix) + config, + quant_config, + prefix=add_prefix("model", prefix), + alt_stream=alt_stream, ) self.lm_head = ParallelLMHead( config.vocab_size,