[Inductor]change pass to adapt to new addrmsnormBias operator (#6094)

### What this PR does / why we need it?
#5790 changes default addrmsnormBias operator if custom ops is enabled.
This PR modifies AddRmsNormQuant pass to align with addrmsnormBias.

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-01-24 20:16:44 +08:00
committed by GitHub
parent 8966a99710
commit 5b746f3e83
2 changed files with 20 additions and 20 deletions

View File

@@ -23,6 +23,8 @@ from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger
from vllm_ascend.utils import enable_custom_op
class AddRMSNormQuantPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
@@ -113,10 +115,11 @@ class AddRMSNormQuantPatternWithBias:
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
rms_norm_input, residual, rms_norm_weight, bias, self.eps
)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
@@ -233,10 +236,11 @@ class AddRMSNormQuantSPPatternWithBias:
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
rms_norm_input, residual, rms_norm_weight, bias, self.eps
)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
@@ -281,9 +285,10 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
if enable_custom_op():
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()