### What this PR does / why we need it? - `qkv_proj.weight` prefetching has been implemented with `Quant` op, when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching won't work - Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant`, which has been merged on `main` branch (#3517) ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Tested on `Qwen3-235B-A22B-W8A8` <img width="1868" height="109" alt="image" src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36" /> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
226 lines
8.2 KiB
Python
226 lines
8.2 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
from typing import Optional, Tuple, Union, cast
|
|
|
|
import torch
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
|
|
|
from vllm_ascend.utils import version_check
|
|
|
|
|
|
def _addrmsnorm_forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
layer: Optional[torch.nn.Module] = None,
|
|
bias: Optional[torch.nn.Parameter] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
import torch_npu
|
|
|
|
from vllm_ascend.utils import is_310p
|
|
|
|
torch_npu_check = version_check()
|
|
if layer is not None and not is_310p():
|
|
layer_cls_name = layer.__class__.__name__
|
|
try:
|
|
weight_prefetch_method = get_forward_context(
|
|
).weight_prefetch_method
|
|
except AssertionError:
|
|
weight_prefetch_method = None
|
|
|
|
# prefetch qkvo_proj.weight preprocess
|
|
if weight_prefetch_method:
|
|
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
|
layer_cls_name=layer_cls_name,
|
|
weight=layer.weight,
|
|
start_flag=x,
|
|
)
|
|
# add_rms_norm_quant
|
|
if torch_npu_check:
|
|
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
|
x,
|
|
residual,
|
|
self.weight,
|
|
layer.aclnn_input_scale,
|
|
layer.aclnn_input_offset,
|
|
beta=bias,
|
|
epsilon=self.variance_epsilon)
|
|
else:
|
|
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
|
x,
|
|
residual,
|
|
self.weight,
|
|
layer.aclnn_input_scale,
|
|
layer.aclnn_input_offset,
|
|
epsilon=self.variance_epsilon)
|
|
# prefetch qkvo_proj.weight postprocess
|
|
if weight_prefetch_method:
|
|
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
|
layer_cls_name=layer_cls_name,
|
|
stop_flag=x,
|
|
)
|
|
|
|
else:
|
|
if is_310p():
|
|
orig_dtype = residual.dtype
|
|
x = x + residual.to(x.dtype)
|
|
residual = x.to(orig_dtype)
|
|
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
|
self.variance_epsilon)
|
|
else:
|
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
|
x, residual, self.weight, self.variance_epsilon)
|
|
if torch_npu_check and bias is not None:
|
|
x.add_(bias)
|
|
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
|
return x, residual
|
|
|
|
|
|
class AscendRMSNorm(RMSNorm):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
eps: float = 1e-6,
|
|
var_hidden_size: Optional[int] = None,
|
|
has_weight: bool = True,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> None:
|
|
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
|
vllm_config = get_current_vllm_config()
|
|
self.bias = None
|
|
self.torch_npu_check = version_check()
|
|
# quantization with anti_method m4 will generate none-zero norm bias
|
|
if self.torch_npu_check and vllm_config.quant_config is not None and \
|
|
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
|
requires_grad=False)
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
import torch_npu
|
|
|
|
if residual is not None:
|
|
assert x.size(0) == residual.size(0)
|
|
x, residual = _addrmsnorm_forward_oot(
|
|
self, x, residual, self.next_need_quant_fusion_linear,
|
|
self.bias)
|
|
return x, residual
|
|
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
|
self.variance_epsilon)
|
|
if self.torch_npu_check and self.bias is not None:
|
|
x.add_(self.bias)
|
|
return x
|
|
|
|
@property
|
|
def next_need_quant_fusion_linear(self):
|
|
try:
|
|
forward_context = get_forward_context()
|
|
if not forward_context.addrmsnorm_quant_fusion_enabled or \
|
|
forward_context.layer_idx == forward_context.num_hidden_layers:
|
|
return None
|
|
except AssertionError:
|
|
return None
|
|
|
|
next_linear = None
|
|
model_instance = forward_context.model_instance
|
|
layer_idx = forward_context.layer_idx
|
|
fusion_linear = forward_context.fusion_linear
|
|
next_linear = None
|
|
if fusion_linear == "qkv_dense":
|
|
next_linear = model_instance.model.layers[
|
|
layer_idx].self_attn.qkv_proj
|
|
forward_context.fusion_linear = "gate_up_dense"
|
|
elif fusion_linear == "gate_up_dense":
|
|
next_linear = model_instance.model.layers[
|
|
layer_idx].mlp.gate_up_proj
|
|
forward_context.fusion_linear = "qkv_dense"
|
|
# if prefetch_mlp_weight enabled, following accumulation operation
|
|
# does not need to be repeated
|
|
if not forward_context.prefetch_mlp_enabled:
|
|
forward_context.layer_idx += 1
|
|
elif fusion_linear == "qkv_moe":
|
|
next_linear = model_instance.model.layers[
|
|
layer_idx].self_attn.qkv_proj
|
|
forward_context.fusion_linear = "gate_moe"
|
|
elif fusion_linear == "gate_moe":
|
|
forward_context.fusion_linear = "qkv_moe"
|
|
forward_context.layer_idx += 1
|
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
|
if next_linear is not None and \
|
|
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
|
next_linear = None
|
|
return next_linear
|
|
|
|
|
|
class AscendQuantRMSNorm(AscendRMSNorm):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
eps: float = 1e-6,
|
|
var_hidden_size: Optional[int] = None,
|
|
has_weight: bool = True,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> None:
|
|
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
|
requires_grad=False)
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
if residual is not None:
|
|
x, residual = super().forward_oot(x, residual)
|
|
return x.add_(self.bias), residual
|
|
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
|
|
|
|
|
|
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
import torch_npu
|
|
|
|
from vllm_ascend.utils import is_310p
|
|
if residual is not None:
|
|
if is_310p():
|
|
orig_dtype = residual.dtype
|
|
x = x + residual.to(x.dtype)
|
|
residual = x.to(orig_dtype)
|
|
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
|
self.variance_epsilon)
|
|
else:
|
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
|
x, residual, 1.0 + self.weight, self.variance_epsilon)
|
|
return x, residual
|
|
|
|
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
|
self.variance_epsilon)
|
|
return x
|