[quantization] Enable aiter mxfp4 fused_moe for Quark (#10048)
Co-authored-by: HaiShaw <hixiao@gmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from aiter.utility.fp4_utils import e8m0_shuffle
|
|||||||
|
|
||||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||||
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
|
from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
@@ -23,6 +23,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
||||||
|
|
||||||
OCP_MX_BLOCK_SIZE = 32
|
OCP_MX_BLOCK_SIZE = 32
|
||||||
@@ -182,6 +184,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|||||||
topk_output = dispatch_output.topk_output
|
topk_output = dispatch_output.topk_output
|
||||||
moe_runner_config = self.moe_runner_config
|
moe_runner_config = self.moe_runner_config
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
if _is_hip:
|
||||||
|
topk_weights = topk_weights.to(
|
||||||
|
torch.float32
|
||||||
|
) # aiter's moe_sorting requires topk_weights to be FP32
|
||||||
|
|
||||||
if hasattr(torch, "float4_e2m1fn_x2"):
|
if hasattr(torch, "float4_e2m1fn_x2"):
|
||||||
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
||||||
|
|||||||
Reference in New Issue
Block a user