[main] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's performance (#1806)

### What this PR does / why we need it?
Optimizes the performance of the Qwen3 quantization model by registering
a custom model and adding the AddRmsNormQuant operation. Subsequent PRs
will focus on performance optimizations based on this custom model.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with existing test.

- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-07-22 19:03:13 +08:00
committed by GitHub
parent ce4970eee0
commit 9a3bdf2162
5 changed files with 227 additions and 8 deletions

View File

@@ -167,3 +167,20 @@ def test_models_distributed_topk() -> None:
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
def test_models_distributed_Qwen3_W8A8():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/Qwen3-8B-W8A8"),
max_model_len=8192,
enforce_eager=True,
dtype="auto",
tensor_parallel_size=4,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -11,6 +11,7 @@ def register_model():
from .qwen2_5_vl import \
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
ModelRegistry.register_model(
"DeepSeekMTPModel",
@@ -53,6 +54,9 @@ def register_model():
"Qwen3MoeForCausalLM",
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
ModelRegistry.register_model(
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
ModelRegistry.register_model(
"PanguProMoEForCausalLM",
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")

156
vllm_ascend/models/qwen3.py Normal file
View File

@@ -0,0 +1,156 @@
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import Qwen3Config
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
def __init__(
self,
config: Qwen3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
if quant_config is None:
return
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
assert isinstance(quant_config, AscendQuantConfig), \
"Expected quant_config to be an instance of AscendQuantConfig"
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.input_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.self_attn.qkv_proj,
eps=config.rms_norm_eps)
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.post_attention_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.mlp.gate_up_proj,
eps=config.rms_norm_eps)
ALL_DECODER_LAYER_TYPES = {
"attention": CustomQwen3DecoderLayer,
}
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class CustomQwen3Model(Qwen2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=CustomQwen3DecoderLayer)
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# add `CustomQwen3Model` to init self.model
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = CustomQwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@@ -23,6 +23,43 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_ascend.utils import is_310p
class AddRMSNormW8A8Quant(RMSNorm):
# Fuse AddRmsNorm and W8A8 quantization ops together
def __init__(
self,
hidden_size: int,
layer: torch.nn.Module,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.layer = layer
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x
def forward_oot(
self,
x: torch.Tensor,

View File

@@ -91,10 +91,12 @@ class AscendW8A8LinearMethod:
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.aclnn_input_scale,
layer.aclnn_input_offset)
if x.dtype != torch.int8:
x = quant_per_tensor(
x,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
if is_310p():
# On 300I Duo platform, we need transpose again if
@@ -104,7 +106,7 @@ class AscendW8A8LinearMethod:
layer.weight.data.transpose(1, 0),
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
output_dtype=layer.params_dtype,
)
else:
output = torch_npu.npu_quant_matmul(
@@ -112,13 +114,16 @@ class AscendW8A8LinearMethod:
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
output_dtype=layer.params_dtype,
)
return output
def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_offset = torch.nn.Parameter(