[Graph][Fusion] Integrating inductor pass and npugraph ex pass (#6354)

### What this PR does / why we need it?
Integrating inductor pass and npugraph ex pass, see RFC:
https://github.com/vllm-project/vllm-ascend/issues/6347

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
all tests passed.

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2026-02-13 15:34:55 +08:00
committed by GitHub
parent 87a0b7b7c7
commit 7164990904
16 changed files with 220 additions and 909 deletions

View File

@@ -16,12 +16,12 @@
# limitations under the License.
#
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger
from vllm_ascend.compilation.passes.base_pattern import BasePattern
from vllm_ascend.utils import enable_custom_op, vllm_version_is
if vllm_version_is("0.15.0"):
@@ -30,11 +30,9 @@ else:
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
class AddRMSNormQuantPattern:
class AddRMSNormQuantPattern(BasePattern):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
super().__init__(vllm_config, eps)
def get_inputs(self):
"""
@@ -48,7 +46,7 @@ class AddRMSNormQuantPattern:
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
def register(self, pm_pass: PatternMatcherPass):
def get_pattern(self):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -68,6 +66,9 @@ class AddRMSNormQuantPattern:
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
return pattern
def get_replacement(self):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -86,14 +87,12 @@ class AddRMSNormQuantPattern:
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
return replacement
class AddRMSNormQuantPatternWithBias:
class AddRMSNormQuantPatternWithBias(BasePattern):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
super().__init__(vllm_config, eps)
def get_inputs(self):
"""
@@ -108,7 +107,7 @@ class AddRMSNormQuantPatternWithBias:
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def get_pattern(self):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -129,6 +128,9 @@ class AddRMSNormQuantPatternWithBias:
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
return pattern
def get_replacement(self):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -148,14 +150,12 @@ class AddRMSNormQuantPatternWithBias:
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
return replacement
class AddRMSNormQuantSPPattern:
class AddRMSNormQuantSPPattern(BasePattern):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
super().__init__(vllm_config, eps)
def get_inputs(self):
"""
@@ -169,7 +169,7 @@ class AddRMSNormQuantSPPattern:
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
def register(self, pm_pass: PatternMatcherPass):
def get_pattern(self):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -190,6 +190,9 @@ class AddRMSNormQuantSPPattern:
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
return pattern
def get_replacement(self):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -209,14 +212,12 @@ class AddRMSNormQuantSPPattern:
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
return replacement
class AddRMSNormQuantSPPatternWithBias:
class AddRMSNormQuantSPPatternWithBias(BasePattern):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
super().__init__(vllm_config, eps)
def get_inputs(self):
"""
@@ -231,7 +232,7 @@ class AddRMSNormQuantSPPatternWithBias:
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def get_pattern(self):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -253,6 +254,9 @@ class AddRMSNormQuantSPPatternWithBias:
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
return pattern
def get_replacement(self):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -273,14 +277,12 @@ class AddRMSNormQuantSPPatternWithBias:
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
return replacement
class AddRMSNormDynamicQuantPattern:
class AddRMSNormDynamicQuantPattern(BasePattern):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
super().__init__(vllm_config, eps)
def get_inputs(self):
"""
@@ -291,7 +293,7 @@ class AddRMSNormDynamicQuantPattern:
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight]
def register(self, pm_pass: PatternMatcherPass):
def get_pattern(self):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
@@ -302,6 +304,9 @@ class AddRMSNormDynamicQuantPattern:
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1
return pattern
def get_replacement(self):
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
@@ -315,14 +320,12 @@ class AddRMSNormDynamicQuantPattern:
output[2],
)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
return replacement
class AddRMSNormDynamicQuantPatternWithBias:
class AddRMSNormDynamicQuantPatternWithBias(BasePattern):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
super().__init__(vllm_config, eps)
def get_inputs(self):
"""
@@ -334,7 +337,7 @@ class AddRMSNormDynamicQuantPatternWithBias:
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def get_pattern(self):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -352,6 +355,9 @@ class AddRMSNormDynamicQuantPatternWithBias:
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1
return pattern
def get_replacement(self):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -370,14 +376,12 @@ class AddRMSNormDynamicQuantPatternWithBias:
output[2],
)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
return replacement
class AddRMSNormDynamicQuantSPPattern:
class AddRMSNormDynamicQuantSPPattern(BasePattern):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
super().__init__(vllm_config, eps)
def get_inputs(self):
"""
@@ -388,7 +392,7 @@ class AddRMSNormDynamicQuantSPPattern:
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight]
def register(self, pm_pass: PatternMatcherPass):
def get_pattern(self):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
@@ -400,6 +404,9 @@ class AddRMSNormDynamicQuantSPPattern:
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1
return pattern
def get_replacement(self):
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
@@ -412,14 +419,12 @@ class AddRMSNormDynamicQuantSPPattern:
out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True)
return quantized_output, out3, output[2]
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
return replacement
class AddRMSNormDynamicQuantSPPatternWithBias:
class AddRMSNormDynamicQuantSPPatternWithBias(BasePattern):
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
super().__init__(vllm_config, eps)
def get_inputs(self):
"""
@@ -431,7 +436,7 @@ class AddRMSNormDynamicQuantSPPatternWithBias:
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def get_pattern(self):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -450,6 +455,9 @@ class AddRMSNormDynamicQuantSPPatternWithBias:
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1
return pattern
def get_replacement(self):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
@@ -467,7 +475,7 @@ class AddRMSNormDynamicQuantSPPatternWithBias:
out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True)
return quantized_output, out3, output[2]
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
return replacement
class AddRMSNormQuantFusionPass(VllmInductorPass):