diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 766f79404..98f89ab7f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -932,11 +932,11 @@ class FlashInferFusedMoE(FusedMoE): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): assert self.use_flashinfer_trtllm_moe assert ( - self.activation == "silu" + self.moe_runner_config.activation == "silu" ), "Only silu is supported for flashinfer blockscale fp8 moe" assert self.quant_method is not None assert ( - self.renormalize + topk_output.topk_config.renormalize ), "Renormalize is required for flashinfer blockscale fp8 moe" assert ( self.num_fused_shared_experts == 0 diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 3b939bca8..479103e15 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -85,8 +85,8 @@ if _is_npu: class TopKConfig: top_k: int use_grouped_topk: bool = False - topk_group: int = 0 - num_expert_group: int = 0 + topk_group: Optional[int] = None + num_expert_group: Optional[int] = None renormalize: bool = True num_fused_shared_experts: int = 0 custom_routing_function: Optional[Callable] = None @@ -189,8 +189,8 @@ class TopK(CustomOp): top_k: int, *, use_grouped_topk: bool = False, - topk_group: int = 0, - num_expert_group: int = 0, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, renormalize: bool = True, num_fused_shared_experts: int = 0, custom_routing_function: Optional[Callable] = None, @@ -427,8 +427,8 @@ def grouped_topk_gpu( gating_output: torch.Tensor, topk: int, renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, num_fused_shared_experts: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, @@ -492,8 +492,8 @@ def grouped_topk_cpu( gating_output: torch.Tensor, topk: int, renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, num_fused_shared_experts: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, @@ -522,8 +522,8 @@ def biased_grouped_topk_impl( correction_bias: torch.Tensor, topk: int, renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, num_fused_shared_experts: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, @@ -615,8 +615,8 @@ def biased_grouped_topk_gpu( correction_bias: torch.Tensor, topk: int, renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, num_fused_shared_experts: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, @@ -690,8 +690,8 @@ def biased_grouped_topk_cpu( correction_bias: torch.Tensor, topk: int, renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, compiled: bool = True, num_fused_shared_experts: int = 0, routed_scaling_factor: Optional[float] = None, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 14ce92f36..f2e07b515 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -445,7 +445,6 @@ class Fp8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if self.use_marlin: return apply_fp8_marlin_linear( input=x, @@ -1087,7 +1086,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_output: TopKOutput, moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: - activation = moe_runner_config.activation routed_scaling_factor = moe_runner_config.routed_scaling_factor @@ -1105,9 +1103,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() + assert ( + topk_config.num_expert_group is not None + and topk_config.topk_group is not None + ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None" + + if topk_config.correction_bias is None: + correction_bias = topk_config.correction_bias.to(x.dtype) + else: + correction_bias = None return trtllm_fp8_block_scale_moe( routing_logits=router_logits.to(torch.float32), - routing_bias=layer.correction_bias.to(x.dtype), + routing_bias=correction_bias, hidden_states=a_q, hidden_states_scale=a_sf_t, gemm1_weights=layer.w13_weight, @@ -1121,9 +1128,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): intermediate_size=layer.w2_weight.shape[2], local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, local_num_experts=layer.num_local_experts, - routed_scaling_factor=routed_scaling_factor, + routed_scaling_factor=( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ), tile_tokens_dim=get_tile_tokens_dim( - x.shape[0], layer.top_k, layer.num_experts + x.shape[0], topk_config.top_k, layer.num_experts ), routing_method_type=2, # DeepSeek-styled routing method use_shuffled_weight=False,