[Graph][Fusion] Add AddRMSNormSPPattern and AddRMSNormSPPatternWithBias (#5569)

### What this PR does / why we need it?
This PR builds upon PR
https://github.com/vllm-project/vllm-ascend/pull/5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for
SPPatternWithBias and Qwen3-32B model for SPPattern. Benchmark results
show that, compared to the unfused baseline, enabling this fusion pass
significantly improves inference throughput for W8A8 quantized models.
For more details can refer to the
RFC:https://github.com/vllm-project/vllm-ascend/issues/4715
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
7157596103

Signed-off-by: cjian <2318164299@qq.com>
This commit is contained in:
CodeCat
2026-01-07 09:03:45 +08:00
committed by GitHub
parent ad9b711f89
commit bdedf3c9f8
2 changed files with 207 additions and 74 deletions

View File

@@ -16,6 +16,23 @@
import sys
from unittest import mock
import torch
def get_inputs():
"""
Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
"""
rms_norm_input = torch.randn(2, 4)
residual = torch.randn(2, 4)
rms_norm_weight = torch.randn(4)
rmsnorm_bias = torch.randn(4)
scale = torch.ones(4)
offset = torch.zeros(4)
return [
rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias
]
def _extra_stream_scope_check_for_test(match) -> bool:
"""
@@ -93,3 +110,39 @@ def test_replacement_function_without_torch_npu(caplog):
assert result is None
except (ImportError, AttributeError):
pass
def test_get_inputs_sp_pattern_with_bias():
"""
Test that get_inputs generates tensors with correct shapes and device.
This test verifies the internal get_inputs function used in the pattern.
"""
try:
import torch
except ImportError:
return # Skip if torch is not available
inputs = get_inputs()
(
rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
rmsnorm_bias,
) = inputs
# Verify shapes
assert rms_norm_input.shape == (2, 4)
assert residual.shape == (2, 4)
assert rms_norm_weight.shape == (4, )
assert rmsnorm_bias.shape == (4, )
assert scale.shape == (4, )
assert offset.shape == (4, )
# Verify number of inputs
assert len(inputs) == 6
# Verify specific values
assert torch.all(scale == 1.0)
assert torch.all(offset == 0.0)

View File

@@ -16,52 +16,47 @@
# limitations under the License.
#
import functools
import sys
import torch
from torch._inductor.pattern_matcher import Match
from vllm.logger import logger
def _extra_stream_scope_check(match: Match) -> bool:
"""
Checks if all nodes in the same stream.
"""
non_default_streams = set()
has_default = False
for node in match.nodes:
if node.op == "call_function":
current_stream = node.meta.get("stream_label")
if current_stream is None:
has_default = True
else:
non_default_streams.add(current_stream)
if len(non_default_streams) > 1:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations."
)
return False
if has_default and len(non_default_streams) > 0:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
return False
return True
@functools.lru_cache(None)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant(epsilon):
if 'torch_npu' not in sys.modules:
logger.info(
'The AddRMSNormQuant fusion will only be enabled in a torch npu env.'
'When there is no torch_npu in the env, skip fusion.')
return
def _extra_stream_scope_check(match: Match) -> bool:
"""
Checks if all nodes in the same stream.
"""
non_default_streams = set()
has_default = False
for node in match.nodes:
if node.op == "call_function":
current_stream = node.meta.get("stream_label")
if current_stream is None:
has_default = True
else:
non_default_streams.add(current_stream)
if len(non_default_streams) > 1:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations."
)
return False
if has_default and len(non_default_streams) > 0:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
return False
return True
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
@@ -114,45 +109,8 @@ def replacement_add_rms_norm_quant(epsilon):
extra_check=_extra_stream_scope_check)
@functools.lru_cache(None)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_with_bias(epsilon):
if 'torch_npu' not in sys.modules:
logger.info(
'The AddRMSNormQuantWithBias fusion will only be enabled in a torch npu env.'
'When there is no torch_npu in the env, skip fusion.')
return
def _extra_stream_scope_check(match: Match) -> bool:
"""
Checks if all nodes in the same stream.
"""
non_default_streams = set()
has_default = False
for node in match.nodes:
if node.op == "call_function":
current_stream = node.meta.get("stream_label")
if current_stream is None:
has_default = True
else:
non_default_streams.add(current_stream)
if len(non_default_streams) > 1:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations."
)
return False
if has_default and len(non_default_streams) > 0:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
return False
return True
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
@@ -211,6 +169,126 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
extra_check=_extra_stream_scope_check)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Pattern for AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Replacement for the AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
offset,
epsilon=epsilon)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
return quantized_output, out1
def get_inputs():
"""
Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
offset,
epsilon=epsilon,
beta=bias)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
return quantized_output, out1
def get_inputs():
"""
Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
# register converter for pass
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
@@ -219,3 +297,5 @@ for eps in common_epsilons:
)
replacement_add_rms_norm_quant(eps)
replacement_add_rms_norm_quant_with_bias(eps)
replacement_add_rms_norm_quant_sp_pattern(eps)
replacement_add_rms_norm_quant_sp_pattern_with_bias(eps)