diff --git a/tests/ut/compilation/test_add_rms_norm_quant.py b/tests/ut/compilation/test_add_rms_norm_quant.py new file mode 100644 index 00000000..0e2887a7 --- /dev/null +++ b/tests/ut/compilation/test_add_rms_norm_quant.py @@ -0,0 +1,95 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import sys +from unittest import mock + + +def _extra_stream_scope_check_for_test(match) -> bool: + """ + Copied from the original implementation for testability. + Checks if all nodes in the same stream. + """ + non_default_streams = set() + has_default = False + + for node in match.nodes: + if node.op == "call_function": + current_stream = node.meta.get("stream_label") + if current_stream is None: + has_default = True + else: + non_default_streams.add(current_stream) + if len(non_default_streams) > 1: + return False + + if has_default and len(non_default_streams) > 0: + return False + + return True + + +def test_extra_stream_scope_check(): + """Test the stream scope check logic.""" + + class MockNode: + + def __init__(self, stream_label=None): + self.op = "call_function" + self.meta = {"stream_label": stream_label} + + class MockMatch: + + def __init__(self, nodes): + self.nodes = nodes + + # Test 1: all default stream (None) → OK + match1 = MockMatch([MockNode(None), MockNode(None)]) + assert _extra_stream_scope_check_for_test(match1) is True + + # Test 2: all same non-default stream → OK + match2 = MockMatch([MockNode("s1"), MockNode("s1")]) + assert _extra_stream_scope_check_for_test(match2) is True + + # Test 3: mixed streams → FAIL + match3 = MockMatch([MockNode("s1"), MockNode("s2")]) + assert _extra_stream_scope_check_for_test(match3) is False + + # Test 4: default + non-default → FAIL + match4 = MockMatch([MockNode(None), MockNode("s1")]) + assert _extra_stream_scope_check_for_test(match4) is False + + # Test 5: empty nodes → OK (edge case) + match5 = MockMatch([]) + assert _extra_stream_scope_check_for_test(match5) is True + + +def test_replacement_function_without_torch_npu(caplog): + with mock.patch.dict(sys.modules, { + 'torch_npu': None, + 'torchair': None, + 'torch_npu.dynamo': None + }): + if 'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant' in sys.modules: + del sys.modules[ + 'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant'] + + try: + from vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant import \ + replacement_add_rms_norm_quant_with_bias + result = replacement_add_rms_norm_quant_with_bias(epsilon=1e-5) + assert result is None + except (ImportError, AttributeError): + pass diff --git a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py index 724d8140..3de71e61 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py @@ -25,7 +25,7 @@ from vllm.logger import logger @functools.lru_cache(None) # The replacement registered here will be actually executed after AOT. -def _register_replacement(epsilon): +def replacement_add_rms_norm_quant(epsilon): if 'torch_npu' not in sys.modules: logger.info( 'The AddRMSNormQuant fusion will only be enabled in a torch npu env.' @@ -114,10 +114,108 @@ def _register_replacement(epsilon): extra_check=_extra_stream_scope_check) +@functools.lru_cache(None) +# The replacement registered here will be actually executed after AOT. +def replacement_add_rms_norm_quant_with_bias(epsilon): + if 'torch_npu' not in sys.modules: + logger.info( + 'The AddRMSNormQuantWithBias fusion will only be enabled in a torch npu env.' + 'When there is no torch_npu in the env, skip fusion.') + return + + def _extra_stream_scope_check(match: Match) -> bool: + """ + Checks if all nodes in the same stream. + """ + non_default_streams = set() + has_default = False + + for node in match.nodes: + if node.op == "call_function": + current_stream = node.meta.get("stream_label") + if current_stream is None: + has_default = True + else: + non_default_streams.add(current_stream) + if len(non_default_streams) > 1: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations." + ) + return False + + if has_default and len(non_default_streams) > 0: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations.") + return False + + return True + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor, bias: torch.Tensor): + """ + Pattern for AddRMSNormQuantWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, epsilon) + out0 = output[0] + out1 = output[2] + out0 = out0 + bias + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, + torch.qint8, -1, False) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor, bias: torch.Tensor): + """ + Replacement for AddRMSNormQuantWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, + residual, + rms_norm_weight, + # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. + 1. / scale, + offset, + epsilon=epsilon, + beta=bias) + quantized_output = output[0] + out1 = output[2] + return quantized_output, out1 + + def get_inputs(): + """ + Generate example inputs for the AddRMSNormQuantWithBias fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + rmsnorm_bias = torch.randn(4, device="npu") + scale = torch.ones(4, device="npu") + offset = torch.zeros(4, device="npu") + return [ + rms_norm_input, residual, rms_norm_weight, scale, offset, + rmsnorm_bias + ] + + import torchair + + torchair.register_replacement(search_fn=pattern, + replace_fn=replacement, + example_inputs=get_inputs(), + extra_check=_extra_stream_scope_check) + + # register converter for pass common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: logger.info( f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}" ) - _register_replacement(eps) + replacement_add_rms_norm_quant(eps) + replacement_add_rms_norm_quant_with_bias(eps)