From 3c66a970f29a4a976bbb4836e333bc58995b4a9d Mon Sep 17 00:00:00 2001 From: Eric-dot <60131170+Eric-dot@users.noreply.github.com> Date: Mon, 2 Mar 2026 11:04:06 +0800 Subject: [PATCH] add mxfp8 moe quantization (#6670) ### What this PR does / why we need it? support mxfp8 quantization (Qwen MOE ) Using adaptor to make the hardware-specific behavior clearer and more maintainable ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/13397841ab469cecf1ed425c3f52a9ffc38139b5 --------- Signed-off-by: fangrongcan <17343701736@163.com> Signed-off-by: wangyao-i Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com> Co-authored-by: fangrongcan Co-authored-by: wangyao-i Co-authored-by: linfeng-yuan <1102311262@qq.com> --- vllm_ascend/ascend_forward_context.py | 5 + vllm_ascend/device/device_op.py | 327 +++++++++++++++++- vllm_ascend/ops/fused_moe/moe_comm_method.py | 55 ++- vllm_ascend/ops/fused_moe/moe_mlp.py | 194 +++++++---- vllm_ascend/ops/fused_moe/prepare_finalize.py | 6 +- vllm_ascend/ops/fused_moe/token_dispatcher.py | 54 ++- vllm_ascend/quantization/methods/base.py | 1 + .../quantization/methods/w8a8_mxfp8.py | 144 +++++++- vllm_ascend/quantization/mxfp_compat.py | 43 +++ vllm_ascend/quantization/quant_parser.py | 73 ++++ 10 files changed, 802 insertions(+), 100 deletions(-) create mode 100644 vllm_ascend/quantization/mxfp_compat.py create mode 100644 vllm_ascend/quantization/quant_parser.py diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index a889b84c..936fa603 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -264,6 +264,11 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL elif soc_version in {AscendDeviceType._310P}: moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendDeviceType.A5}: + if num_tokens <= mc2_tokens_capacity and vllm_config.parallel_config.world_size_across_dp > 1: + moe_comm_type = MoECommType.MC2 + else: + moe_comm_type = MoECommType.ALLTOALL else: raise ValueError(f"Unsupported soc_version: {soc_version}") return moe_comm_type diff --git a/vllm_ascend/device/device_op.py b/vllm_ascend/device/device_op.py index 92e7e8fa..ed5d87d2 100644 --- a/vllm_ascend/device/device_op.py +++ b/vllm_ascend/device/device_op.py @@ -15,9 +15,14 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # - +import torch import torch_npu +from vllm_ascend.quantization.mxfp_compat import ( + FLOAT4_E2M1FN_X2_DTYPE, + FLOAT8_E8M0FNU_DTYPE, + HIFLOAT8_DTYPE, +) from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type @@ -28,6 +33,126 @@ class BaseDeviceAdaptor: key=key, value=value, key_cache=key_cache, value_cache=value_cache, slot_indices=slot_mapping ) + @staticmethod + def npu_moe_init_routing( + hidden_states, + topk_ids, + *, + scale=None, + active_num: int, + expert_num: int, + expert_tokens_num_type: int = 1, + expert_tokens_num_flag: bool = True, + active_expert_range=None, + quant_mode: int = -1, + ): + return torch.ops._C_ascend.npu_moe_init_routing_custom( + hidden_states, + topk_ids, + scale=scale, + active_num=active_num, + expert_num=expert_num, + expert_tokens_num_type=expert_tokens_num_type, + expert_tokens_num_flag=expert_tokens_num_flag, + active_expert_range=active_expert_range, + quant_mode=quant_mode, + ) + + @staticmethod + def npu_dynamic_quant( + hidden_states: torch.Tensor, + dynamic_scale: torch.Tensor | None = None, + *, + act_quant_type=torch.float8_e4m3fn, + use_mxfp_quant: bool = False, + ): + if use_mxfp_quant: + raise RuntimeError("MXFP8 MoE quantization is only supported on Ascend A5.") + + if dynamic_scale is None: + return torch_npu.npu_dynamic_quant(hidden_states) + + return hidden_states, dynamic_scale + + @staticmethod + def npu_grouped_matmul_swiglu_quant( + *, + x: torch.Tensor, + weight: torch.Tensor, + group_list: torch.Tensor, + weight_scale: torch.Tensor, + x_scale: torch.Tensor, + bias=None, + use_mxfp_quant: bool = False, + ): + if use_mxfp_quant: + raise RuntimeError("MXFP8 MoE quantization is only supported on Ascend A5.") + + return torch_npu.npu_grouped_matmul_swiglu_quant( + x=x, + weight=weight, + bias=bias, + group_list=group_list, + weight_scale=weight_scale, + x_scale=x_scale, + ) + + @staticmethod + def get_quant_gmm2_kwargs( + *, + input_dtype: torch.dtype, + act_quant_type, + weight_quant_type, + scale_type, + per_token_scale_type, + use_bf16: bool = True, + use_mxfp_quant: bool = False, + ) -> dict: + if use_mxfp_quant: + raise RuntimeError("MXFP8 MoE quantization is only supported on Ascend A5.") + + return { + "output_dtype": input_dtype if input_dtype in [torch.bfloat16, torch.float16] else torch.bfloat16, + } + + @classmethod + def npu_grouped_matmul_gmm2( + cls, + *, + hidden_states: torch.Tensor, + weight: list[torch.Tensor] | torch.Tensor, + weight_scale: list[torch.Tensor] | torch.Tensor, + per_token_scale: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int, + input_dtype: torch.dtype, + act_quant_type, + weight_quant_type, + scale_type, + per_token_scale_type, + use_bf16: bool = True, + use_mxfp_quant: bool = False, + bias=None, + fallback_output_dtype: torch.dtype | None = None, + ) -> torch.Tensor: + if use_mxfp_quant: + raise RuntimeError("MXFP8 MoE quantization is only supported on Ascend A5.") + + if fallback_output_dtype is None: + fallback_output_dtype = weight_scale[0].dtype if isinstance(weight_scale, list) else weight_scale.dtype + return torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=weight, + scale=weight_scale, + bias=bias, + per_token_scale=[per_token_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=fallback_output_dtype, + )[0] + class A5DeviceAdaptor(BaseDeviceAdaptor): @classmethod @@ -36,12 +161,208 @@ class A5DeviceAdaptor(BaseDeviceAdaptor): key=key, value=value.contiguous(), key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping ) + @staticmethod + def npu_moe_init_routing( + hidden_states, + topk_ids, + *, + scale=None, + active_num: int, + expert_num: int, + expert_tokens_num_type: int = 1, + expert_tokens_num_flag: bool = True, + active_expert_range=None, + quant_mode: int = -1, + ): + return torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + scale=scale, + active_num=active_num, + expert_num=expert_num, + expert_tokens_num_type=expert_tokens_num_type, + expert_tokens_num_flag=expert_tokens_num_flag, + active_expert_range=active_expert_range, + quant_mode=quant_mode, + ) -def get_device_adaptor(): + @staticmethod + def npu_dynamic_quant( + hidden_states: torch.Tensor, + dynamic_scale: torch.Tensor | None = None, + *, + act_quant_type=torch.float8_e4m3fn, + use_mxfp_quant: bool = False, + ): + if not use_mxfp_quant: + return BaseDeviceAdaptor.npu_dynamic_quant( + hidden_states, + dynamic_scale, + act_quant_type=act_quant_type, + use_mxfp_quant=False, + ) + + if dynamic_scale is None: + return torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type) + + if dynamic_scale.ndim == 2: + dynamic_scale = dynamic_scale.reshape(dynamic_scale.shape[0], dynamic_scale.shape[1] // 2, 2) + + return hidden_states, dynamic_scale + + @staticmethod + def npu_grouped_matmul_swiglu_quant( + *, + x: torch.Tensor, + weight: torch.Tensor, + group_list: torch.Tensor, + weight_scale: torch.Tensor, + x_scale: torch.Tensor, + bias=None, + use_mxfp_quant: bool = False, + ): + if not use_mxfp_quant: + return BaseDeviceAdaptor.npu_grouped_matmul_swiglu_quant( + x=x, + weight=weight, + group_list=group_list, + weight_scale=weight_scale, + x_scale=x_scale, + bias=bias, + use_mxfp_quant=False, + ) + + out, out_scale = torch_npu.npu_grouped_matmul_swiglu_quant_v2( + x=x, + weight=[weight], + group_list=group_list, + weight_scale=[weight_scale], + x_scale=x_scale, + dequant_mode=2, + quant_mode=2, + dequant_dtype=torch.float32, + quant_dtype=torch.float8_e4m3fn, + weight_scale_dtype=FLOAT8_E8M0FNU_DTYPE, + x_scale_dtype=FLOAT8_E8M0FNU_DTYPE, + ) + return out, out_scale, None + + @staticmethod + def get_quant_gmm2_kwargs( + *, + input_dtype: torch.dtype, + act_quant_type, + weight_quant_type, + scale_type, + per_token_scale_type, + use_bf16: bool = True, + use_mxfp_quant: bool = False, + ) -> dict: + if not use_mxfp_quant: + return BaseDeviceAdaptor.get_quant_gmm2_kwargs( + input_dtype=input_dtype, + act_quant_type=act_quant_type, + weight_quant_type=weight_quant_type, + scale_type=scale_type, + per_token_scale_type=per_token_scale_type, + use_bf16=use_bf16, + use_mxfp_quant=False, + ) + + quant_dtypes = tuple(dtype for dtype in (FLOAT4_E2M1FN_X2_DTYPE, HIFLOAT8_DTYPE) if dtype is not None) + scale_dtypes = tuple(dtype for dtype in (FLOAT8_E8M0FNU_DTYPE,) if dtype is not None) + + output_dtype = ( + input_dtype + if input_dtype in [torch.bfloat16, torch.float16] + else (torch.bfloat16 if use_bf16 else torch.float16) + ) + + return { + "scale_dtype": scale_type if scale_type in scale_dtypes else None, + "per_token_scale_dtype": per_token_scale_type if per_token_scale_type in scale_dtypes else None, + "x_dtype": act_quant_type if act_quant_type in quant_dtypes else None, + "weight_dtype": weight_quant_type if weight_quant_type in quant_dtypes else None, + "output_dtype": output_dtype, + } + + @classmethod + def npu_grouped_matmul_gmm2( + cls, + *, + hidden_states: torch.Tensor, + weight: list[torch.Tensor] | torch.Tensor, + weight_scale: list[torch.Tensor] | torch.Tensor, + per_token_scale: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int, + input_dtype: torch.dtype, + act_quant_type, + weight_quant_type, + scale_type, + per_token_scale_type, + use_bf16: bool = True, + use_mxfp_quant: bool = False, + bias=None, + fallback_output_dtype: torch.dtype | None = None, + ) -> torch.Tensor: + if not use_mxfp_quant: + return BaseDeviceAdaptor.npu_grouped_matmul_gmm2( + hidden_states=hidden_states, + weight=weight, + weight_scale=weight_scale, + per_token_scale=per_token_scale, + group_list=group_list, + group_list_type=group_list_type, + input_dtype=input_dtype, + act_quant_type=act_quant_type, + weight_quant_type=weight_quant_type, + scale_type=scale_type, + per_token_scale_type=per_token_scale_type, + use_bf16=use_bf16, + use_mxfp_quant=False, + bias=bias, + fallback_output_dtype=fallback_output_dtype, + ) + + gmm2_kwargs = cls.get_quant_gmm2_kwargs( + input_dtype=input_dtype, + act_quant_type=act_quant_type, + weight_quant_type=weight_quant_type, + scale_type=scale_type, + per_token_scale_type=per_token_scale_type, + use_bf16=use_bf16, + use_mxfp_quant=True, + ) + output_dtype = gmm2_kwargs.pop("output_dtype") + + if isinstance(weight, list) and len(weight) != 1: + raise ValueError(f"w2 must have a single tensor in MXFP path, but got {len(weight)}.") + if isinstance(weight_scale, list) and len(weight_scale) != 1: + raise ValueError(f"w2_scale must have a single tensor in MXFP path, but got {len(weight_scale)}.") + gmm2_weight = weight if isinstance(weight, list) else [weight] + gmm2_scale = weight_scale if isinstance(weight_scale, list) else [weight_scale] + + return torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=gmm2_weight, + scale=gmm2_scale, + bias=bias, + per_token_scale=[per_token_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + **gmm2_kwargs, + )[0] + + +def get_device_adaptor() -> type["BaseDeviceAdaptor"]: ascend_device_type = get_ascend_device_type() if ascend_device_type == AscendDeviceType.A5: return A5DeviceAdaptor return BaseDeviceAdaptor -DeviceOperator: type["BaseDeviceAdaptor"] | None = get_device_adaptor() +DeviceOperator: type["BaseDeviceAdaptor"] = get_device_adaptor() diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 9b27ed14..39e80342 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -38,6 +38,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import ( TokenDispatcherWithMC2, ) from vllm_ascend.quantization.methods.base import QuantType +from vllm_ascend.quantization.quant_parser import parse_mxfp_quant_params _MoECommMethods: dict[MoECommType | None, MoECommMethod] = {} @@ -129,6 +130,7 @@ class MoECommMethod(ABC): dynamic_eplb: bool = False, mc2_mask: torch.Tensor = None, pertoken_scale: torch.Tensor | None = None, + **kwargs, ): # Check constraints assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8] @@ -140,20 +142,36 @@ class MoECommMethod(ABC): # Apply log2phy if needed if log2phy is not None: topk_ids = log2phy[topk_ids] - - dispatch_results = self.token_dispatcher.token_dispatch( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - global_redundant_expert_num=self.moe_config.global_redundant_expert_num, - mc2_mask=mc2_mask, - apply_router_weight_on_input=apply_router_weight_on_input, - with_quant=use_int8_w8a8 or use_int4_w4a8, - dynamic_eplb=dynamic_eplb, - pertoken_scale=pertoken_scale, + # TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced + # by different quantization modes will be consolidated into a dataclass in a follow-up. + use_mxfp_quant = kwargs.get("use_mxfp_quant", False) + dispatch_with_quant = use_int8_w8a8 or use_int4_w4a8 or use_mxfp_quant + act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode = parse_mxfp_quant_params( + **kwargs ) + dispatch_kwargs = { + "hidden_states": hidden_states, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "expert_map": expert_map, + "global_redundant_expert_num": self.moe_config.global_redundant_expert_num, + "mc2_mask": mc2_mask, + "apply_router_weight_on_input": apply_router_weight_on_input, + "dynamic_eplb": dynamic_eplb, + "pertoken_scale": pertoken_scale, + } + + if isinstance(self.token_dispatcher, TokenDispatcherWithMC2): + dispatch_kwargs["with_quant"] = dispatch_with_quant + dispatch_kwargs["comm_quant_mode"] = kwargs.get("comm_quant_mode") + dispatch_kwargs["y_dtype"] = act_quant_type if use_mxfp_quant else None + dispatch_kwargs["use_mxfp_quant"] = use_mxfp_quant + else: + dispatch_kwargs["with_quant"] = use_int8_w8a8 or use_int4_w4a8 + + dispatch_results = self.token_dispatcher.token_dispatch(**dispatch_kwargs) + mlp_output = unified_apply_mlp( hidden_states=dispatch_results.hidden_states, w1=w1, @@ -171,10 +189,18 @@ class MoECommMethod(ABC): w1_offset=w1_offset, w2_offset=w2_offset, topk_scales=dispatch_results.topk_scales, - with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16, - fusion=use_int8_w8a8 and self.use_fusion_ops, + with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16 or use_mxfp_quant, + fusion=(use_int8_w8a8 or use_mxfp_quant) and self.use_fusion_ops, need_trans=need_trans, dynamic_eplb=dynamic_eplb, + use_mxfp_quant=use_mxfp_quant, + act_quant_type=act_quant_type, + weight_quant_type=weight_quant_type, + scale_type=scale_type, + per_token_scale_type=per_token_scale_type, + round_mode=round_mode, + use_bf16=(hidden_states.dtype == torch.bfloat16), + rollback_quant_config=kwargs.get("rollback_quant_config"), ) before_combine_evt = torch.npu.current_stream().record_event() @@ -317,6 +343,7 @@ class FusedMC2CommImpl(MoECommMethod): dynamic_eplb: bool = False, mc2_mask: torch.Tensor = None, pertoken_scale: torch.Tensor | None = None, + **kwargs, ): assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 7b086b46..830bb6af 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -22,7 +22,11 @@ from vllm.forward_context import get_forward_context from vllm.triton_utils import HAS_TRITON from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.ops.activation import AscendSwigluOAIAndMul +from vllm_ascend.quantization.mxfp_compat import ( + ensure_mxfp8_moe_available, +) from vllm_ascend.utils import ( dispose_tensor, enable_custom_op, @@ -66,12 +70,22 @@ def cumsum_group_list( ) +def _require_single_tensor_for_swiglu_quant( + tensor_or_list: list[torch.Tensor] | torch.Tensor, *, name: str +) -> torch.Tensor: + if isinstance(tensor_or_list, list): + if len(tensor_or_list) != 1: + raise ValueError(f"{name} must be a tensor or a single-element list, but got {len(tensor_or_list)}.") + return tensor_or_list[0] + return tensor_or_list + + def quant_apply_mlp( hidden_states: torch.Tensor, - w1: list[torch.Tensor], - w1_scale: list[torch.Tensor], - w2: list[torch.Tensor], - w2_scale: list[torch.Tensor], + w1: list[torch.Tensor] | torch.Tensor, + w1_scale: list[torch.Tensor] | torch.Tensor, + w2: list[torch.Tensor] | torch.Tensor, + w2_scale: list[torch.Tensor] | torch.Tensor, group_list: torch.Tensor, group_list_type: int = 1, dynamic_scale: torch.Tensor = None, @@ -81,15 +95,45 @@ def quant_apply_mlp( w2_offset: torch.Tensor | None = None, fusion: bool = False, dynamic_eplb: bool = False, + **kwargs, ) -> torch.Tensor: + # TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different + # quantization modes will be consolidated into a dataclass in a follow-up. + use_mxfp_quant = kwargs.get("use_mxfp_quant", False) + act_quant_type = torch.float8_e4m3fn + weight_quant_type = None + scale_type = None + per_token_scale_type = None + use_bf16 = True + + input_hidden_dtype = hidden_states.dtype + use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb) + + if use_mxfp_quant: + act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn) + weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn) + scale_type = kwargs.get("scale_type") + per_token_scale_type = kwargs.get("per_token_scale_type") + use_bf16 = kwargs.get("use_bf16", True) + + ensure_mxfp8_moe_available("MXFP MoE MLP path") + + if w1_scale_bias is not None or w2_scale_bias is not None: + raise NotImplementedError("MXFP path does not support scale_bias yet.") + if w1_offset is not None or w2_offset is not None: + raise NotImplementedError("MXFP path does not support antiquant offset yet.") + if w1_offset is not None: unquantized_hidden_states = hidden_states quantized_hidden_states = None elif dynamic_scale is None: unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) - # Dispose the original unquantized hidden states - # to save npu memory because they're no longer used. + hidden_states, pertoken_scale = DeviceOperator.npu_dynamic_quant( + hidden_states=hidden_states, + dynamic_scale=None, + act_quant_type=act_quant_type, + use_mxfp_quant=use_mxfp_quant, + ) dispose_tensor(unquantized_hidden_states) quantized_hidden_states = None else: @@ -98,13 +142,14 @@ def quant_apply_mlp( quantized_hidden_states = hidden_states bias1, bias2 = None, None - _output_dtype = w2_scale[0].dtype + _output_dtype = w2_scale[0].dtype if isinstance(w2_scale, list) else w2_scale.dtype weight_prefetch_method = get_weight_prefetch_method() - weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states) + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states) is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and w1_offset is None and is_mc2: - if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list( x=hidden_states, @@ -113,14 +158,16 @@ def quant_apply_mlp( x_scale=pertoken_scale, group_list=cumsum_group_list(group_list, group_list_type, 0), ) - elif fusion and not dynamic_eplb: + elif use_gmm_swiglu_quant_fusion: # gmm1: gate_up_proj & act_fn: swiglu - hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( + hidden_states, swiglu_out_scale, _ = DeviceOperator.npu_grouped_matmul_swiglu_quant( x=hidden_states, - weight=w1[0], + weight=_require_single_tensor_for_swiglu_quant(w1, name="w1"), group_list=cumsum_group_list(group_list, group_list_type, 0), - weight_scale=w1_scale[0], + weight_scale=_require_single_tensor_for_swiglu_quant(w1_scale, name="w1_scale"), x_scale=pertoken_scale, + bias=None, + use_mxfp_quant=use_mxfp_quant, ) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) @@ -152,17 +199,23 @@ def quant_apply_mlp( quant_mode=1, ) # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], + hidden_states = DeviceOperator.npu_grouped_matmul_gmm2( + hidden_states=hidden_states, weight=w2, - scale=w2_scale, - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, + weight_scale=w2_scale, + per_token_scale=swiglu_out_scale, group_list=group_list, - output_dtype=w2_scale[0].dtype, - )[0] + group_list_type=group_list_type, + input_dtype=input_hidden_dtype, + act_quant_type=act_quant_type, + weight_quant_type=weight_quant_type, + scale_type=scale_type, + per_token_scale_type=per_token_scale_type, + use_bf16=use_bf16, + use_mxfp_quant=use_mxfp_quant, + bias=None, + fallback_output_dtype=w2_scale[0].dtype if isinstance(w2_scale, list) else w2_scale.dtype, + ) elif w1_offset is not None: # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( @@ -201,7 +254,7 @@ def quant_apply_mlp( # TODO w4a8 scene: dynamic acquisition of dtype in the future _output_dtype = torch.bfloat16 - if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list( x=hidden_states, @@ -211,15 +264,15 @@ def quant_apply_mlp( group_list=cumsum_group_list(group_list, group_list_type, 0), bias=bias1, ) - elif fusion and not dynamic_eplb: - # gmm1: gate_up_proj & act_fn: swiglu - hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( + elif use_gmm_swiglu_quant_fusion: + hidden_states, swiglu_out_scale, _ = DeviceOperator.npu_grouped_matmul_swiglu_quant( x=hidden_states, - weight=w1[0], - bias=bias1, + weight=_require_single_tensor_for_swiglu_quant(w1, name="w1"), group_list=cumsum_group_list(group_list, group_list_type, 0), - weight_scale=w1_scale[0], + weight_scale=_require_single_tensor_for_swiglu_quant(w1_scale, name="w1_scale"), x_scale=pertoken_scale, + bias=bias1, + use_mxfp_quant=use_mxfp_quant, ) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) @@ -251,18 +304,23 @@ def quant_apply_mlp( hidden_states = torch_npu.npu_swiglu(hidden_states) hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states) # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], + hidden_states = DeviceOperator.npu_grouped_matmul_gmm2( + hidden_states=hidden_states, weight=w2, - scale=w2_scale, - bias=bias2, - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, + weight_scale=w2_scale, + per_token_scale=swiglu_out_scale, group_list=group_list, - output_dtype=_output_dtype, - )[0] + group_list_type=group_list_type, + input_dtype=input_hidden_dtype, + act_quant_type=act_quant_type, + weight_quant_type=weight_quant_type, + scale_type=scale_type, + per_token_scale_type=per_token_scale_type, + use_bf16=use_bf16, + use_mxfp_quant=use_mxfp_quant, + bias=bias2, + fallback_output_dtype=_output_dtype, + ) return hidden_states @@ -334,26 +392,13 @@ def unified_apply_mlp( fusion: bool = False, need_trans: bool = True, dynamic_eplb: bool = False, + **kwargs, ) -> torch.Tensor: - if with_quant: - assert w1_scale is not None and w2_scale is not None - return quant_apply_mlp( - hidden_states=hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=group_list, - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - w1_offset=w1_offset, - w2_offset=w2_offset, - fusion=fusion, - dynamic_eplb=dynamic_eplb, - ) - else: + """ + Unified MoE MLP entry. + Quant path is dispatched by DeviceOperator with explicit quant-type flags. + """ + if not with_quant: return unquant_apply_mlp( hidden_states=hidden_states, w1=w1, @@ -366,3 +411,34 @@ def unified_apply_mlp( topk_scales=topk_scales, need_trans=need_trans, ) + + assert w1_scale is not None and w2_scale is not None + # TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different + # quantization modes will be consolidated into a dataclass in a follow-up. + act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn) + weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn) + scale_type = kwargs.get("scale_type") + per_token_scale_type = kwargs.get("per_token_scale_type") + use_mxfp_quant = kwargs.get("use_mxfp_quant", False) + return quant_apply_mlp( + hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, + fusion=fusion, + dynamic_eplb=dynamic_eplb, + act_quant_type=act_quant_type, + weight_quant_type=weight_quant_type, + scale_type=scale_type, + per_token_scale_type=per_token_scale_type, + use_mxfp_quant=use_mxfp_quant, + use_bf16=kwargs.get("use_bf16", True), + ) diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index 9fe9239e..e7b4cf98 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -76,7 +76,7 @@ class PrepareAndFinalize(ABC): router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts] enable_shared_expert_dp (bool): Skip DP communication for shared experts replace_allreduce (bool): Bypass default all-reduce behavior - quant_type: none, w8a8 or w4a8 + quant_type: none, w8a8, w4a8 or mxfp8 Returns: Tuple of: @@ -323,6 +323,10 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): pertoken_scale = None if quant_type == QuantType.W8A8: hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + elif quant_type == QuantType.MXFP8: + # TODO(linfeng): MXFP8 with AllGather+EP currently does not pre-quantize + # per-token activations in prepare. Keep quantization in the MoE MLP path. + pass if self.multistream_overlap_gate: assert PrepareAndFinalize.quant_stream is not None diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index d909ab89..bf4a5972 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -28,6 +28,7 @@ import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import get_ep_group +from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled @@ -103,8 +104,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.ep_rank_id = get_mc2_group().rank_in_group self.ep_world_size = get_mc2_group().world_size self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") - self.need_extra_args = get_ascend_device_type() == AscendDeviceType.A3 - + self.need_extra_args = get_ascend_device_type() in [AscendDeviceType.A3, AscendDeviceType.A5] + self.a5_need_extra_args = get_ascend_device_type() == AscendDeviceType.A5 # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. @@ -136,8 +137,21 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): expert_map: torch.Tensor, mc2_mask: torch.Tensor, global_redundant_expert_num: int = 0, + **kwargs, ): - quant_mode = 2 if self.with_quant else 0 + use_mxfp_quant = kwargs.get("use_mxfp_quant", False) + comm_quant_mode = kwargs.get("comm_quant_mode") + # NOTE: quant_mode differs by quant feature: + # - Legacy int communication quantization uses quant_mode=2. + # - A5 MXFP8 communication uses quant_mode=4. + # TODO(linfeng): The quantization-related parameters need to be consolidated into a single + # dataclass, and the FP8 MoE code path should be integrated into it going forward. + if comm_quant_mode is not None: + quant_mode = comm_quant_mode + elif self.with_quant: + quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2 + else: + quant_mode = 0 self.moe_expert_num = len(expert_map) + global_redundant_expert_num kwargs_mc2 = { "x": hidden_states, @@ -164,7 +178,12 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "tp_rank_id": 0, } ) - if self.need_expert_scale: + if self.a5_need_extra_args and use_mxfp_quant: + y_dtype = kwargs.get("y_dtype") + if self.with_quant: + y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype + stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype}) + if self.need_expert_scale or self.a5_need_extra_args: stage1_kwargs.update( { "expert_scales": topk_weights.to(torch.float32), @@ -186,11 +205,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): with_quant: bool = False, dynamic_eplb: bool = False, pertoken_scale: torch.Tensor | None = None, + **kwargs, ): self.with_quant = with_quant - kwargs_mc2 = self.get_dispatch_mc2_kwargs( - hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num + hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs ) output = ( torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2) @@ -337,19 +356,16 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): first_expert_idx = 0 last_expert_idx = self.num_experts_local global_num_experts = self.num_experts_local - - sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = ( - torch.ops._C_ascend.npu_moe_init_routing_custom( - hidden_states, - topk_ids, - scale=pertoken_scale, - active_num=num_tokens * self.top_k, - expert_num=global_num_experts, - expert_tokens_num_type=1, - expert_tokens_num_flag=True, - active_expert_range=[first_expert_idx, last_expert_idx], - quant_mode=1 if self.with_quant and pertoken_scale is None else -1, - ) + sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = DeviceOperator.npu_moe_init_routing( + hidden_states, + topk_ids, + scale=pertoken_scale, + active_num=num_tokens * self.top_k, + expert_num=global_num_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[first_expert_idx, last_expert_idx], + quant_mode=1 if self.with_quant and pertoken_scale is None else -1, ) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 # `count` mode diff --git a/vllm_ascend/quantization/methods/base.py b/vllm_ascend/quantization/methods/base.py index c525167a..9307eb92 100644 --- a/vllm_ascend/quantization/methods/base.py +++ b/vllm_ascend/quantization/methods/base.py @@ -30,6 +30,7 @@ class QuantType(Enum): NONE = 0 W8A8 = 1 W4A8 = 2 + MXFP8 = 3 class AscendLinearScheme(ABC): diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index dc772952..bc25074d 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -15,13 +15,24 @@ # limitations under the License. # +from collections.abc import Callable from typing import Any import torch import torch_npu -from vllm.config import get_current_vllm_config +from vllm.config import CompilationMode, get_current_vllm_config +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context -from .base import AscendLinearScheme +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.quantization.mxfp_compat import ( + FLOAT8_E8M0FNU_DTYPE, + ensure_mxfp8_linear_available, + ensure_mxfp8_moe_available, +) + +from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .registry import register_scheme @@ -37,6 +48,7 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): model_dtype = None def __init__(self): + ensure_mxfp8_linear_available("W8A8_MXFP8 linear quantization") vllm_config = get_current_vllm_config() self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32) @@ -66,9 +78,9 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): quantized_x, layer.weight, layer.weight_scale, - scale_dtype=torch_npu.float8_e8m0fnu, + scale_dtype=FLOAT8_E8M0FNU_DTYPE, pertoken_scale=pertoken_scale, - pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale_dtype=FLOAT8_E8M0FNU_DTYPE, bias=bias, output_dtype=output_dtype, group_sizes=[1, 1, self.group_size], @@ -81,3 +93,127 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) layer.weight.data = layer.weight.data.transpose(0, 1) layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1) + + +@register_scheme("W8A8_MXFP8", "moe") +class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme): + """FusedMoe method for Ascend W8A8_DYNAMIC.""" + + model_dtype = None + quant_type: QuantType = QuantType.MXFP8 + + def __init__(self): + ensure_mxfp8_moe_available("W8A8_MXFP8 MoE quantization") + self.ep_group = get_ep_group() + + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32) + ascend_config = get_ascend_config() + self.use_aclgraph = ( + vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE + and not vllm_config.model_config.enforce_eager + ) + self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb + + @staticmethod + def get_weight( + num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty( + num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.float8_e4m3fn + ) + param_dict["w2_weight"] = torch.empty( + num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.float8_e4m3fn + ) + return param_dict + + def get_dynamic_quant_param( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.uint8 + ) + + param_dict["w2_weight_scale"] = torch.empty( + num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.uint8 + ) + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + **kwargs, + ) -> torch.Tensor: + expected = global_num_experts - global_redundant_expert_num + assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)" + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int8_w8a8=False, + expert_map=expert_map, + log2phy=log2phy, + dynamic_eplb=self.dynamic_eplb, + mc2_mask=kwargs.get("mc2_mask"), + use_mxfp_quant=True, + act_quant_type=torch.float8_e4m3fn, + weight_quant_type=torch.float8_e4m3fn, + scale_type=FLOAT8_E8M0FNU_DTYPE, + per_token_scale_type=FLOAT8_E8M0FNU_DTYPE, + ) + + def process_weights_after_loading(self, layer): + g_num, n_size, k_size = layer.w13_weight_scale.shape + layer.w13_weight_scale.data = layer.w13_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2) + g_num, n_size, k_size = layer.w2_weight_scale.shape + layer.w2_weight_scale.data = layer.w2_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2) + layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2) + layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2) + layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(1, 2) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(1, 2) diff --git a/vllm_ascend/quantization/mxfp_compat.py b/vllm_ascend/quantization/mxfp_compat.py new file mode 100644 index 00000000..dabc8162 --- /dev/null +++ b/vllm_ascend/quantization/mxfp_compat.py @@ -0,0 +1,43 @@ +import torch +import torch_npu + +# TODO(linfeng): Temporary compatibility shim for MXFP4/MXFP8 because current torch_npu +# releases do not expose the required dtype attributes yet. Simplify or remove this +# file after the torch_npu release in March 2026 includes those dtype symbols. +FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)) +FLOAT4_E2M1FN_X2_DTYPE = getattr(torch_npu, "float4_e2m1fn_x2", getattr(torch, "float4_e2m1fn_x2", None)) +HIFLOAT8_DTYPE = getattr(torch_npu, "hifloat8", None) + + +def _get_missing_symbols(symbols: tuple[str, ...]) -> list[str]: + return [symbol for symbol in symbols if not hasattr(torch_npu, symbol)] + + +def _ensure_symbols_available(feature: str, symbols: tuple[str, ...]) -> None: + missing_symbols = _get_missing_symbols(symbols) + if not missing_symbols: + return + missing_symbols_str = ", ".join(missing_symbols) + raise RuntimeError( + f"{feature} requires a newer torch_npu runtime. Missing symbols: {missing_symbols_str}. " + "Please upgrade torch_npu or disable MXFP quantization." + ) + + +def ensure_mxfp8_scale_dtype_available(feature: str) -> None: + _ensure_symbols_available(feature, ("float8_e8m0fnu",)) + + +def ensure_mxfp4_dtype_available(feature: str) -> None: + _ensure_symbols_available(feature, ("float4_e2m1fn_x2", "float8_e8m0fnu")) + + +def ensure_mxfp8_linear_available(feature: str) -> None: + _ensure_symbols_available(feature, ("float8_e8m0fnu", "npu_dynamic_mx_quant", "npu_quant_matmul")) + + +def ensure_mxfp8_moe_available(feature: str) -> None: + _ensure_symbols_available( + feature, + ("float8_e8m0fnu", "npu_dynamic_mx_quant", "npu_grouped_matmul_swiglu_quant_v2"), + ) diff --git a/vllm_ascend/quantization/quant_parser.py b/vllm_ascend/quantization/quant_parser.py new file mode 100644 index 00000000..33144ce5 --- /dev/null +++ b/vllm_ascend/quantization/quant_parser.py @@ -0,0 +1,73 @@ +import torch + +from vllm_ascend.quantization.mxfp_compat import ( + FLOAT4_E2M1FN_X2_DTYPE, + FLOAT8_E8M0FNU_DTYPE, + ensure_mxfp4_dtype_available, + ensure_mxfp8_scale_dtype_available, +) + + +class QuantTypeMapping: + quant_configs = { + "W8A8_MXFP8": { + "act_quant_type": torch.float8_e4m3fn, + "weight_quant_type": None, + "scale_dtype": FLOAT8_E8M0FNU_DTYPE, + "per_token_scale_dtype": FLOAT8_E8M0FNU_DTYPE, + }, + "W4A4_MXFP4": { + "act_quant_type": FLOAT4_E2M1FN_X2_DTYPE, + "weight_quant_type": FLOAT4_E2M1FN_X2_DTYPE, + "scale_dtype": FLOAT8_E8M0FNU_DTYPE, + "per_token_scale_dtype": FLOAT8_E8M0FNU_DTYPE, + }, + "W4A8_MXFP": { + "act_quant_type": torch.float8_e4m3fn, + "weight_quant_type": FLOAT4_E2M1FN_X2_DTYPE, + "scale_dtype": FLOAT8_E8M0FNU_DTYPE, + "per_token_scale_dtype": FLOAT8_E8M0FNU_DTYPE, + }, + } + + @staticmethod + def get_quant_settings(): + return QuantTypeMapping.quant_configs + + +def get_rollback_quant_type(rollback_quant_config): + rollback_quant_type = "W8A8_MXFP8" + for k, v in rollback_quant_config.items(): + if "down_proj" in k: + rollback_quant_type = v + return rollback_quant_type + + +def parse_mxfp_quant_params(**kwargs): + act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn) + weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn) + scale_type = kwargs.get("scale_type") + per_token_scale_type = kwargs.get("per_token_scale_type") + round_mode = kwargs.get("round_mode", "rint") + return act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode + + +def parse_quant_moe_down_proj_params(rollback_quant_type, parsed_round_mode): + if rollback_quant_type == "W4A4_MXFP4": + ensure_mxfp4_dtype_available("W4A4_MXFP4 quantization") + elif rollback_quant_type in ("W8A8_MXFP8", "W4A8_MXFP"): + ensure_mxfp8_scale_dtype_available(f"{rollback_quant_type} quantization") + + quant_type_mapping = QuantTypeMapping.get_quant_settings() + cur_rollback_quant_config = quant_type_mapping[rollback_quant_type] + if rollback_quant_type in ["W4A4_MXFP4"]: # w4a4mxfp4 round mode support round、rint + round_mode = parsed_round_mode + else: # mxfp8 only support rint + round_mode = "rint" + return ( + cur_rollback_quant_config["act_quant_type"], + cur_rollback_quant_config["weight_quant_type"], + cur_rollback_quant_config["scale_dtype"], + cur_rollback_quant_config["per_token_scale_dtype"], + round_mode, + )