add mxfp8 moe quantization (#6670)
### What this PR does / why we need it?
support mxfp8 quantization (Qwen MOE )
Using adaptor to make the hardware-specific behavior clearer and more
maintainable
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: fangrongcan <17343701736@163.com>
Signed-off-by: wangyao-i <iwangyao@outlook.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com>
Co-authored-by: fangrongcan <f00876277@china.huawei.com>
Co-authored-by: wangyao-i <iwangyao@outlook.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -30,6 +30,7 @@ class QuantType(Enum):
|
||||
NONE = 0
|
||||
W8A8 = 1
|
||||
W4A8 = 2
|
||||
MXFP8 = 3
|
||||
|
||||
|
||||
class AscendLinearScheme(ABC):
|
||||
|
||||
@@ -15,13 +15,24 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from .base import AscendLinearScheme
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.quantization.mxfp_compat import (
|
||||
FLOAT8_E8M0FNU_DTYPE,
|
||||
ensure_mxfp8_linear_available,
|
||||
ensure_mxfp8_moe_available,
|
||||
)
|
||||
|
||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
from .registry import register_scheme
|
||||
|
||||
|
||||
@@ -37,6 +48,7 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
|
||||
model_dtype = None
|
||||
|
||||
def __init__(self):
|
||||
ensure_mxfp8_linear_available("W8A8_MXFP8 linear quantization")
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
|
||||
|
||||
@@ -66,9 +78,9 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
|
||||
quantized_x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
scale_dtype=torch_npu.float8_e8m0fnu,
|
||||
scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||
pertoken_scale=pertoken_scale,
|
||||
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
|
||||
pertoken_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||
bias=bias,
|
||||
output_dtype=output_dtype,
|
||||
group_sizes=[1, 1, self.group_size],
|
||||
@@ -81,3 +93,127 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
|
||||
layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2)
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1)
|
||||
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)
|
||||
|
||||
|
||||
@register_scheme("W8A8_MXFP8", "moe")
|
||||
class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
"""FusedMoe method for Ascend W8A8_DYNAMIC."""
|
||||
|
||||
model_dtype = None
|
||||
quant_type: QuantType = QuantType.MXFP8
|
||||
|
||||
def __init__(self):
|
||||
ensure_mxfp8_moe_available("W8A8_MXFP8 MoE quantization")
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
|
||||
ascend_config = get_ascend_config()
|
||||
self.use_aclgraph = (
|
||||
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
)
|
||||
self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
|
||||
) -> dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight"] = torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
param_dict["w2_weight"] = torch.empty(
|
||||
num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(
|
||||
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
|
||||
) -> dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.uint8
|
||||
)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(
|
||||
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.uint8
|
||||
)
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
expected = global_num_experts - global_redundant_expert_num
|
||||
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=False,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask"),
|
||||
use_mxfp_quant=True,
|
||||
act_quant_type=torch.float8_e4m3fn,
|
||||
weight_quant_type=torch.float8_e4m3fn,
|
||||
scale_type=FLOAT8_E8M0FNU_DTYPE,
|
||||
per_token_scale_type=FLOAT8_E8M0FNU_DTYPE,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
g_num, n_size, k_size = layer.w13_weight_scale.shape
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2)
|
||||
g_num, n_size, k_size = layer.w2_weight_scale.shape
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2)
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2)
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(1, 2)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(1, 2)
|
||||
|
||||
43
vllm_ascend/quantization/mxfp_compat.py
Normal file
43
vllm_ascend/quantization/mxfp_compat.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
# TODO(linfeng): Temporary compatibility shim for MXFP4/MXFP8 because current torch_npu
|
||||
# releases do not expose the required dtype attributes yet. Simplify or remove this
|
||||
# file after the torch_npu release in March 2026 includes those dtype symbols.
|
||||
FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None))
|
||||
FLOAT4_E2M1FN_X2_DTYPE = getattr(torch_npu, "float4_e2m1fn_x2", getattr(torch, "float4_e2m1fn_x2", None))
|
||||
HIFLOAT8_DTYPE = getattr(torch_npu, "hifloat8", None)
|
||||
|
||||
|
||||
def _get_missing_symbols(symbols: tuple[str, ...]) -> list[str]:
|
||||
return [symbol for symbol in symbols if not hasattr(torch_npu, symbol)]
|
||||
|
||||
|
||||
def _ensure_symbols_available(feature: str, symbols: tuple[str, ...]) -> None:
|
||||
missing_symbols = _get_missing_symbols(symbols)
|
||||
if not missing_symbols:
|
||||
return
|
||||
missing_symbols_str = ", ".join(missing_symbols)
|
||||
raise RuntimeError(
|
||||
f"{feature} requires a newer torch_npu runtime. Missing symbols: {missing_symbols_str}. "
|
||||
"Please upgrade torch_npu or disable MXFP quantization."
|
||||
)
|
||||
|
||||
|
||||
def ensure_mxfp8_scale_dtype_available(feature: str) -> None:
|
||||
_ensure_symbols_available(feature, ("float8_e8m0fnu",))
|
||||
|
||||
|
||||
def ensure_mxfp4_dtype_available(feature: str) -> None:
|
||||
_ensure_symbols_available(feature, ("float4_e2m1fn_x2", "float8_e8m0fnu"))
|
||||
|
||||
|
||||
def ensure_mxfp8_linear_available(feature: str) -> None:
|
||||
_ensure_symbols_available(feature, ("float8_e8m0fnu", "npu_dynamic_mx_quant", "npu_quant_matmul"))
|
||||
|
||||
|
||||
def ensure_mxfp8_moe_available(feature: str) -> None:
|
||||
_ensure_symbols_available(
|
||||
feature,
|
||||
("float8_e8m0fnu", "npu_dynamic_mx_quant", "npu_grouped_matmul_swiglu_quant_v2"),
|
||||
)
|
||||
73
vllm_ascend/quantization/quant_parser.py
Normal file
73
vllm_ascend/quantization/quant_parser.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
|
||||
from vllm_ascend.quantization.mxfp_compat import (
|
||||
FLOAT4_E2M1FN_X2_DTYPE,
|
||||
FLOAT8_E8M0FNU_DTYPE,
|
||||
ensure_mxfp4_dtype_available,
|
||||
ensure_mxfp8_scale_dtype_available,
|
||||
)
|
||||
|
||||
|
||||
class QuantTypeMapping:
|
||||
quant_configs = {
|
||||
"W8A8_MXFP8": {
|
||||
"act_quant_type": torch.float8_e4m3fn,
|
||||
"weight_quant_type": None,
|
||||
"scale_dtype": FLOAT8_E8M0FNU_DTYPE,
|
||||
"per_token_scale_dtype": FLOAT8_E8M0FNU_DTYPE,
|
||||
},
|
||||
"W4A4_MXFP4": {
|
||||
"act_quant_type": FLOAT4_E2M1FN_X2_DTYPE,
|
||||
"weight_quant_type": FLOAT4_E2M1FN_X2_DTYPE,
|
||||
"scale_dtype": FLOAT8_E8M0FNU_DTYPE,
|
||||
"per_token_scale_dtype": FLOAT8_E8M0FNU_DTYPE,
|
||||
},
|
||||
"W4A8_MXFP": {
|
||||
"act_quant_type": torch.float8_e4m3fn,
|
||||
"weight_quant_type": FLOAT4_E2M1FN_X2_DTYPE,
|
||||
"scale_dtype": FLOAT8_E8M0FNU_DTYPE,
|
||||
"per_token_scale_dtype": FLOAT8_E8M0FNU_DTYPE,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_quant_settings():
|
||||
return QuantTypeMapping.quant_configs
|
||||
|
||||
|
||||
def get_rollback_quant_type(rollback_quant_config):
|
||||
rollback_quant_type = "W8A8_MXFP8"
|
||||
for k, v in rollback_quant_config.items():
|
||||
if "down_proj" in k:
|
||||
rollback_quant_type = v
|
||||
return rollback_quant_type
|
||||
|
||||
|
||||
def parse_mxfp_quant_params(**kwargs):
|
||||
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
|
||||
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
|
||||
scale_type = kwargs.get("scale_type")
|
||||
per_token_scale_type = kwargs.get("per_token_scale_type")
|
||||
round_mode = kwargs.get("round_mode", "rint")
|
||||
return act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode
|
||||
|
||||
|
||||
def parse_quant_moe_down_proj_params(rollback_quant_type, parsed_round_mode):
|
||||
if rollback_quant_type == "W4A4_MXFP4":
|
||||
ensure_mxfp4_dtype_available("W4A4_MXFP4 quantization")
|
||||
elif rollback_quant_type in ("W8A8_MXFP8", "W4A8_MXFP"):
|
||||
ensure_mxfp8_scale_dtype_available(f"{rollback_quant_type} quantization")
|
||||
|
||||
quant_type_mapping = QuantTypeMapping.get_quant_settings()
|
||||
cur_rollback_quant_config = quant_type_mapping[rollback_quant_type]
|
||||
if rollback_quant_type in ["W4A4_MXFP4"]: # w4a4mxfp4 round mode support round、rint
|
||||
round_mode = parsed_round_mode
|
||||
else: # mxfp8 only support rint
|
||||
round_mode = "rint"
|
||||
return (
|
||||
cur_rollback_quant_config["act_quant_type"],
|
||||
cur_rollback_quant_config["weight_quant_type"],
|
||||
cur_rollback_quant_config["scale_dtype"],
|
||||
cur_rollback_quant_config["per_token_scale_dtype"],
|
||||
round_mode,
|
||||
)
|
||||
Reference in New Issue
Block a user