diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ac5371871..862561804 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -673,66 +673,6 @@ class DeepEPMoE(EPMoE): return down_output -class FlashInferEPMoE(EPMoE): - def __init__(self, *args, **kwargs): - renormalize = kwargs.pop("renormalize", True) - num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0) - use_grouped_topk = kwargs.pop("use_grouped_topk", False) - num_expert_group = kwargs.pop("num_expert_group", None) - topk_group = kwargs.pop("topk_group", None) - correction_bias = kwargs.pop("correction_bias", None) - super().__init__(*args, **kwargs) - self.renormalize = renormalize - self.num_fused_shared_experts = num_fused_shared_experts - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.correction_bias = correction_bias - self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() - - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - assert self.use_flashinfer_trtllm_moe - assert ( - self.activation == "silu" - ), "Only silu is supported for flashinfer blockscale fp8 moe" - assert ( - self.renormalize - ), "Renormalize is required for flashinfer blockscale fp8 moe" - assert ( - self.num_fused_shared_experts == 0 - ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" - a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1]) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() - from flashinfer.fused_moe import trtllm_fp8_block_scale_moe - - return trtllm_fp8_block_scale_moe( - routing_logits=router_logits.to(torch.float32), - routing_bias=self.correction_bias.to(hidden_states.dtype), - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=self.w13_weight, - gemm1_weights_scale=self.w13_weight_scale_inv, - gemm2_weights=self.w2_weight, - gemm2_weights_scale=self.w2_weight_scale_inv, - num_experts=self.num_experts, - top_k=self.top_k, - n_group=self.num_expert_group, - topk_group=self.topk_group, - intermediate_size=self.w2_weight.shape[2], - local_expert_offset=self.start_expert_id, - local_num_experts=self.num_local_experts, - routed_scaling_factor=self.routed_scaling_factor, - tile_tokens_dim=get_tile_tokens_dim( - hidden_states.shape[0], self.top_k, self.num_experts - ), - routing_method_type=2, # DeepSeek-styled routing method - use_shuffled_weight=False, - ) - - def get_moe_impl_class(): if global_server_args_dict["moe_a2a_backend"].is_deepep(): return DeepEPMoE @@ -752,8 +692,10 @@ def get_moe_impl_class(): except: pass + if should_use_flashinfer_trtllm_moe(): + return FlashInferFusedMoE if global_server_args_dict["enable_flashinfer_cutlass_moe"]: return FusedMoE if get_moe_expert_parallel_world_size() > 1: - return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE - return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE + return EPMoE + return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index c30535d7f..74558fd9b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -763,8 +763,13 @@ class FlashInferFusedMoE(FusedMoE): self.num_expert_group = num_expert_group self.topk_group = topk_group self.correction_bias = correction_bias + self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, topk_output: tuple): + assert self.use_flashinfer_trtllm_moe + assert ( + self.activation == "silu" + ), "Only silu is supported for flashinfer blockscale fp8 moe" assert self.quant_method is not None assert ( self.renormalize @@ -772,6 +777,14 @@ class FlashInferFusedMoE(FusedMoE): assert ( self.num_fused_shared_experts == 0 ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" + + # TRTLLM mode expects (TopK_config, router_logits) tuple + if not isinstance(topk_output, tuple) or len(topk_output) != 2: + raise ValueError( + f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}" + ) + _, router_logits = topk_output + # Matrix multiply. final_hidden_states = self.quant_method.apply_with_router_logits( layer=self,