Files
xc-llm-ascend/vllm_ascend/ops/layernorm.py

214 lines
7.7 KiB
Python
Raw Normal View History

[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
#
from typing import Optional, Tuple, Union, cast
[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
def _addrmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor,
layer: Optional[torch.nn.Module] = None,
bias: Optional[torch.nn.Parameter] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
from vllm_ascend.utils import is_310p
if layer is not None and not is_310p():
[v0.11.0][Feat] Prefetching Attention QKV Linear Weight With `AddRmsNormQuant` Custom Op (#3649) ### What this PR does / why we need it? - `qkv_proj.weight` prefetching has been implemented with `Quant` op, when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching won't work - Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant`, which has been merged on `main` branch (#3517) ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Tested on `Qwen3-235B-A22B-W8A8` <img width="1868" height="109" alt="image" src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36" /> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-10-27 09:42:09 +08:00
layer_cls_name = layer.__class__.__name__
try:
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
except AssertionError:
weight_prefetch_method = None
# prefetch qkvo_proj.weight preprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
layer_cls_name=layer_cls_name,
weight=layer.weight,
start_flag=x,
)
# add_rms_norm_quant
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
beta=bias,
epsilon=self.variance_epsilon)
[v0.11.0][Feat] Prefetching Attention QKV Linear Weight With `AddRmsNormQuant` Custom Op (#3649) ### What this PR does / why we need it? - `qkv_proj.weight` prefetching has been implemented with `Quant` op, when `AddRmsNormQuant` is enabled (#3465) `qkv_proj.weight` prefetching won't work - Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant`, which has been merged on `main` branch (#3517) ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Tested on `Qwen3-235B-A22B-W8A8` <img width="1868" height="109" alt="image" src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36" /> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-10-27 09:42:09 +08:00
# prefetch qkvo_proj.weight postprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
layer_cls_name=layer_cls_name,
stop_flag=x,
)
else:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if bias is not None:
x.add_(bias)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
class AscendRMSNorm(RMSNorm):
[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
vllm_config = get_current_vllm_config()
self.bias = None
# quantization with anti_method m4 will generate none-zero norm bias
if vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
if residual is not None:
assert x.size(0) == residual.size(0)
x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear,
self.bias)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x
@property
def next_need_quant_fusion_linear(self):
try:
forward_context = get_forward_context()
if not forward_context.addrmsnorm_quant_fusion_enabled or \
forward_context.layer_idx == forward_context.num_hidden_layers:
return None
except AssertionError:
return None
next_linear = None
model_instance = forward_context.model_instance
layer_idx = forward_context.layer_idx
fusion_linear = forward_context.fusion_linear
next_linear = None
if fusion_linear == "qkv_dense":
next_linear = model_instance.model.layers[
layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_up_dense"
elif fusion_linear == "gate_up_dense":
next_linear = model_instance.model.layers[
layer_idx].mlp.gate_up_proj
forward_context.fusion_linear = "qkv_dense"
# if prefetch_mlp_weight enabled, following accumulation operation
# does not need to be repeated
if not forward_context.prefetch_mlp_enabled:
forward_context.layer_idx += 1
elif fusion_linear == "qkv_moe":
next_linear = model_instance.model.layers[
layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_moe"
elif fusion_linear == "gate_moe":
forward_context.fusion_linear = "qkv_moe"
forward_context.layer_idx += 1
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
if next_linear is not None and \
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
next_linear = None
return next_linear
class AscendQuantRMSNorm(AscendRMSNorm):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
x, residual = super().forward_oot(x, residual)
return x.add_(self.bias), residual
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
class AscendGemmaRMSNorm(GemmaRMSNorm):
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
from vllm_ascend.utils import is_310p
if residual is not None:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, 1.0 + self.weight, self.variance_epsilon)
return x, residual
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
self.variance_epsilon)
return x