### What this PR does / why we need it?
1. In addition to
[#4168](https://github.com/vllm-project/vllm-ascend/pull/4168),
[#5011](https://github.com/vllm-project/vllm-ascend/pull/5011), this PR
adds two more pattern for AddRmsnormQuant with SP enabled. The key
difference is to insert an additional `maybe_all_gather_and_maybe_unpad`
between `addrmsnorm` and `quantize`.
2. This PR also introduce another api `torch.ops.vllm.quantize`, so that
we pass `input_scale` and `input_scale_reciprocal` at the same time.
This is because `npu_add_rms_norm_quant` and `npu_quantize` requires
different `div_mode`. To avoid introducing additional reciprocal
calculation in runtime, we have to pass both of them to quantize api.
3. Removes redundant `AscendQuantRmsnorm`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Angazenn <supperccell@163.com>
98 lines
3.7 KiB
Python
98 lines
3.7 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
|
|
|
|
|
class AscendRMSNorm(RMSNorm):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
eps: float = 1e-6,
|
|
var_hidden_size: Optional[int] = None,
|
|
has_weight: bool = True,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> None:
|
|
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
|
vllm_config = get_current_vllm_config()
|
|
self.bias = None
|
|
# quantization with anti_method m4 will generate none-zero norm bias
|
|
if vllm_config.quant_config is not None and \
|
|
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
|
requires_grad=False)
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
import torch_npu
|
|
|
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
|
if residual is not None:
|
|
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
orig_dtype = residual.dtype
|
|
x = x + residual.to(x.dtype)
|
|
residual = x.to(orig_dtype)
|
|
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
|
self.variance_epsilon)
|
|
else:
|
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
|
x, residual, self.weight, self.variance_epsilon)
|
|
if self.bias is not None:
|
|
x.add_(self.bias)
|
|
return x, residual
|
|
|
|
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
|
self.variance_epsilon)
|
|
if self.bias is not None:
|
|
x.add_(self.bias)
|
|
return x
|
|
|
|
|
|
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
import torch_npu
|
|
|
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
|
if residual is not None:
|
|
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
orig_dtype = residual.dtype
|
|
x = x + residual.to(x.dtype)
|
|
residual = x.to(orig_dtype)
|
|
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
|
self.variance_epsilon)
|
|
else:
|
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
|
x, residual, 1.0 + self.weight, self.variance_epsilon)
|
|
return x, residual
|
|
|
|
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
|
self.variance_epsilon)
|
|
return x
|