[Graph][Fusion] Add AddRMSNorm(with bias) (#5491)
### What this PR does / why we need it?
This PR builds upon PR #5011 and aims to further enhance the
npu_graph_ex_passes module. Based on prior work, we have added graph
optimization support for the add_rms_quant fused operator in scenarios
where a bias term is present—ensuring the fusion pattern is correctly
registered and matched into the computation graph.
For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:https://github.com/vllm-project/vllm-ascend/issues/4715
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=False,
enable_expert_parallel=enable_expert_parallel,
trust_remote_code=trust_remote_code,
gpu_memory_utilization=0.98,
max_num_batched_tokens=512,
# load_format="dummy",
max_model_len=2048,
max_num_seqs=16,
quantization="ascend",
additional_config={
"refresh": True,
"enable_npugraph_ex": True
},
compilation_config={
"cudagraph_capture_sizes": [8, 16],
"cudagraph_mode": "FULL_DECODE_ONLY",
},
)
if profile_dir:
llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
if profile_dir:
llm.stop_profile()
for i, output in enumerate(outputs):
if i >= 5:
break
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}"
)
```
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Signed-off-by: cjian <2318164299@qq.com>
This commit is contained in:
95
tests/ut/compilation/test_add_rms_norm_quant.py
Normal file
95
tests/ut/compilation/test_add_rms_norm_quant.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
|
||||||
|
def _extra_stream_scope_check_for_test(match) -> bool:
|
||||||
|
"""
|
||||||
|
Copied from the original implementation for testability.
|
||||||
|
Checks if all nodes in the same stream.
|
||||||
|
"""
|
||||||
|
non_default_streams = set()
|
||||||
|
has_default = False
|
||||||
|
|
||||||
|
for node in match.nodes:
|
||||||
|
if node.op == "call_function":
|
||||||
|
current_stream = node.meta.get("stream_label")
|
||||||
|
if current_stream is None:
|
||||||
|
has_default = True
|
||||||
|
else:
|
||||||
|
non_default_streams.add(current_stream)
|
||||||
|
if len(non_default_streams) > 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if has_default and len(non_default_streams) > 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_extra_stream_scope_check():
|
||||||
|
"""Test the stream scope check logic."""
|
||||||
|
|
||||||
|
class MockNode:
|
||||||
|
|
||||||
|
def __init__(self, stream_label=None):
|
||||||
|
self.op = "call_function"
|
||||||
|
self.meta = {"stream_label": stream_label}
|
||||||
|
|
||||||
|
class MockMatch:
|
||||||
|
|
||||||
|
def __init__(self, nodes):
|
||||||
|
self.nodes = nodes
|
||||||
|
|
||||||
|
# Test 1: all default stream (None) → OK
|
||||||
|
match1 = MockMatch([MockNode(None), MockNode(None)])
|
||||||
|
assert _extra_stream_scope_check_for_test(match1) is True
|
||||||
|
|
||||||
|
# Test 2: all same non-default stream → OK
|
||||||
|
match2 = MockMatch([MockNode("s1"), MockNode("s1")])
|
||||||
|
assert _extra_stream_scope_check_for_test(match2) is True
|
||||||
|
|
||||||
|
# Test 3: mixed streams → FAIL
|
||||||
|
match3 = MockMatch([MockNode("s1"), MockNode("s2")])
|
||||||
|
assert _extra_stream_scope_check_for_test(match3) is False
|
||||||
|
|
||||||
|
# Test 4: default + non-default → FAIL
|
||||||
|
match4 = MockMatch([MockNode(None), MockNode("s1")])
|
||||||
|
assert _extra_stream_scope_check_for_test(match4) is False
|
||||||
|
|
||||||
|
# Test 5: empty nodes → OK (edge case)
|
||||||
|
match5 = MockMatch([])
|
||||||
|
assert _extra_stream_scope_check_for_test(match5) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_replacement_function_without_torch_npu(caplog):
|
||||||
|
with mock.patch.dict(sys.modules, {
|
||||||
|
'torch_npu': None,
|
||||||
|
'torchair': None,
|
||||||
|
'torch_npu.dynamo': None
|
||||||
|
}):
|
||||||
|
if 'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant' in sys.modules:
|
||||||
|
del sys.modules[
|
||||||
|
'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant']
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant import \
|
||||||
|
replacement_add_rms_norm_quant_with_bias
|
||||||
|
result = replacement_add_rms_norm_quant_with_bias(epsilon=1e-5)
|
||||||
|
assert result is None
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
||||||
@@ -25,7 +25,7 @@ from vllm.logger import logger
|
|||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
# The replacement registered here will be actually executed after AOT.
|
# The replacement registered here will be actually executed after AOT.
|
||||||
def _register_replacement(epsilon):
|
def replacement_add_rms_norm_quant(epsilon):
|
||||||
if 'torch_npu' not in sys.modules:
|
if 'torch_npu' not in sys.modules:
|
||||||
logger.info(
|
logger.info(
|
||||||
'The AddRMSNormQuant fusion will only be enabled in a torch npu env.'
|
'The AddRMSNormQuant fusion will only be enabled in a torch npu env.'
|
||||||
@@ -114,10 +114,108 @@ def _register_replacement(epsilon):
|
|||||||
extra_check=_extra_stream_scope_check)
|
extra_check=_extra_stream_scope_check)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
# The replacement registered here will be actually executed after AOT.
|
||||||
|
def replacement_add_rms_norm_quant_with_bias(epsilon):
|
||||||
|
if 'torch_npu' not in sys.modules:
|
||||||
|
logger.info(
|
||||||
|
'The AddRMSNormQuantWithBias fusion will only be enabled in a torch npu env.'
|
||||||
|
'When there is no torch_npu in the env, skip fusion.')
|
||||||
|
return
|
||||||
|
|
||||||
|
def _extra_stream_scope_check(match: Match) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if all nodes in the same stream.
|
||||||
|
"""
|
||||||
|
non_default_streams = set()
|
||||||
|
has_default = False
|
||||||
|
|
||||||
|
for node in match.nodes:
|
||||||
|
if node.op == "call_function":
|
||||||
|
current_stream = node.meta.get("stream_label")
|
||||||
|
if current_stream is None:
|
||||||
|
has_default = True
|
||||||
|
else:
|
||||||
|
non_default_streams.add(current_stream)
|
||||||
|
if len(non_default_streams) > 1:
|
||||||
|
logger.debug(
|
||||||
|
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
|
||||||
|
f"Multiple streams found: {non_default_streams}. "
|
||||||
|
f"Fusion is not supported for cross-stream operations."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if has_default and len(non_default_streams) > 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
|
||||||
|
f"Multiple streams found: {non_default_streams}. "
|
||||||
|
f"Fusion is not supported for cross-stream operations.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||||
|
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||||
|
offset: torch.Tensor, bias: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Pattern for AddRMSNormQuantWithBias fusion.
|
||||||
|
"""
|
||||||
|
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||||
|
rms_norm_weight, epsilon)
|
||||||
|
out0 = output[0]
|
||||||
|
out1 = output[2]
|
||||||
|
out0 = out0 + bias
|
||||||
|
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||||
|
torch.qint8, -1, False)
|
||||||
|
return quantized_output, out1
|
||||||
|
|
||||||
|
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||||
|
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||||
|
offset: torch.Tensor, bias: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Replacement for AddRMSNormQuantWithBias fusion.
|
||||||
|
"""
|
||||||
|
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||||
|
rms_norm_input,
|
||||||
|
residual,
|
||||||
|
rms_norm_weight,
|
||||||
|
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||||
|
1. / scale,
|
||||||
|
offset,
|
||||||
|
epsilon=epsilon,
|
||||||
|
beta=bias)
|
||||||
|
quantized_output = output[0]
|
||||||
|
out1 = output[2]
|
||||||
|
return quantized_output, out1
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
"""
|
||||||
|
Generate example inputs for the AddRMSNormQuantWithBias fusion pattern.
|
||||||
|
"""
|
||||||
|
rms_norm_input = torch.randn(2, 4, device="npu")
|
||||||
|
residual = torch.randn(2, 4, device="npu")
|
||||||
|
rms_norm_weight = torch.randn(4, device="npu")
|
||||||
|
rmsnorm_bias = torch.randn(4, device="npu")
|
||||||
|
scale = torch.ones(4, device="npu")
|
||||||
|
offset = torch.zeros(4, device="npu")
|
||||||
|
return [
|
||||||
|
rms_norm_input, residual, rms_norm_weight, scale, offset,
|
||||||
|
rmsnorm_bias
|
||||||
|
]
|
||||||
|
|
||||||
|
import torchair
|
||||||
|
|
||||||
|
torchair.register_replacement(search_fn=pattern,
|
||||||
|
replace_fn=replacement,
|
||||||
|
example_inputs=get_inputs(),
|
||||||
|
extra_check=_extra_stream_scope_check)
|
||||||
|
|
||||||
|
|
||||||
# register converter for pass
|
# register converter for pass
|
||||||
common_epsilons = [1e-5, 1e-6]
|
common_epsilons = [1e-5, 1e-6]
|
||||||
for eps in common_epsilons:
|
for eps in common_epsilons:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
|
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
|
||||||
)
|
)
|
||||||
_register_replacement(eps)
|
replacement_add_rms_norm_quant(eps)
|
||||||
|
replacement_add_rms_norm_quant_with_bias(eps)
|
||||||
|
|||||||
Reference in New Issue
Block a user