[BugFix] Fix AddRMSNormQuant not taking effect (#6620)

### What this PR does / why we need it?
Fix the issue where, in graph mode, the fused `AddRMSNormQuant` operator
does not take effect when there is no bias.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
ZYang6263
2026-02-12 09:26:05 +08:00
committed by GitHub
parent 052cc4e61b
commit 56269eae0e
3 changed files with 16 additions and 12 deletions

View File

@@ -74,8 +74,8 @@ class TestModelWithoutBias(nn.Module):
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
x, residual, self.rms_norm_weight, self.eps)
norm_output, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
x, residual, self.rms_norm_weight, None, self.eps)
quantized_output = torch.ops.vllm.quantize(norm_output,
self.quant_scale,
@@ -87,7 +87,7 @@ class TestModelWithoutBias(nn.Module):
def ops_in_model_before(self) -> List[OpOverload]:
"""Return the list of expected operators BEFORE fusion."""
return [
torch.ops.npu.npu_add_rms_norm.default,
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
torch.ops.vllm.quantize.default
]
@@ -187,8 +187,8 @@ class TestModelSPWithoutBias(nn.Module):
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
x, residual, self.rms_norm_weight, self.eps)
norm_output, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
x, residual, self.rms_norm_weight, None, self.eps)
norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
norm_output, True)
@@ -203,7 +203,7 @@ class TestModelSPWithoutBias(nn.Module):
def ops_in_model_before(self) -> List[OpOverload]:
"""Return the list of expected operators BEFORE fusion."""
return [
torch.ops.npu.npu_add_rms_norm.default,
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
torch.ops.vllm.quantize.default
]

View File

@@ -26,8 +26,8 @@ def test_qwen3_w8a8_quant():
]
vllm_target_outputs = [([
85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323,
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 369, 3460
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed for large'
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be'
)]
with VllmRunner(