[Graph][Fusion]Add new pattern for AddRmsnormQuant with SP. (#5077)

### What this PR does / why we need it?
1. In addition to
[#4168](https://github.com/vllm-project/vllm-ascend/pull/4168),
[#5011](https://github.com/vllm-project/vllm-ascend/pull/5011), this PR
adds two more pattern for AddRmsnormQuant with SP enabled. The key
difference is to insert an additional `maybe_all_gather_and_maybe_unpad`
between `addrmsnorm` and `quantize`.
2. This PR also introduce another api `torch.ops.vllm.quantize`, so that
we pass `input_scale` and `input_scale_reciprocal` at the same time.
This is because `npu_add_rms_norm_quant` and `npu_quantize` requires
different `div_mode`. To avoid introducing additional reciprocal
calculation in runtime, we have to pass both of them to quantize api.
3. Removes redundant `AscendQuantRmsnorm`.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2025-12-18 20:25:44 +08:00
committed by GitHub
parent a74a1196c5
commit acc3578f58
7 changed files with 454 additions and 116 deletions

View File

@@ -15,7 +15,7 @@
# This file is a part of the vllm-ascend project.
#
from typing import Optional, Tuple, Union, cast
from typing import Optional, Tuple, Union
import torch
from vllm.config import get_current_vllm_config
@@ -70,31 +70,6 @@ class AscendRMSNorm(RMSNorm):
return x
class AscendQuantRMSNorm(AscendRMSNorm):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
x, residual = super().forward_oot(x, residual)
return x.add_(self.bias), residual
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
class AscendGemmaRMSNorm(GemmaRMSNorm):
def forward_oot(

View File

@@ -545,8 +545,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
quant_per_tensor)
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
# For unquant
if mmrs_fusion and isinstance(self.layer.quant_method,
@@ -568,8 +567,9 @@ class SequenceRowParallelOp(CustomRowParallelOp):
and isinstance(self.layer.quant_method.quant_method,
AscendW8A8LinearMethod)):
if x.dtype != torch.int8:
x_quant = quant_per_tensor(
x, self.layer.aclnn_input_scale_reciprocal,
x_quant = torch.ops.vllm.quantize(
x, self.layer.aclnn_input_scale,
self.layer.aclnn_input_scale_reciprocal,
self.layer.aclnn_input_offset)
else:
x_quant = x

View File

@@ -282,6 +282,26 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
return output
# TODO(Angazenn): The reason why we use a custom op to encapsulate npu_quantize
# is that aclnnAscendQuantV3(npu_quantize) use div_mode=False, while
# aclnnAddRmsNormQuantV2(npu_add_rms_norm_quant) use div_moe=True. We have to
# pass input_scale and input_scale_reciprocal at the same time to avoid redundant
# reciprocal calculation in fussion pass. We shall remove this once
# aclnnAddRmsNormQuantV2 supports div_moe=False.
def _quantize_impl(in_tensor: torch.Tensor, input_scale: torch.Tensor,
input_scale_reciprocal: torch.Tensor,
input_offset: torch.Tensor) -> torch.Tensor:
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
input_offset, torch.qint8, -1, False)
def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor,
input_scale_reciprocal: torch.Tensor,
input_offset: torch.Tensor) -> torch.Tensor:
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
input_offset, torch.qint8, -1, False)
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: x,
@@ -341,3 +361,9 @@ direct_register_custom_op(op_name="matmul_and_reduce",
fake_impl=_matmul_and_reduce_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="quantize",
op_func=_quantize_impl,
fake_impl=_quantize_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")