From baee08601bea983335cf7c06da2f341157cdc225 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Sun, 5 Oct 2025 19:51:34 -0700 Subject: [PATCH] [quantization] Enable aiter mxfp4 fused_moe for Quark (#10048) Co-authored-by: HaiShaw --- python/sglang/srt/layers/quantization/quark/quark_moe.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py index 1f8a1abfe..d1ad13f48 100644 --- a/python/sglang/srt/layers/quantization/quark/quark_moe.py +++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py @@ -12,7 +12,7 @@ from aiter.utility.fp4_utils import e8m0_shuffle from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase -from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs +from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import ( @@ -23,6 +23,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_is_hip = is_hip() + __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"] OCP_MX_BLOCK_SIZE = 32 @@ -182,6 +184,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): topk_output = dispatch_output.topk_output moe_runner_config = self.moe_runner_config topk_weights, topk_ids, _ = topk_output + if _is_hip: + topk_weights = topk_weights.to( + torch.float32 + ) # aiter's moe_sorting requires topk_weights to be FP32 if hasattr(torch, "float4_e2m1fn_x2"): w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)