### What this PR does / why we need it?
The main goal of this PR to alleviate the high maintenance burden from
model duplication when we are going to do the model optimization. Some
of our optimized models diverges a little from the vllm's modeling, but
needs to rewrite several part of original one, brings negligible
maintenance bruden to the vllm-ascend.In order to solve that, we propose
to leverage `torch.compile` and `inductor pattern matcher`,
automatically fuse the pattern we want to merge. For more details can
refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239
This pr integrates `AddRMSNorm` and the `Quant` operator, which can
improve the inference speed of models using `w8a8 `quantization.
### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config
### How was this patch tested?
```python
def main():
prompts = [
"The president of the United States is Mr.",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
# Create an LLM.
llm = LLM(
model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8",
# enforce_eager=True,
tensor_parallel_size=1,
trust_remote_code=True,
gpu_memory_utilization=0.7,
quantization="ascend",
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
```text
Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of'
```
- vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24
- vLLM main:
86e178f7c4
---------
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
123 lines
4.5 KiB
Python
123 lines
4.5 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
from typing import Optional, Tuple, Union, cast
|
|
|
|
import torch
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
|
|
|
|
|
class AscendRMSNorm(RMSNorm):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
eps: float = 1e-6,
|
|
var_hidden_size: Optional[int] = None,
|
|
has_weight: bool = True,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> None:
|
|
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
|
vllm_config = get_current_vllm_config()
|
|
self.bias = None
|
|
# quantization with anti_method m4 will generate none-zero norm bias
|
|
if vllm_config.quant_config is not None and \
|
|
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
|
requires_grad=False)
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
import torch_npu
|
|
|
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
|
if residual is not None:
|
|
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
orig_dtype = residual.dtype
|
|
x = x + residual.to(x.dtype)
|
|
residual = x.to(orig_dtype)
|
|
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
|
self.variance_epsilon)
|
|
else:
|
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
|
x, residual, self.weight, self.variance_epsilon)
|
|
if self.bias is not None:
|
|
x.add_(self.bias)
|
|
return x, residual
|
|
|
|
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
|
self.variance_epsilon)
|
|
if self.bias is not None:
|
|
x.add_(self.bias)
|
|
return x
|
|
|
|
|
|
class AscendQuantRMSNorm(AscendRMSNorm):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
eps: float = 1e-6,
|
|
var_hidden_size: Optional[int] = None,
|
|
has_weight: bool = True,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> None:
|
|
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
|
requires_grad=False)
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
if residual is not None:
|
|
x, residual = super().forward_oot(x, residual)
|
|
return x.add_(self.bias), residual
|
|
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
|
|
|
|
|
|
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
import torch_npu
|
|
|
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
|
if residual is not None:
|
|
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
orig_dtype = residual.dtype
|
|
x = x + residual.to(x.dtype)
|
|
residual = x.to(orig_dtype)
|
|
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
|
self.variance_epsilon)
|
|
else:
|
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
|
x, residual, 1.0 + self.weight, self.variance_epsilon)
|
|
return x, residual
|
|
|
|
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
|
self.variance_epsilon)
|
|
return x
|