[main] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's performance (#1806)

### What this PR does / why we need it?
Optimizes the performance of the Qwen3 quantization model by registering
a custom model and adding the AddRmsNormQuant operation. Subsequent PRs
will focus on performance optimizations based on this custom model.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with existing test.

- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-07-22 19:03:13 +08:00
committed by GitHub
parent ce4970eee0
commit 9a3bdf2162
5 changed files with 227 additions and 8 deletions

View File

@@ -91,10 +91,12 @@ class AscendW8A8LinearMethod:
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.aclnn_input_scale,
layer.aclnn_input_offset)
if x.dtype != torch.int8:
x = quant_per_tensor(
x,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
if is_310p():
# On 300I Duo platform, we need transpose again if
@@ -104,7 +106,7 @@ class AscendW8A8LinearMethod:
layer.weight.data.transpose(1, 0),
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
output_dtype=layer.params_dtype,
)
else:
output = torch_npu.npu_quant_matmul(
@@ -112,13 +114,16 @@ class AscendW8A8LinearMethod:
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
output_dtype=layer.params_dtype,
)
return output
def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_offset = torch.nn.Parameter(