From 8ebf72fef335c963bce7d4867ef8dcb6cdc96263 Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Sat, 27 Sep 2025 13:12:22 +0800 Subject: [PATCH] [Fix] RuntimeError: get_cfg Unsupported input_type:Float4_e2m1fn_x2 in using aiter-mxfp4-moe (#10981) Co-authored-by: wunhuang --- python/sglang/srt/layers/quantization/mxfp4.py | 12 ++++++++++-- .../srt/layers/quantization/quark/quark_moe.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 8643a3e36..caf323950 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -843,10 +843,18 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): topk_weights = topk_weights.to( torch.float32 ) # aiter's moe_sorting requires topk_weights to be FP32 + + if hasattr(torch, "float4_e2m1fn_x2"): + w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) + w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) + else: + w13_weight = layer.w13_weight + w2_weight = layer.w2_weight + output = fused_moe( x, - layer.w13_weight, - layer.w2_weight, + w13_weight, + w2_weight, topk_weights, topk_ids, quant_type=QuantType.per_1x32, diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py index f6e750a2c..1f8a1abfe 100644 --- a/python/sglang/srt/layers/quantization/quark/quark_moe.py +++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py @@ -183,10 +183,17 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): moe_runner_config = self.moe_runner_config topk_weights, topk_ids, _ = topk_output + if hasattr(torch, "float4_e2m1fn_x2"): + w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) + w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) + else: + w13_weight = layer.w13_weight + w2_weight = layer.w2_weight + output = fused_moe( x, - layer.w13_weight, - layer.w2_weight, + w13_weight, + w2_weight, topk_weights, topk_ids, quant_type=QuantType.per_1x32,