[Graph][Fusion]Add new pattern for AddRmsnormQuant with SP. (#5077)
### What this PR does / why we need it?
1. In addition to
[#4168](https://github.com/vllm-project/vllm-ascend/pull/4168),
[#5011](https://github.com/vllm-project/vllm-ascend/pull/5011), this PR
adds two more pattern for AddRmsnormQuant with SP enabled. The key
difference is to insert an additional `maybe_all_gather_and_maybe_unpad`
between `addrmsnorm` and `quantize`.
2. This PR also introduce another api `torch.ops.vllm.quantize`, so that
we pass `input_scale` and `input_scale_reciprocal` at the same time.
This is because `npu_add_rms_norm_quant` and `npu_quantize` requires
different `div_mode`. To avoid introducing additional reciprocal
calculation in runtime, we have to pass both of them to quantize api.
3. Removes redundant `AscendQuantRmsnorm`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -28,24 +28,29 @@ class AddRMSNormQuantPattern:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.eps = eps
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs for the AddRMSNormQuant fusion pattern.
|
||||
"""
|
||||
rms_norm_input = torch.randn(2, 4, device="npu")
|
||||
residual = torch.randn(2, 4, device="npu")
|
||||
rms_norm_weight = torch.randn(4, device="npu")
|
||||
scale = torch.tensor([1.0], device="npu")
|
||||
offset = torch.tensor([0.0], device="npu")
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
|
||||
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
@@ -53,24 +58,23 @@ class AddRMSNormQuantPattern:
|
||||
rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.npu.npu_quantize(
|
||||
out0, scale, offset, torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
1. /
|
||||
scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
@@ -83,25 +87,31 @@ class AddRMSNormQuantPatternWithBias:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.eps = eps
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs for the AddRMSNormQuant fusion pattern.
|
||||
"""
|
||||
rms_norm_input = torch.randn(2, 4, device="npu")
|
||||
residual = torch.randn(2, 4, device="npu")
|
||||
rms_norm_weight = torch.randn(4, device="npu")
|
||||
scale = torch.tensor([1.0], device="npu")
|
||||
offset = torch.tensor([0.0], device="npu")
|
||||
bias = torch.randn(4, device="npu")
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, offset, bias]
|
||||
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset, rmsnorm_bias
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
@@ -110,25 +120,25 @@ class AddRMSNormQuantPatternWithBias:
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
quantized_output = torch.ops.npu.npu_quantize(
|
||||
out0, scale, offset, torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
1. /
|
||||
scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
@@ -137,6 +147,135 @@ class AddRMSNormQuantPatternWithBias:
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantSPPattern:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.eps = eps
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs for the AddRMSNormQuant fusion pattern.
|
||||
"""
|
||||
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantSPPatternWithBias:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.eps = eps
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs for the AddRMSNormQuant fusion pattern.
|
||||
"""
|
||||
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset, rmsnorm_bias
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
"""
|
||||
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
|
||||
@@ -159,6 +298,10 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
@@ -70,31 +70,6 @@ class AscendRMSNorm(RMSNorm):
|
||||
return x
|
||||
|
||||
|
||||
class AscendQuantRMSNorm(AscendRMSNorm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
x, residual = super().forward_oot(x, residual)
|
||||
return x.add_(self.bias), residual
|
||||
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
|
||||
|
||||
|
||||
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
||||
|
||||
def forward_oot(
|
||||
|
||||
@@ -545,8 +545,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
|
||||
quant_per_tensor)
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
|
||||
# For unquant
|
||||
if mmrs_fusion and isinstance(self.layer.quant_method,
|
||||
@@ -568,8 +567,9 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
and isinstance(self.layer.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod)):
|
||||
if x.dtype != torch.int8:
|
||||
x_quant = quant_per_tensor(
|
||||
x, self.layer.aclnn_input_scale_reciprocal,
|
||||
x_quant = torch.ops.vllm.quantize(
|
||||
x, self.layer.aclnn_input_scale,
|
||||
self.layer.aclnn_input_scale_reciprocal,
|
||||
self.layer.aclnn_input_offset)
|
||||
else:
|
||||
x_quant = x
|
||||
|
||||
@@ -282,6 +282,26 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
|
||||
return output
|
||||
|
||||
|
||||
# TODO(Angazenn): The reason why we use a custom op to encapsulate npu_quantize
|
||||
# is that aclnnAscendQuantV3(npu_quantize) use div_mode=False, while
|
||||
# aclnnAddRmsNormQuantV2(npu_add_rms_norm_quant) use div_moe=True. We have to
|
||||
# pass input_scale and input_scale_reciprocal at the same time to avoid redundant
|
||||
# reciprocal calculation in fussion pass. We shall remove this once
|
||||
# aclnnAddRmsNormQuantV2 supports div_moe=False.
|
||||
def _quantize_impl(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_scale_reciprocal: torch.Tensor,
|
||||
input_offset: torch.Tensor) -> torch.Tensor:
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
|
||||
input_offset, torch.qint8, -1, False)
|
||||
|
||||
|
||||
def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_scale_reciprocal: torch.Tensor,
|
||||
input_offset: torch.Tensor) -> torch.Tensor:
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
|
||||
input_offset, torch.qint8, -1, False)
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: x,
|
||||
@@ -341,3 +361,9 @@ direct_register_custom_op(op_name="matmul_and_reduce",
|
||||
fake_impl=_matmul_and_reduce_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="quantize",
|
||||
op_func=_quantize_impl,
|
||||
fake_impl=_quantize_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
@@ -128,8 +128,9 @@ class AscendW8A8LinearMethod:
|
||||
if enable_flashcomm2_quant_comm:
|
||||
quant_input_x = x.contiguous().view(
|
||||
-1, layer.aclnn_input_scale_reciprocal.size(0))
|
||||
quant_x = quant_per_tensor(
|
||||
quant_x = torch.ops.vllm.quantize(
|
||||
quant_input_x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
@@ -138,8 +139,9 @@ class AscendW8A8LinearMethod:
|
||||
x = comm_fn(comm_input)
|
||||
else:
|
||||
# quant
|
||||
x = quant_per_tensor(
|
||||
x = torch.ops.vllm.quantize(
|
||||
x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user