[Feat]Qwen3 Moe supports npu_add_rms_norm_quant op by default, update op with bias, resolve conflict with weight prefetch (#3465)
### What this PR does / why we need it? 1.qwen3 moe uses add_rms_norm_quant op instead of 'add_rms_norm op and quant op' during quantization scene. 2.torch_npu.add_rms_norm_quant op fixed accuracy while model weights is quantized by anti_method m4, m4 quantization is asymmetric outlier suppression method, it will generate none-zero norm bias, add_rms_norm_quant op updated to add this parameter to calculate. 3. add torch-npu check ### Does this PR introduce _any_ user-facing change? new feature works if torch_npu version >= torch_npu-2.7.1.dev20250919 ### How was this patch tested? 1.no special parameters to set, no new envs to set. new feature works if torch_npu version >= torch_npu-2.7.1.dev20250919 2.use qwen3 moe quantization model to test ,such as Qwen3-235B-A22B-W8A8, Qwen3-30B-A3B-W8A8, Qwen3-235B-A22B-Instruct-2507-m4 (anti_method m4) - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: h30027576 <huangdong51@huawei.com>
This commit is contained in:
@@ -18,28 +18,43 @@
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
from vllm_ascend.utils import version_check
|
||||
|
||||
|
||||
def _addrmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
layer: Optional[torch.nn.Module] = None,
|
||||
bias: Optional[torch.nn.Parameter] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
torch_npu_check = version_check()
|
||||
if layer is not None and not is_310p():
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
if torch_npu_check:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
beta=bias,
|
||||
epsilon=self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
else:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
@@ -50,12 +65,32 @@ def _addrmsnorm_forward_oot(
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
if torch_npu_check and bias is not None:
|
||||
x.add_(bias)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.bias = None
|
||||
self.torch_npu_check = version_check()
|
||||
# quantization with anti_method m4 will generate none-zero norm bias
|
||||
if self.torch_npu_check and vllm_config.quant_config is not None and \
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -66,10 +101,13 @@ class AscendRMSNorm(RMSNorm):
|
||||
if residual is not None:
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, residual = _addrmsnorm_forward_oot(
|
||||
self, x, residual, self.next_need_quant_fusion_linear)
|
||||
self, x, residual, self.next_need_quant_fusion_linear,
|
||||
self.bias)
|
||||
return x, residual
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
if self.torch_npu_check and self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
|
||||
@property
|
||||
@@ -99,6 +137,13 @@ class AscendRMSNorm(RMSNorm):
|
||||
# does not need to be repeated
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
forward_context.layer_idx += 1
|
||||
elif fusion_linear == "qkv_moe":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].self_attn.qkv_proj
|
||||
forward_context.fusion_linear = "gate_moe"
|
||||
elif fusion_linear == "gate_moe":
|
||||
forward_context.fusion_linear = "qkv_moe"
|
||||
forward_context.layer_idx += 1
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
if next_linear is not None and \
|
||||
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
||||
|
||||
@@ -177,7 +177,6 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.utils import version_check
|
||||
|
||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||
MOE_PREFETCH_TOKEN_THRESHOLD = 96
|
||||
@@ -82,14 +83,15 @@ class WeightPrefetchMethod:
|
||||
if not self.moe.is_active_this_forward:
|
||||
return
|
||||
forward_context = get_forward_context()
|
||||
if not version_check():
|
||||
forward_context.layer_idx += 1
|
||||
weight = forward_context.model_instance.model.layers[
|
||||
forward_context.layer_idx].mlp.experts.w13_weight
|
||||
forward_context.layer_idx - 1].mlp.experts.w13_weight
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.moe.prefetch_ratio.get(prefix, 0)
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=None,
|
||||
max_weight_size=int(weight_size))
|
||||
forward_context.layer_idx += 1
|
||||
|
||||
def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor):
|
||||
if not self.moe.is_active_this_forward:
|
||||
|
||||
Reference in New Issue
Block a user