[bugfix] adapt bugfix for norm_quant_fusion_pass to npugraph_ex (#6726)
### What this PR does / why we need it?
This PR adapts bugfixes from `norm_quant_fusion_pass` to
`graphex_norm_quant_fusion_pass` for the `npugraph_ex` backend.
The main changes are:
- Replaced `torch.ops.npu.npu_add_rms_norm` with
`torch.ops._C_ascend.npu_add_rms_norm_bias`.
- For patterns without bias, `None` is passed as the bias argument.
- For patterns with bias, the separate `add` operation for bias is
removed and the bias is passed directly to `npu_add_rms_norm_bias`. This
improves fusion.
These changes ensure consistency and correctness for RMSNorm and
quantization fusion patterns when using `npugraph_ex`.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
9562912cea
Signed-off-by: huyuanquan1 <huyuanquan1@huawei.com>
Co-authored-by: huyuanquan1 <huyuanquan1@huawei.com>
This commit is contained in:
@@ -58,7 +58,9 @@ class GraphEXAddRMSNormQuantPattern:
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
rms_norm_input, residual, rms_norm_weight, None, self.eps
|
||||
)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
@@ -123,10 +125,11 @@ class GraphEXAddRMSNormQuantPatternWithBias:
|
||||
"""
|
||||
Pattern for AddRMSNormQuantWithBias fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
rms_norm_input, residual, rms_norm_weight, bias, self.eps
|
||||
)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
@@ -188,7 +191,9 @@ class GraphEXAddRMSNormQuantSPPattern:
|
||||
"""
|
||||
Pattern for AddRMSNormQuantSPPattern fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
rms_norm_input, residual, rms_norm_weight, None, self.eps
|
||||
)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
@@ -255,10 +260,11 @@ class GraphEXAddRMSNormQuantSPPatternWithBias:
|
||||
"""
|
||||
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
rms_norm_input, residual, rms_norm_weight, bias, self.eps
|
||||
)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
Reference in New Issue
Block a user