diff --git a/python/sglang/srt/layers/fused_moe/fused_moe.py b/python/sglang/srt/layers/fused_moe/fused_moe.py index 3e8c2eae0..4d1c98c23 100644 --- a/python/sglang/srt/layers/fused_moe/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe/fused_moe.py @@ -250,9 +250,12 @@ def invoke_fused_moe_kernel( assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + padded_size = padding_size if not use_fp8: assert A_scale is None assert B_scale is None + # MOE_PADDING FP8 only + padded_size = 0 else: A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None @@ -262,7 +265,7 @@ def invoke_fused_moe_kernel( * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - K = B.shape[2] - padding_size + K = B.shape[2] - padded_size if K % config["BLOCK_SIZE_K"] == 0: even_ks = True else: @@ -279,7 +282,7 @@ def invoke_fused_moe_kernel( expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2] - padding_size, + B.shape[2] - padded_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -480,8 +483,12 @@ def fused_experts( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ): + padded_size = padding_size + if not use_fp8: + # MOE_PADDING FP8 only + padded_size = 0 # Check constraints. - assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -498,7 +505,7 @@ def fused_experts( get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size), + (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), topk_ids.shape[1], "float8" if use_fp8 else None, override_config=override_config,