diff --git a/tests/ut/compilation/test_add_rms_norm_quant.py b/tests/ut/compilation/test_add_rms_norm_quant.py index 0e2887a7..d056676c 100644 --- a/tests/ut/compilation/test_add_rms_norm_quant.py +++ b/tests/ut/compilation/test_add_rms_norm_quant.py @@ -16,6 +16,23 @@ import sys from unittest import mock +import torch + + +def get_inputs(): + """ + Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern. + """ + rms_norm_input = torch.randn(2, 4) + residual = torch.randn(2, 4) + rms_norm_weight = torch.randn(4) + rmsnorm_bias = torch.randn(4) + scale = torch.ones(4) + offset = torch.zeros(4) + return [ + rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias + ] + def _extra_stream_scope_check_for_test(match) -> bool: """ @@ -93,3 +110,39 @@ def test_replacement_function_without_torch_npu(caplog): assert result is None except (ImportError, AttributeError): pass + + +def test_get_inputs_sp_pattern_with_bias(): + """ + Test that get_inputs generates tensors with correct shapes and device. + This test verifies the internal get_inputs function used in the pattern. + """ + try: + import torch + except ImportError: + return # Skip if torch is not available + + inputs = get_inputs() + ( + rms_norm_input, + residual, + rms_norm_weight, + scale, + offset, + rmsnorm_bias, + ) = inputs + + # Verify shapes + assert rms_norm_input.shape == (2, 4) + assert residual.shape == (2, 4) + assert rms_norm_weight.shape == (4, ) + assert rmsnorm_bias.shape == (4, ) + assert scale.shape == (4, ) + assert offset.shape == (4, ) + + # Verify number of inputs + assert len(inputs) == 6 + + # Verify specific values + assert torch.all(scale == 1.0) + assert torch.all(offset == 0.0) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py index 3de71e61..0c12e68d 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py @@ -16,52 +16,47 @@ # limitations under the License. # import functools -import sys import torch from torch._inductor.pattern_matcher import Match from vllm.logger import logger +def _extra_stream_scope_check(match: Match) -> bool: + """ + Checks if all nodes in the same stream. + """ + non_default_streams = set() + has_default = False + + for node in match.nodes: + if node.op == "call_function": + current_stream = node.meta.get("stream_label") + if current_stream is None: + has_default = True + else: + non_default_streams.add(current_stream) + if len(non_default_streams) > 1: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations." + ) + return False + + if has_default and len(non_default_streams) > 0: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations.") + return False + + return True + + @functools.lru_cache(None) # The replacement registered here will be actually executed after AOT. def replacement_add_rms_norm_quant(epsilon): - if 'torch_npu' not in sys.modules: - logger.info( - 'The AddRMSNormQuant fusion will only be enabled in a torch npu env.' - 'When there is no torch_npu in the env, skip fusion.') - return - - def _extra_stream_scope_check(match: Match) -> bool: - """ - Checks if all nodes in the same stream. - """ - non_default_streams = set() - has_default = False - - for node in match.nodes: - if node.op == "call_function": - current_stream = node.meta.get("stream_label") - if current_stream is None: - has_default = True - else: - non_default_streams.add(current_stream) - if len(non_default_streams) > 1: - logger.debug( - f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " - f"Multiple streams found: {non_default_streams}. " - f"Fusion is not supported for cross-stream operations." - ) - return False - - if has_default and len(non_default_streams) > 0: - logger.debug( - f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " - f"Multiple streams found: {non_default_streams}. " - f"Fusion is not supported for cross-stream operations.") - return False - - return True def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor, scale: torch.Tensor, @@ -114,45 +109,8 @@ def replacement_add_rms_norm_quant(epsilon): extra_check=_extra_stream_scope_check) -@functools.lru_cache(None) # The replacement registered here will be actually executed after AOT. def replacement_add_rms_norm_quant_with_bias(epsilon): - if 'torch_npu' not in sys.modules: - logger.info( - 'The AddRMSNormQuantWithBias fusion will only be enabled in a torch npu env.' - 'When there is no torch_npu in the env, skip fusion.') - return - - def _extra_stream_scope_check(match: Match) -> bool: - """ - Checks if all nodes in the same stream. - """ - non_default_streams = set() - has_default = False - - for node in match.nodes: - if node.op == "call_function": - current_stream = node.meta.get("stream_label") - if current_stream is None: - has_default = True - else: - non_default_streams.add(current_stream) - if len(non_default_streams) > 1: - logger.debug( - f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. " - f"Multiple streams found: {non_default_streams}. " - f"Fusion is not supported for cross-stream operations." - ) - return False - - if has_default and len(non_default_streams) > 0: - logger.debug( - f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. " - f"Multiple streams found: {non_default_streams}. " - f"Fusion is not supported for cross-stream operations.") - return False - - return True def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor, scale: torch.Tensor, @@ -211,6 +169,126 @@ def replacement_add_rms_norm_quant_with_bias(epsilon): extra_check=_extra_stream_scope_check) +# The replacement registered here will be actually executed after AOT. +def replacement_add_rms_norm_quant_sp_pattern(epsilon): + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor): + """ + Pattern for AddRMSNormQuantSPPattern fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, epsilon) + out0 = output[0] + out1 = output[2] + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, + torch.qint8, -1, False) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor): + """ + Replacement for the AddRMSNormQuantSPPattern fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, + residual, + rms_norm_weight, + # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. + 1. / scale, + offset, + epsilon=epsilon) + quantized_output = output[0] + out1 = output[2] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + quantized_output, True) + return quantized_output, out1 + + def get_inputs(): + """ + Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + scale = torch.ones(4, device="npu") + offset = torch.zeros(4, device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, offset] + + import torchair + + torchair.register_replacement(search_fn=pattern, + replace_fn=replacement, + example_inputs=get_inputs(), + extra_check=_extra_stream_scope_check) + + +# The replacement registered here will be actually executed after AOT. +def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon): + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor, bias: torch.Tensor): + """ + Pattern for AddRMSNormQuantSPPatternWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, epsilon) + out0 = output[0] + out1 = output[2] + out0 = out0 + bias + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, + torch.qint8, -1, False) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor, bias: torch.Tensor): + """ + Replacement for the AddRMSNormQuantSPPatternWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, + residual, + rms_norm_weight, + # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. + 1. / scale, + offset, + epsilon=epsilon, + beta=bias) + quantized_output = output[0] + out1 = output[2] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + quantized_output, True) + return quantized_output, out1 + + def get_inputs(): + """ + Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + rmsnorm_bias = torch.randn(4, device="npu") + scale = torch.ones(4, device="npu") + offset = torch.zeros(4, device="npu") + return [ + rms_norm_input, residual, rms_norm_weight, scale, offset, + rmsnorm_bias + ] + + import torchair + + torchair.register_replacement(search_fn=pattern, + replace_fn=replacement, + example_inputs=get_inputs(), + extra_check=_extra_stream_scope_check) + + # register converter for pass common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: @@ -219,3 +297,5 @@ for eps in common_epsilons: ) replacement_add_rms_norm_quant(eps) replacement_add_rms_norm_quant_with_bias(eps) + replacement_add_rms_norm_quant_sp_pattern(eps) + replacement_add_rms_norm_quant_sp_pattern_with_bias(eps)