diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index b5b884e59..fe6b00685 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -68,6 +68,7 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.two_batch_overlap import model_forward_maybe_tbo from sglang.srt.utils import add_prefix, make_layers logger = logging.getLogger(__name__) @@ -442,12 +443,22 @@ class Qwen2MoeModel(nn.Module): hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - with get_global_expert_distribution_recorder().with_current_layer(i): - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual - ) + if forward_batch.can_run_tbo: + hidden_states, residual = model_forward_maybe_tbo( + layers=self.layers, + enable_tbo=True, + positions=positions, + forward_batch=forward_batch, + hidden_states=hidden_states, + residual=residual, + ) + else: + for i in range(self.start_layer, self.end_layer): + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) if not self.pp_group.is_last_rank: return PPProxyTensors( { diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 9b96574b6..27e4ae62c 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -68,6 +68,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -79,6 +82,7 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel +from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty Qwen3MoeConfig = None @@ -137,7 +141,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self.top_k = config.num_experts_per_tok self.renormalize = config.norm_topk_prob - self.deepep_dispatcher = DeepEPDispatcher( + self.deepep_dispatcher = MaybeTboDeepEPDispatcher( group=parallel_state.get_tp_group().device_group, router_topk=self.top_k, permute_fusion=True, @@ -217,9 +221,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): masked_m, expected_m, ) = self.deepep_dispatcher.dispatch( - hidden_states, - topk_idx, - topk_weights, + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, forward_mode=forward_mode, ) final_hidden_states = self.experts( @@ -235,13 +239,105 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ) if self.ep_size > 1: final_hidden_states = self.deepep_dispatcher.combine( - final_hidden_states, - topk_idx, - topk_weights, - forward_mode, + hidden_states=final_hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + forward_mode=forward_mode, ) return final_hidden_states + def op_gate(self, state): + if is_non_idle_and_non_empty( + state.forward_batch.forward_mode, state.hidden_states_mlp_input + ): + # router_logits: (num_tokens, n_experts) + state.router_logits, _ = self.gate(state.hidden_states_mlp_input) + else: + state.router_logits = None + + def op_select_experts(self, state): + router_logits = state.pop("router_logits") + hidden_states = state.hidden_states_mlp_input + if router_logits is not None: + state.topk_weights_local, state.topk_idx_local = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=self.renormalize, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), + ) + else: + state.topk_idx_local = torch.full( + (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device + ) + state.topk_weights_local = torch.empty( + (0, self.top_k), dtype=torch.float32, device=hidden_states.device + ) + + def op_dispatch_a(self, state): + if self.ep_size > 1: + # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value + self.deepep_dispatcher.dispatch_a( + hidden_states=state.pop("hidden_states_mlp_input"), + topk_idx=state.pop("topk_idx_local"), + topk_weights=state.pop("topk_weights_local"), + forward_mode=state.forward_batch.forward_mode, + tbo_subbatch_index=state.get("tbo_subbatch_index"), + ) + + def op_dispatch_b(self, state): + if self.ep_size > 1: + with get_global_expert_distribution_recorder().with_current_layer( + self.layer_id + ): + ( + state.hidden_states_experts_input, + state.topk_idx_dispatched, + state.topk_weights_dispatched, + state.reorder_topk_ids, + state.num_recv_tokens_per_expert, + state.seg_indptr, + state.masked_m, + state.expected_m, + ) = self.deepep_dispatcher.dispatch_b( + tbo_subbatch_index=state.get("tbo_subbatch_index"), + ) + + def op_experts(self, state): + state.hidden_states_experts_output = self.experts( + hidden_states=state.pop("hidden_states_experts_input"), + topk_idx=state.topk_idx_dispatched, + topk_weights=state.topk_weights_dispatched, + reorder_topk_ids=state.pop("reorder_topk_ids"), + seg_indptr=state.pop("seg_indptr"), + masked_m=state.pop("masked_m"), + expected_m=state.pop("expected_m"), + num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), + forward_mode=state.forward_batch.forward_mode, + ) + + def op_combine_a(self, state): + if self.ep_size > 1: + self.deepep_dispatcher.combine_a( + hidden_states=state.pop("hidden_states_experts_output"), + topk_idx=state.pop("topk_idx_dispatched"), + topk_weights=state.pop("topk_weights_dispatched"), + forward_mode=state.forward_batch.forward_mode, + tbo_subbatch_index=state.get("tbo_subbatch_index"), + ) + + def op_combine_b(self, state): + if self.ep_size > 1: + state.hidden_states_after_combine = self.deepep_dispatcher.combine_b( + tbo_subbatch_index=state.get("tbo_subbatch_index"), + ) + + def op_output(self, state): + state.hidden_states_mlp_output = state.pop("hidden_states_after_combine") + class Qwen3MoeAttention(nn.Module): def __init__( @@ -339,19 +435,53 @@ class Qwen3MoeAttention(nn.Module): k = k_by_head.view(k.shape) return q, k + def op_prepare(self, state): + state.attn_intermediate_state = self.forward_prepare( + positions=state.positions, + hidden_states=state.pop("hidden_states_after_comm_pre_attn"), + forward_batch=state.forward_batch, + ) + + def op_core(self, state): + state.hidden_states_after_attn = self.forward_core( + state.pop("attn_intermediate_state") + ) + + def forward_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): + if hidden_states.shape[0] == 0: + return hidden_states, forward_batch, None + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + inner_state = q, k, v, forward_batch + return None, forward_batch, inner_state + + def forward_core(self, intermediate_state): + hidden_states, forward_batch, inner_state = intermediate_state + if inner_state is None: + return hidden_states + attn_output = self.attn(*inner_state) + output, _ = self.o_proj(attn_output) + return output + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, forward_batch) - output, _ = self.o_proj(attn_output) - return output + s = self.forward_prepare( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + return self.forward_core(s) class Qwen3MoeDecoderLayer(nn.Module): @@ -462,6 +592,65 @@ class Qwen3MoeDecoderLayer(nn.Module): return hidden_states, residual + def op_comm_prepare_attn( + self, + state, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + tbo_subbatch_index: Optional[int] = None, + ): + state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( + self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch) + ) + state.update( + dict( + forward_batch=forward_batch, + positions=positions, + tbo_subbatch_index=tbo_subbatch_index, + ) + ) + + def op_comm_prepare_mlp(self, state): + state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = ( + self.layer_communicator.prepare_mlp( + state.pop("hidden_states_after_attn"), + state.pop("residual_after_input_ln"), + state.forward_batch, + ) + ) + + def op_mlp(self, state): + hidden_states = state.pop("hidden_states_mlp_input") + state.hidden_states_mlp_output = self.mlp( + hidden_states, state.forward_batch.forward_mode + ) + + def op_comm_postprocess_layer(self, state): + hidden_states, residual = self.layer_communicator.postprocess_layer( + state.pop("hidden_states_mlp_output"), + state.pop("residual_after_comm_pre_mlp"), + state.forward_batch, + ) + + output = dict( + positions=state.positions, + hidden_states=hidden_states, + residual=residual, + forward_batch=state.forward_batch, + tbo_subbatch_index=state.tbo_subbatch_index, + ) + + state.clear( + expect_keys={ + "positions", + "forward_batch", + "tbo_subbatch_index", + } + ) + return output + class Qwen3MoeModel(Qwen2MoeModel): def __init__( diff --git a/python/sglang/srt/operations_strategy.py b/python/sglang/srt/operations_strategy.py index b8e0eaef0..6fd32e66a 100644 --- a/python/sglang/srt/operations_strategy.py +++ b/python/sglang/srt/operations_strategy.py @@ -32,12 +32,27 @@ class OperationsStrategy: layers: torch.nn.ModuleList, forward_mode: ForwardMode, ) -> "OperationsStrategy": - return OperationsStrategy.concat( - [ - _compute_layer_operations_strategy_tbo(layer, forward_mode) - for layer in layers - ] - ) + layer_name = layers[0].__class__.__name__ + if layer_name == "DeepseekV2DecoderLayer": + return OperationsStrategy.concat( + [ + _compute_moe_deepseek_layer_operations_strategy_tbo( + layer, forward_mode + ) + for layer in layers + ] + ) + elif layer_name == "Qwen3MoeDecoderLayer": + return OperationsStrategy.concat( + [ + _compute_moe_qwen3_layer_operations_strategy_tbo( + layer, forward_mode + ) + for layer in layers + ] + ) + else: + raise NotImplementedError def _assert_all_same(items: List): @@ -45,8 +60,11 @@ def _assert_all_same(items: List): return items[0] +# -------------------------------- Strategy for DeepSeek --------------------------------------- + + # TODO can refactor to make it more fancy if we have more complex strategies -def _compute_layer_operations_strategy_tbo( +def _compute_moe_deepseek_layer_operations_strategy_tbo( layer: torch.nn.Module, forward_mode: ForwardMode, ) -> OperationsStrategy: @@ -114,3 +132,76 @@ def _compute_moe_deepseek_blog_decode(layer): operations.YieldOperation(), ], ) + + +# -------------------------------- Strategy for Qwen3 --------------------------------------- + + +# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for +# convenience to adjust strategy +def _compute_moe_qwen3_layer_operations_strategy_tbo( + layer: torch.nn.Module, + forward_mode: ForwardMode, +) -> OperationsStrategy: + assert layer.is_layer_sparse, "qwen3 moe only support sparse layers" + if forward_mode == ForwardMode.EXTEND: + return _compute_moe_qwen3_prefill(layer) + elif forward_mode == ForwardMode.DECODE: + return _compute_moe_qwen3_decode(layer) + else: + raise NotImplementedError(f"Unsupported {forward_mode=}") + + +def _compute_moe_qwen3_prefill(layer): + device_properties = torch.cuda.get_device_properties(device="cuda") + total_num_sms = device_properties.multi_processor_count + deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms + + return OperationsStrategy( + deep_gemm_num_sms=deep_gemm_num_sms, + tbo_delta_stages=0, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + layer.mlp.op_dispatch_a, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + ], + ) + + +def _compute_moe_qwen3_decode(layer): + return OperationsStrategy( + deep_gemm_num_sms=None, + tbo_delta_stages=2, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + operations.YieldOperation(), + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + operations.YieldOperation(), + layer.mlp.op_dispatch_a, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + operations.YieldOperation(), + ], + ) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index c507c2637..afdb5fce0 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -356,14 +356,14 @@ def model_forward_maybe_tbo( forward_batch: ForwardBatch, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - zero_allocator: BumpAllocator, + zero_allocator: Optional[BumpAllocator] = None, ): inputs = dict( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, residual=residual, - zero_allocator=zero_allocator, + **(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}), ) operations_strategy = OperationsStrategy.init_new_tbo( layers, forward_batch.global_forward_mode @@ -401,7 +401,7 @@ def _model_forward_tbo_split_inputs( residual: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, - zero_allocator: BumpAllocator, + zero_allocator: Optional[BumpAllocator] = None, ) -> List[Dict]: return [ dict( @@ -412,7 +412,11 @@ def _model_forward_tbo_split_inputs( output_forward_batch=output_forward_batch, tbo_subbatch_index=tbo_subbatch_index, ), - zero_allocator=zero_allocator, + **( + dict(zero_allocator=zero_allocator) + if zero_allocator is not None + else {} + ), ) for tbo_subbatch_index, output_forward_batch in enumerate( forward_batch.tbo_children diff --git a/test/srt/test_two_batch_overlap.py b/test/srt/test_two_batch_overlap.py index 765679fc3..02633b78a 100644 --- a/test/srt/test_two_batch_overlap.py +++ b/test/srt/test_two_batch_overlap.py @@ -9,6 +9,7 @@ from sglang.srt.two_batch_overlap import compute_split_seq_index from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( + DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -104,5 +105,32 @@ class TestTwoBatchOverlapUnitTest(unittest.TestCase): self.assertEqual(actual, expect) +class TestQwen3TwoBatchOverlap(TestTwoBatchOverlap): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-1234" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--enable-deepep-moe", + "--deepep-mode", + "normal", + "--disable-cuda-graph", # DeepEP normal does not support CUDA Graph + "--enable-two-batch-overlap", + ], + env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ}, + ) + + if __name__ == "__main__": unittest.main()