### What this PR does / why we need it?
Optimizes the performance of the Qwen3 quantization model by registering
a custom model and adding the AddRmsNormQuant operation. Subsequent PRs
will focus on performance optimizations based on this custom model.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
Signed-off-by: rjg-lyh <1318825571@qq.com>
157 lines
5.7 KiB
Python
157 lines
5.7 KiB
Python
from collections.abc import Iterable
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import Qwen3Config
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import get_pp_group
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
|
from vllm.model_executor.models.qwen2 import Qwen2Model
|
|
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
|
|
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
|
PPMissingLayer, maybe_prefix)
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
|
|
|
|
|
|
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Qwen3Config,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__(config=config,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix)
|
|
if quant_config is None:
|
|
return
|
|
|
|
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
|
|
|
assert isinstance(quant_config, AscendQuantConfig), \
|
|
"Expected quant_config to be an instance of AscendQuantConfig"
|
|
|
|
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
|
|
AscendW8A8LinearMethod):
|
|
self.input_layernorm = AddRMSNormW8A8Quant(
|
|
config.hidden_size,
|
|
layer=self.self_attn.qkv_proj,
|
|
eps=config.rms_norm_eps)
|
|
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
|
|
AscendW8A8LinearMethod):
|
|
self.post_attention_layernorm = AddRMSNormW8A8Quant(
|
|
config.hidden_size,
|
|
layer=self.mlp.gate_up_proj,
|
|
eps=config.rms_norm_eps)
|
|
|
|
|
|
ALL_DECODER_LAYER_TYPES = {
|
|
"attention": CustomQwen3DecoderLayer,
|
|
}
|
|
|
|
|
|
@support_torch_compile(
|
|
dynamic_arg_dims={
|
|
"input_ids": 0,
|
|
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
|
# otherwise (seq_len, ).
|
|
"positions": -1,
|
|
"intermediate_tensors": 0,
|
|
"inputs_embeds": 0,
|
|
})
|
|
class CustomQwen3Model(Qwen2Model):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__(vllm_config=vllm_config,
|
|
prefix=prefix,
|
|
decoder_layer_type=CustomQwen3DecoderLayer)
|
|
|
|
|
|
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|
# add `CustomQwen3Model` to init self.model
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
lora_config = vllm_config.lora_config
|
|
|
|
self.config = config
|
|
self.lora_config = lora_config
|
|
|
|
self.quant_config = quant_config
|
|
self.model = CustomQwen3Model(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
|
|
if get_pp_group().is_last_rank:
|
|
if config.tie_word_embeddings:
|
|
self.lm_head = self.model.embed_tokens
|
|
else:
|
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(
|
|
prefix, "lm_head"))
|
|
else:
|
|
self.lm_head = PPMissingLayer()
|
|
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(
|
|
self,
|
|
skip_prefixes=(["lm_head."]
|
|
if self.config.tie_word_embeddings else None),
|
|
)
|
|
return loader.load_weights(weights)
|