[Feat]Qwen3 Moe supports npu_add_rms_norm_quant op by default, update op with bias, resolve conflict with weight prefetch (#3465)

### What this PR does / why we need it?
1.qwen3 moe uses add_rms_norm_quant op instead of 'add_rms_norm op and
quant op' during quantization scene.
2.torch_npu.add_rms_norm_quant op fixed accuracy while model weights is
quantized by anti_method m4, m4 quantization is asymmetric outlier
suppression method, it will generate none-zero norm bias,
add_rms_norm_quant op updated to add this parameter to calculate.
3. add torch-npu check

### Does this PR introduce _any_ user-facing change?
new feature works if torch_npu version >= torch_npu-2.7.1.dev20250919

### How was this patch tested?
1.no special parameters to set, no new envs to set. new feature works if
torch_npu version >= torch_npu-2.7.1.dev20250919
2.use qwen3 moe quantization model to test ,such as
Qwen3-235B-A22B-W8A8, Qwen3-30B-A3B-W8A8,
Qwen3-235B-A22B-Instruct-2507-m4 (anti_method m4)

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: h30027576 <huangdong51@huawei.com>
This commit is contained in:
huangdong2022
2025-10-17 09:30:51 +08:00
committed by GitHub
parent 4c4a8458a5
commit 3a53bbc508
9 changed files with 121 additions and 38 deletions

View File

@@ -546,7 +546,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
if vllm_config is not None and \
vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()) and \
not version_check():
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
for name, op_cls in REGISTERED_ASCEND_OPS.items():
@@ -725,3 +726,18 @@ def calculate_dp_buffer_size() -> int:
def is_hierarchical_communication_enabled():
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
@functools.cache
def version_check():
"""check if torch_npu version >= dev20250919"""
import re
torch_npu_version = torch_npu.version.__version__
date_pattern = r'dev(\d{8})'
match = re.search(date_pattern, torch_npu_version)
if match:
full_date = match.group(1)
if full_date >= "20250919":
return True
return False