[quantization] Enable aiter mxfp4 fused_moe for Quark (#10048)
Co-authored-by: HaiShaw <hixiao@gmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
@@ -23,6 +23,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
@@ -182,6 +184,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
topk_output = dispatch_output.topk_output
|
||||
moe_runner_config = self.moe_runner_config
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if _is_hip:
|
||||
topk_weights = topk_weights.to(
|
||||
torch.float32
|
||||
) # aiter's moe_sorting requires topk_weights to be FP32
|
||||
|
||||
if hasattr(torch, "float4_e2m1fn_x2"):
|
||||
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
||||
|
||||
Reference in New Issue
Block a user