diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py index 1f8a1abfe..d1ad13f48 100644 --- a/python/sglang/srt/layers/quantization/quark/quark_moe.py +++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py @@ -12,7 +12,7 @@ from aiter.utility.fp4_utils import e8m0_shuffle from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase -from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs +from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import ( @@ -23,6 +23,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_is_hip = is_hip() + __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"] OCP_MX_BLOCK_SIZE = 32 @@ -182,6 +184,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): topk_output = dispatch_output.topk_output moe_runner_config = self.moe_runner_config topk_weights, topk_ids, _ = topk_output + if _is_hip: + topk_weights = topk_weights.to( + torch.float32 + ) # aiter's moe_sorting requires topk_weights to be FP32 if hasattr(torch, "float4_e2m1fn_x2"): w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)