[BugFix] Fix AddRMSNormQuant not taking effect (#6620)
### What this PR does / why we need it?
Fix the issue where, in graph mode, the fused `AddRMSNormQuant` operator
does not take effect when there is no bias.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -60,7 +60,9 @@ class AddRMSNormQuantPattern:
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
rms_norm_input, residual, rms_norm_weight, None, self.eps
|
||||
)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
@@ -179,7 +181,9 @@ class AddRMSNormQuantSPPattern:
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
rms_norm_input, residual, rms_norm_weight, None, self.eps
|
||||
)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
@@ -482,11 +486,11 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
|
||||
common_epsilons = [1e-5, 1e-6]
|
||||
for eps in common_epsilons:
|
||||
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
if enable_custom_op():
|
||||
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
|
||||
Reference in New Issue
Block a user