[BugFix] Fix AddRMSNormQuant not taking effect (#6620)
### What this PR does / why we need it?
Fix the issue where, in graph mode, the fused `AddRMSNormQuant` operator
does not take effect when there is no bias.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -74,8 +74,8 @@ class TestModelWithoutBias(nn.Module):
|
|||||||
"""
|
"""
|
||||||
residual = torch.zeros_like(x)
|
residual = torch.zeros_like(x)
|
||||||
|
|
||||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
|
norm_output, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
x, residual, self.rms_norm_weight, self.eps)
|
x, residual, self.rms_norm_weight, None, self.eps)
|
||||||
|
|
||||||
quantized_output = torch.ops.vllm.quantize(norm_output,
|
quantized_output = torch.ops.vllm.quantize(norm_output,
|
||||||
self.quant_scale,
|
self.quant_scale,
|
||||||
@@ -87,7 +87,7 @@ class TestModelWithoutBias(nn.Module):
|
|||||||
def ops_in_model_before(self) -> List[OpOverload]:
|
def ops_in_model_before(self) -> List[OpOverload]:
|
||||||
"""Return the list of expected operators BEFORE fusion."""
|
"""Return the list of expected operators BEFORE fusion."""
|
||||||
return [
|
return [
|
||||||
torch.ops.npu.npu_add_rms_norm.default,
|
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
|
||||||
torch.ops.vllm.quantize.default
|
torch.ops.vllm.quantize.default
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -187,8 +187,8 @@ class TestModelSPWithoutBias(nn.Module):
|
|||||||
"""
|
"""
|
||||||
residual = torch.zeros_like(x)
|
residual = torch.zeros_like(x)
|
||||||
|
|
||||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
|
norm_output, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
x, residual, self.rms_norm_weight, self.eps)
|
x, residual, self.rms_norm_weight, None, self.eps)
|
||||||
|
|
||||||
norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
norm_output, True)
|
norm_output, True)
|
||||||
@@ -203,7 +203,7 @@ class TestModelSPWithoutBias(nn.Module):
|
|||||||
def ops_in_model_before(self) -> List[OpOverload]:
|
def ops_in_model_before(self) -> List[OpOverload]:
|
||||||
"""Return the list of expected operators BEFORE fusion."""
|
"""Return the list of expected operators BEFORE fusion."""
|
||||||
return [
|
return [
|
||||||
torch.ops.npu.npu_add_rms_norm.default,
|
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
|
||||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
|
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
|
||||||
torch.ops.vllm.quantize.default
|
torch.ops.vllm.quantize.default
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ def test_qwen3_w8a8_quant():
|
|||||||
]
|
]
|
||||||
vllm_target_outputs = [([
|
vllm_target_outputs = [([
|
||||||
85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323,
|
85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323,
|
||||||
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 369, 3460
|
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387
|
||||||
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed for large'
|
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be'
|
||||||
)]
|
)]
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
|
|||||||
@@ -60,7 +60,9 @@ class AddRMSNormQuantPattern:
|
|||||||
"""
|
"""
|
||||||
Pattern for AddRMSNormQuant fusion.
|
Pattern for AddRMSNormQuant fusion.
|
||||||
"""
|
"""
|
||||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
|
rms_norm_input, residual, rms_norm_weight, None, self.eps
|
||||||
|
)
|
||||||
out0 = output[0]
|
out0 = output[0]
|
||||||
out1 = output[2]
|
out1 = output[2]
|
||||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||||
@@ -179,7 +181,9 @@ class AddRMSNormQuantSPPattern:
|
|||||||
"""
|
"""
|
||||||
Pattern for AddRMSNormQuant fusion.
|
Pattern for AddRMSNormQuant fusion.
|
||||||
"""
|
"""
|
||||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
|
rms_norm_input, residual, rms_norm_weight, None, self.eps
|
||||||
|
)
|
||||||
out0 = output[0]
|
out0 = output[0]
|
||||||
out1 = output[2]
|
out1 = output[2]
|
||||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||||
@@ -482,11 +486,11 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
|||||||
|
|
||||||
common_epsilons = [1e-5, 1e-6]
|
common_epsilons = [1e-5, 1e-6]
|
||||||
for eps in common_epsilons:
|
for eps in common_epsilons:
|
||||||
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
|
||||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
|
||||||
AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
if enable_custom_op():
|
if enable_custom_op():
|
||||||
|
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
|
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
|
|||||||
Reference in New Issue
Block a user