From acc3578f585272aaea6e8bdbbb0a94320a476b51 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Thu, 18 Dec 2025 20:25:44 +0800 Subject: [PATCH] [Graph][Fusion]Add new pattern for AddRmsnormQuant with SP. (#5077) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? 1. In addition to [#4168](https://github.com/vllm-project/vllm-ascend/pull/4168), [#5011](https://github.com/vllm-project/vllm-ascend/pull/5011), this PR adds two more pattern for AddRmsnormQuant with SP enabled. The key difference is to insert an additional `maybe_all_gather_and_maybe_unpad` between `addrmsnorm` and `quantize`. 2. This PR also introduce another api `torch.ops.vllm.quantize`, so that we pass `input_scale` and `input_scale_reciprocal` at the same time. This is because `npu_add_rms_norm_quant` and `npu_quantize` requires different `div_mode`. To avoid introducing additional reciprocal calculation in runtime, we have to pass both of them to quantize api. 3. Removes redundant `AscendQuantRmsnorm`. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Angazenn --- .../compile/test_norm_quant_fusion.py | 271 +++++++++++++++--- tests/ut/quantization/test_w8a8.py | 13 +- .../passes/norm_quant_fusion_pass.py | 219 +++++++++++--- vllm_ascend/ops/layernorm.py | 27 +- vllm_ascend/ops/linear_op.py | 8 +- vllm_ascend/ops/register_custom_ops.py | 26 ++ vllm_ascend/quantization/w8a8.py | 6 +- 7 files changed, 454 insertions(+), 116 deletions(-) diff --git a/tests/e2e/singlecard/compile/test_norm_quant_fusion.py b/tests/e2e/singlecard/compile/test_norm_quant_fusion.py index 2dd09de1..1a335135 100644 --- a/tests/e2e/singlecard/compile/test_norm_quant_fusion.py +++ b/tests/e2e/singlecard/compile/test_norm_quant_fusion.py @@ -23,8 +23,13 @@ import torch_npu import vllm.config from vllm.compilation.fx_utils import OpOverload from vllm.config import ModelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.utils.system_utils import update_environment_variables +import vllm_ascend.ops.register_custom_ops # noqa from tests.e2e.singlecard.compile.backend import TestBackend +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \ AddRMSNormQuantFusionPass @@ -35,14 +40,23 @@ class TestModelWithoutBias(nn.Module): AddRMSNorm → Quantization (without bias) """ - def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"): + def __init__(self, + hidden_size: int, + dtype: torch.dtype, + eps: float = 1e-6, + device="npu"): super().__init__() self.hidden_size = hidden_size self.eps = eps self.rms_norm_weight = nn.Parameter( torch.randn(hidden_size, device=device)) - self.quant_scale = torch.tensor([1.0], device=device) - self.quant_offset = torch.tensor([0.0], device=device) + self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_scale_reciprocal = torch.ones(hidden_size, + dtype=dtype, + device=device) + self.quant_offset = torch.zeros(hidden_size, + dtype=dtype, + device=device) def forward(self, x): """ @@ -56,10 +70,10 @@ class TestModelWithoutBias(nn.Module): norm_output, _, new_residual = torch_npu.npu_add_rms_norm( x, residual, self.rms_norm_weight, self.eps) - quantized_output = torch_npu.npu_quantize(norm_output, - self.quant_scale, - self.quant_offset, - torch.qint8, -1, False) + quantized_output = torch.ops.vllm.quantize(norm_output, + self.quant_scale, + self.quant_scale_reciprocal, + self.quant_offset) return quantized_output, new_residual @@ -67,7 +81,7 @@ class TestModelWithoutBias(nn.Module): """Return the list of expected operators BEFORE fusion.""" return [ torch.ops.npu.npu_add_rms_norm.default, - torch.ops.npu.npu_quantize.default + torch.ops.vllm.quantize.default ] def ops_in_model_after(self) -> List[OpOverload]: @@ -81,15 +95,24 @@ class TestModelWithBias(nn.Module): AddRMSNorm → Add Bias → Quantization (with bias) """ - def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"): + def __init__(self, + hidden_size: int, + dtype: torch.dtype, + eps: float = 1e-6, + device="npu"): super().__init__() self.hidden_size = hidden_size self.eps = eps self.rms_norm_weight = nn.Parameter( torch.randn(hidden_size, device=device)) self.bias = nn.Parameter(torch.randn(hidden_size, device=device)) - self.quant_scale = torch.tensor([1.0], device=device) - self.quant_offset = torch.tensor([0.0], device=device) + self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_scale_reciprocal = torch.ones(hidden_size, + dtype=dtype, + device=device) + self.quant_offset = torch.zeros(hidden_size, + dtype=dtype, + device=device) def forward(self, x): """ @@ -107,10 +130,10 @@ class TestModelWithBias(nn.Module): # Add bias norm_output_with_bias = norm_output + self.bias - quantized_output = torch_npu.npu_quantize(norm_output_with_bias, - self.quant_scale, - self.quant_offset, - torch.qint8, -1, False) + quantized_output = torch.ops.vllm.quantize(norm_output_with_bias, + self.quant_scale, + self.quant_scale_reciprocal, + self.quant_offset) return quantized_output, new_residual @@ -119,7 +142,7 @@ class TestModelWithBias(nn.Module): return [ torch.ops.npu.npu_add_rms_norm.default, torch.ops.aten.add.Tensor, # Add bias operation - torch.ops.npu.npu_quantize.default + torch.ops.vllm.quantize.default ] def ops_in_model_after(self) -> List[OpOverload]: @@ -127,13 +150,152 @@ class TestModelWithBias(nn.Module): return [torch.ops.npu.npu_add_rms_norm_quant.default] -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +class TestModelSPWithoutBias(nn.Module): + """ + A minimal test model that simulates the pattern: + AddRMSNorm → maybe_allgather → Quantization (without bias) + """ + + def __init__(self, + hidden_size: int, + dtype: torch.dtype, + eps: float = 1e-6, + device="npu"): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.rms_norm_weight = nn.Parameter( + torch.randn(hidden_size, device=device)) + self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_scale_reciprocal = torch.ones(hidden_size, + dtype=dtype, + device=device) + self.quant_offset = torch.zeros(hidden_size, + dtype=dtype, + device=device) + + def forward(self, x): + """ + Forward pass: + 1. Perform npu_add_rms_norm + 2. Perform a fake maybe_all_gather_and_maybe_unpad + 3. Quantize the normalized output to int8 + Returns both quantized output and updated residual. + """ + residual = torch.zeros_like(x) + + norm_output, _, new_residual = torch_npu.npu_add_rms_norm( + x, residual, self.rms_norm_weight, self.eps) + + norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + norm_output, True) + + quantized_output = torch.ops.vllm.quantize(norm_output, + self.quant_scale, + self.quant_scale_reciprocal, + self.quant_offset) + + return quantized_output, new_residual + + def ops_in_model_before(self) -> List[OpOverload]: + """Return the list of expected operators BEFORE fusion.""" + return [ + torch.ops.npu.npu_add_rms_norm.default, + torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default, + torch.ops.vllm.quantize.default + ] + + def ops_in_model_after(self) -> List[OpOverload]: + """Return the list of expected operators AFTER successful fusion.""" + return [ + torch.ops.npu.npu_add_rms_norm_quant.default, + torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default + ] + + +class TestModelSPWithBias(nn.Module): + """ + A minimal test model that simulates the pattern: + AddRMSNorm → Add bias → maybe_allgather → Quantization (without bias) + """ + + def __init__(self, + hidden_size: int, + dtype: torch.dtype, + eps: float = 1e-6, + device="npu"): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.rms_norm_weight = nn.Parameter( + torch.randn(hidden_size, device=device)) + self.bias = nn.Parameter(torch.randn(hidden_size, device=device)) + self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_scale_reciprocal = torch.ones(hidden_size, + dtype=dtype, + device=device) + self.quant_offset = torch.zeros(hidden_size, + dtype=dtype, + device=device) + + def forward(self, x): + """ + Forward pass: + 1. Perform npu_add_rms_norm + 2. Add bias + 3. Perform a fake maybe_all_gather_and_maybe_unpad + 4. Quantize the normalized output to int8 + Returns both quantized output and updated residual. + """ + residual = torch.zeros_like(x) + + norm_output, _, new_residual = torch_npu.npu_add_rms_norm( + x, residual, self.rms_norm_weight, self.eps) + + # Add bias + norm_output_with_bias = norm_output + self.bias + + norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + norm_output_with_bias, True) + + quantized_output = torch.ops.vllm.quantize(norm_output_with_bias, + self.quant_scale, + self.quant_scale_reciprocal, + self.quant_offset) + + return quantized_output, new_residual + + def ops_in_model_before(self) -> List[OpOverload]: + """Return the list of expected operators BEFORE fusion.""" + return [ + torch.ops.npu.npu_add_rms_norm.default, + torch.ops.aten.add.Tensor, # Add bias operation + torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default, + torch.ops.vllm.quantize.default + ] + + def ops_in_model_after(self) -> List[OpOverload]: + """Return the list of expected operators AFTER successful fusion.""" + return [ + torch.ops.npu.npu_add_rms_norm_quant.default, + torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default + ] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("use_bias", [False, True]) -def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int, - num_tokens: int, eps: float, use_bias: bool): +@pytest.mark.parametrize("sp_enable", [False, True]) +def test_rmsnorm_quant_fusion( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + use_bias: bool, + sp_enable: bool, +): """ End-to-end test for AddRMSNorm+Quantize fusion. Compares: Operator presence/absence before and after graph transformation @@ -143,27 +305,58 @@ def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int, vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype)) + update_environment_variables({ + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + }) + init_distributed_environment() + ensure_model_parallel_initialized(1, 1) + with vllm.config.set_current_vllm_config(vllm_config): - backend = TestBackend( - custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)]) - if use_bias: - model = TestModelWithBias(hidden_size, eps, device="npu") - else: - model = TestModelWithoutBias(hidden_size, eps, device="npu") - model = model.to("npu") + with set_ascend_forward_context(None, vllm_config): + backend = TestBackend(custom_passes=[ + AddRMSNormQuantFusionPass(vllm_config=vllm_config) + ]) + if use_bias: + if sp_enable: + model = TestModelSPWithBias(hidden_size, + dtype, + eps, + device="npu") + else: + model = TestModelWithBias(hidden_size, + dtype, + eps, + device="npu") + else: + if sp_enable: + model = TestModelSPWithoutBias(hidden_size, + dtype, + eps, + device="npu") + else: + model = TestModelWithoutBias(hidden_size, + dtype, + eps, + device="npu") + model = model.to("npu") - x = torch.rand(num_tokens, - hidden_size, - device="npu", - dtype=dtype, - requires_grad=False) + x = torch.rand(num_tokens, + hidden_size, + device="npu", + dtype=dtype, + requires_grad=False) - result_unfused = model(x) - print("Unfused result:", [t.shape for t in result_unfused]) - model_fused = torch.compile(model, backend=backend) - result_fused = model_fused(x) - print("Fused result:", [t.shape for t in result_fused]) + result_unfused = model(x) + print("Unfused result:", [t.shape for t in result_unfused]) + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) + print("Fused result:", [t.shape for t in result_fused]) - print("=== Checking operator fusion ===") - backend.check_before_ops(model.ops_in_model_before()) - backend.check_after_ops(model.ops_in_model_after()) + print("=== Checking operator fusion ===") + backend.check_before_ops(model.ops_in_model_before(), + fully_replaced=not sp_enable) + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 9a1c0e8d..c5b13360 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -70,10 +70,9 @@ class TestAscendW8A8LinearMethod(TestBase): self.assertEqual(params['weight_offset'].shape, (10, 1)) @patch("vllm_ascend.quantization.w8a8.get_forward_context") - @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch("torch.ops.vllm.quantize") @patch("torch_npu.npu_quant_matmul") - def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, - mock_quant_per_tensor, + def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize, mock_get_forward_context): layer = MagicMock() layer.aclnn_input_scale = 0.1 @@ -88,10 +87,10 @@ class TestAscendW8A8LinearMethod(TestBase): x = torch.randn(32, 128) bias = torch.randn(256) - mock_quant_per_tensor.return_value = torch.randint(-128, - 127, - x.shape, - dtype=torch.int8) + mock_quantize.return_value = torch.randint(-128, + 127, + x.shape, + dtype=torch.int8) expected_y_output = torch.randn(32, 256) mock_npu_quant_matmul.return_value = expected_y_output diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index c23449bf..3cdeaaf3 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -28,24 +28,29 @@ class AddRMSNormQuantPattern: def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype self.eps = eps def get_inputs(self): """ Generate example inputs for the AddRMSNormQuant fusion pattern. """ - rms_norm_input = torch.randn(2, 4, device="npu") - residual = torch.randn(2, 4, device="npu") - rms_norm_weight = torch.randn(4, device="npu") - scale = torch.tensor([1.0], device="npu") - offset = torch.tensor([0.0], device="npu") - return [rms_norm_input, residual, rms_norm_weight, scale, offset] + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + scale = torch.ones(4, device="npu", dtype=self.dtype) + scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) + offset = torch.zeros(4, device="npu", dtype=self.dtype) + return [ + rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, + offset + ] def register(self, pm_pass: PatternMatcherPass): def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor): + scale_reciprocal: torch.Tensor, offset: torch.Tensor): """ Pattern for AddRMSNormQuant fusion. """ @@ -53,24 +58,23 @@ class AddRMSNormQuantPattern: rms_norm_weight, self.eps) out0 = output[0] out1 = output[2] - quantized_output = torch.ops.npu.npu_quantize( - out0, scale, offset, torch.qint8, -1, False) + quantized_output = torch.ops.vllm.quantize(out0, scale, + scale_reciprocal, + offset) return quantized_output, out1 def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor): + scale_reciprocal: torch.Tensor, offset: torch.Tensor): """ Replacement for the AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm_quant( - rms_norm_input, - residual, - rms_norm_weight, - 1. / - scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. - offset, - epsilon=self.eps) + output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, + residual, + rms_norm_weight, + scale, + offset, + epsilon=self.eps) quantized_output = output[0] out1 = output[2] return quantized_output, out1 @@ -83,25 +87,31 @@ class AddRMSNormQuantPatternWithBias: def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype self.eps = eps def get_inputs(self): """ Generate example inputs for the AddRMSNormQuant fusion pattern. """ - rms_norm_input = torch.randn(2, 4, device="npu") - residual = torch.randn(2, 4, device="npu") - rms_norm_weight = torch.randn(4, device="npu") - scale = torch.tensor([1.0], device="npu") - offset = torch.tensor([0.0], device="npu") - bias = torch.randn(4, device="npu") - return [rms_norm_input, residual, rms_norm_weight, scale, offset, bias] + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype) + scale = torch.ones(4, device="npu", dtype=self.dtype) + scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) + offset = torch.zeros(4, device="npu", dtype=self.dtype) + return [ + rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, + offset, rmsnorm_bias + ] def register(self, pm_pass: PatternMatcherPass): def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor, bias: torch.Tensor): + scale_reciprocal: torch.Tensor, offset: torch.Tensor, + bias: torch.Tensor): """ Pattern for AddRMSNormQuant fusion. """ @@ -110,25 +120,25 @@ class AddRMSNormQuantPatternWithBias: out0 = output[0] out1 = output[2] out0 = out0 + bias - quantized_output = torch.ops.npu.npu_quantize( - out0, scale, offset, torch.qint8, -1, False) + quantized_output = torch.ops.vllm.quantize(out0, scale, + scale_reciprocal, + offset) return quantized_output, out1 def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor, scale: torch.Tensor, - offset: torch.Tensor, bias: torch.Tensor): + scale_reciprocal: torch.Tensor, offset: torch.Tensor, + bias: torch.Tensor): """ Replacement for the AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm_quant( - rms_norm_input, - residual, - rms_norm_weight, - 1. / - scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. - offset, - epsilon=self.eps, - beta=bias) + output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, + residual, + rms_norm_weight, + scale, + offset, + epsilon=self.eps, + beta=bias) quantized_output = output[0] out1 = output[2] return quantized_output, out1 @@ -137,6 +147,135 @@ class AddRMSNormQuantPatternWithBias: pm.fwd_only, pm_pass) +class AddRMSNormQuantSPPattern: + + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + scale = torch.ones(4, device="npu", dtype=self.dtype) + scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) + offset = torch.zeros(4, device="npu", dtype=self.dtype) + return [ + rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, + offset + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + scale_reciprocal: torch.Tensor, offset: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.vllm.quantize(out0, scale, + scale_reciprocal, + offset) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + scale_reciprocal: torch.Tensor, offset: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, + residual, + rms_norm_weight, + scale, + offset, + epsilon=self.eps) + quantized_output = output[0] + out1 = output[2] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + quantized_output, True) + return quantized_output, out1 + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AddRMSNormQuantSPPatternWithBias: + + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype) + scale = torch.ones(4, device="npu", dtype=self.dtype) + scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) + offset = torch.zeros(4, device="npu", dtype=self.dtype) + return [ + rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, + offset, rmsnorm_bias + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + scale_reciprocal: torch.Tensor, offset: torch.Tensor, + bias: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + out0 = out0 + bias + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.vllm.quantize(out0, scale, + scale_reciprocal, + offset) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + scale_reciprocal: torch.Tensor, offset: torch.Tensor, + bias: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, + residual, + rms_norm_weight, + scale, + offset, + epsilon=self.eps, + beta=bias) + quantized_output = output[0] + out1 = output[2] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + quantized_output, True) + return quantized_output, out1 + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + class AddRMSNormQuantFusionPass(VllmInductorPass): """ A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. @@ -159,6 +298,10 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): eps=eps).register(self.pattern_match_passes) AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register( self.pattern_match_passes) + AddRMSNormQuantSPPattern(vllm_config, eps=eps).register( + self.pattern_match_passes) + AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register( + self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin() diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index cdbba32f..98c0e6bb 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -from typing import Optional, Tuple, Union, cast +from typing import Optional, Tuple, Union import torch from vllm.config import get_current_vllm_config @@ -70,31 +70,6 @@ class AscendRMSNorm(RMSNorm): return x -class AscendQuantRMSNorm(AscendRMSNorm): - - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - has_weight: bool = True, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - - def forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if residual is not None: - x, residual = super().forward_oot(x, residual) - return x.add_(self.bias), residual - return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias) - - class AscendGemmaRMSNorm(GemmaRMSNorm): def forward_oot( diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 27310ffd..eec63dd3 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -545,8 +545,7 @@ class SequenceRowParallelOp(CustomRowParallelOp): from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm_ascend.quantization.quant_config import AscendLinearMethod - from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod, - quant_per_tensor) + from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod # For unquant if mmrs_fusion and isinstance(self.layer.quant_method, @@ -568,8 +567,9 @@ class SequenceRowParallelOp(CustomRowParallelOp): and isinstance(self.layer.quant_method.quant_method, AscendW8A8LinearMethod)): if x.dtype != torch.int8: - x_quant = quant_per_tensor( - x, self.layer.aclnn_input_scale_reciprocal, + x_quant = torch.ops.vllm.quantize( + x, self.layer.aclnn_input_scale, + self.layer.aclnn_input_scale_reciprocal, self.layer.aclnn_input_offset) else: x_quant = x diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index b7100991..a534b719 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -282,6 +282,26 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, return output +# TODO(Angazenn): The reason why we use a custom op to encapsulate npu_quantize +# is that aclnnAscendQuantV3(npu_quantize) use div_mode=False, while +# aclnnAddRmsNormQuantV2(npu_add_rms_norm_quant) use div_moe=True. We have to +# pass input_scale and input_scale_reciprocal at the same time to avoid redundant +# reciprocal calculation in fussion pass. We shall remove this once +# aclnnAddRmsNormQuantV2 supports div_moe=False. +def _quantize_impl(in_tensor: torch.Tensor, input_scale: torch.Tensor, + input_scale_reciprocal: torch.Tensor, + input_offset: torch.Tensor) -> torch.Tensor: + return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, + input_offset, torch.qint8, -1, False) + + +def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor, + input_scale_reciprocal: torch.Tensor, + input_offset: torch.Tensor) -> torch.Tensor: + return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, + input_offset, torch.qint8, -1, False) + + direct_register_custom_op(op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, fake_impl=lambda x, residual: x, @@ -341,3 +361,9 @@ direct_register_custom_op(op_name="matmul_and_reduce", fake_impl=_matmul_and_reduce_impl_fake, mutates_args=[], dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="quantize", + op_func=_quantize_impl, + fake_impl=_quantize_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index bfa39e69..35691f38 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -128,8 +128,9 @@ class AscendW8A8LinearMethod: if enable_flashcomm2_quant_comm: quant_input_x = x.contiguous().view( -1, layer.aclnn_input_scale_reciprocal.size(0)) - quant_x = quant_per_tensor( + quant_x = torch.ops.vllm.quantize( quant_input_x, + layer.aclnn_input_scale, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset, ) @@ -138,8 +139,9 @@ class AscendW8A8LinearMethod: x = comm_fn(comm_input) else: # quant - x = quant_per_tensor( + x = torch.ops.vllm.quantize( x, + layer.aclnn_input_scale, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset, )