[main] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's performance (#1806)
### What this PR does / why we need it?
Optimizes the performance of the Qwen3 quantization model by registering
a custom model and adding the AddRmsNormQuant operation. Subsequent PRs
will focus on performance optimizations based on this custom model.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -167,3 +167,20 @@ def test_models_distributed_topk() -> None:
|
|||||||
distributed_executor_backend="mp",
|
distributed_executor_backend="mp",
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.generate(example_prompts, sampling_params)
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_distributed_Qwen3_W8A8():
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
snapshot_download("vllm-ascend/Qwen3-8B-W8A8"),
|
||||||
|
max_model_len=8192,
|
||||||
|
enforce_eager=True,
|
||||||
|
dtype="auto",
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
quantization="ascend",
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ def register_model():
|
|||||||
from .qwen2_5_vl import \
|
from .qwen2_5_vl import \
|
||||||
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
|
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
|
||||||
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
|
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
|
||||||
|
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"DeepSeekMTPModel",
|
"DeepSeekMTPModel",
|
||||||
@@ -53,6 +54,9 @@ def register_model():
|
|||||||
"Qwen3MoeForCausalLM",
|
"Qwen3MoeForCausalLM",
|
||||||
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"PanguProMoEForCausalLM",
|
"PanguProMoEForCausalLM",
|
||||||
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
|
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
|
||||||
|
|||||||
156
vllm_ascend/models/qwen3.py
Normal file
156
vllm_ascend/models/qwen3.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Qwen3Config
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.distributed import get_pp_group
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||||
|
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||||
|
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
|
||||||
|
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||||
|
PPMissingLayer, maybe_prefix)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
|
||||||
|
|
||||||
|
|
||||||
|
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
if quant_config is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||||
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
|
|
||||||
|
assert isinstance(quant_config, AscendQuantConfig), \
|
||||||
|
"Expected quant_config to be an instance of AscendQuantConfig"
|
||||||
|
|
||||||
|
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
|
||||||
|
AscendW8A8LinearMethod):
|
||||||
|
self.input_layernorm = AddRMSNormW8A8Quant(
|
||||||
|
config.hidden_size,
|
||||||
|
layer=self.self_attn.qkv_proj,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
|
||||||
|
AscendW8A8LinearMethod):
|
||||||
|
self.post_attention_layernorm = AddRMSNormW8A8Quant(
|
||||||
|
config.hidden_size,
|
||||||
|
layer=self.mlp.gate_up_proj,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
|
||||||
|
ALL_DECODER_LAYER_TYPES = {
|
||||||
|
"attention": CustomQwen3DecoderLayer,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims={
|
||||||
|
"input_ids": 0,
|
||||||
|
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||||
|
# otherwise (seq_len, ).
|
||||||
|
"positions": -1,
|
||||||
|
"intermediate_tensors": 0,
|
||||||
|
"inputs_embeds": 0,
|
||||||
|
})
|
||||||
|
class CustomQwen3Model(Qwen2Model):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
decoder_layer_type=CustomQwen3DecoderLayer)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
# add `CustomQwen3Model` to init self.model
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = CustomQwen3Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "lm_head"))
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
||||||
@@ -23,6 +23,43 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm_ascend.utils import is_310p
|
from vllm_ascend.utils import is_310p
|
||||||
|
|
||||||
|
|
||||||
|
class AddRMSNormW8A8Quant(RMSNorm):
|
||||||
|
# Fuse AddRmsNorm and W8A8 quantization ops together
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
var_hidden_size: Optional[int] = None,
|
||||||
|
has_weight: bool = True,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||||
|
self.layer = layer
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
self.weight,
|
||||||
|
self.layer.aclnn_input_scale,
|
||||||
|
self.layer.aclnn_input_offset,
|
||||||
|
epsilon=self.variance_epsilon)
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||||
|
self.variance_epsilon)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def forward_oot(
|
def forward_oot(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@@ -91,10 +91,12 @@ class AscendW8A8LinearMethod:
|
|||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
tp_rank: Optional[int] = 0,
|
tp_rank: Optional[int] = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
original_dtype = x.dtype
|
if x.dtype != torch.int8:
|
||||||
if original_dtype != torch.int8:
|
x = quant_per_tensor(
|
||||||
x = quant_per_tensor(x, layer.aclnn_input_scale,
|
x,
|
||||||
layer.aclnn_input_offset)
|
layer.aclnn_input_scale_reciprocal,
|
||||||
|
layer.aclnn_input_offset,
|
||||||
|
)
|
||||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||||
if is_310p():
|
if is_310p():
|
||||||
# On 300I Duo platform, we need transpose again if
|
# On 300I Duo platform, we need transpose again if
|
||||||
@@ -104,7 +106,7 @@ class AscendW8A8LinearMethod:
|
|||||||
layer.weight.data.transpose(1, 0),
|
layer.weight.data.transpose(1, 0),
|
||||||
layer.deq_scale,
|
layer.deq_scale,
|
||||||
bias=quant_bias,
|
bias=quant_bias,
|
||||||
output_dtype=original_dtype,
|
output_dtype=layer.params_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = torch_npu.npu_quant_matmul(
|
output = torch_npu.npu_quant_matmul(
|
||||||
@@ -112,13 +114,16 @@ class AscendW8A8LinearMethod:
|
|||||||
layer.weight,
|
layer.weight,
|
||||||
layer.deq_scale,
|
layer.deq_scale,
|
||||||
bias=quant_bias,
|
bias=quant_bias,
|
||||||
output_dtype=original_dtype,
|
output_dtype=layer.params_dtype,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
expanding_factor = layer.weight.data.shape[1]
|
expanding_factor = layer.weight.data.shape[1]
|
||||||
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
|
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||||
|
layer.input_scale.data.repeat(expanding_factor),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
||||||
layer.input_scale.data.repeat(expanding_factor),
|
layer.input_scale.data.repeat(expanding_factor),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||||
|
|||||||
Reference in New Issue
Block a user