Files
xc-llm-ascend/vllm_ascend/quantization/mxfp_compat.py
Eric-dot 3c66a970f2 add mxfp8 moe quantization (#6670)
### What this PR does / why we need it?
support mxfp8 quantization (Qwen MOE )
Using adaptor to make the hardware-specific behavior clearer and more
maintainable
### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
13397841ab

---------

Signed-off-by: fangrongcan <17343701736@163.com>
Signed-off-by: wangyao-i <iwangyao@outlook.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com>
Co-authored-by: fangrongcan <f00876277@china.huawei.com>
Co-authored-by: wangyao-i <iwangyao@outlook.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
2026-03-02 11:04:06 +08:00

44 lines
1.7 KiB
Python

import torch
import torch_npu
# TODO(linfeng): Temporary compatibility shim for MXFP4/MXFP8 because current torch_npu
# releases do not expose the required dtype attributes yet. Simplify or remove this
# file after the torch_npu release in March 2026 includes those dtype symbols.
FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None))
FLOAT4_E2M1FN_X2_DTYPE = getattr(torch_npu, "float4_e2m1fn_x2", getattr(torch, "float4_e2m1fn_x2", None))
HIFLOAT8_DTYPE = getattr(torch_npu, "hifloat8", None)
def _get_missing_symbols(symbols: tuple[str, ...]) -> list[str]:
return [symbol for symbol in symbols if not hasattr(torch_npu, symbol)]
def _ensure_symbols_available(feature: str, symbols: tuple[str, ...]) -> None:
missing_symbols = _get_missing_symbols(symbols)
if not missing_symbols:
return
missing_symbols_str = ", ".join(missing_symbols)
raise RuntimeError(
f"{feature} requires a newer torch_npu runtime. Missing symbols: {missing_symbols_str}. "
"Please upgrade torch_npu or disable MXFP quantization."
)
def ensure_mxfp8_scale_dtype_available(feature: str) -> None:
_ensure_symbols_available(feature, ("float8_e8m0fnu",))
def ensure_mxfp4_dtype_available(feature: str) -> None:
_ensure_symbols_available(feature, ("float4_e2m1fn_x2", "float8_e8m0fnu"))
def ensure_mxfp8_linear_available(feature: str) -> None:
_ensure_symbols_available(feature, ("float8_e8m0fnu", "npu_dynamic_mx_quant", "npu_quant_matmul"))
def ensure_mxfp8_moe_available(feature: str) -> None:
_ensure_symbols_available(
feature,
("float8_e8m0fnu", "npu_dynamic_mx_quant", "npu_grouped_matmul_swiglu_quant_v2"),
)