Files
xc-llm-ascend/vllm_ascend/ops/layernorm.py

178 lines
6.7 KiB
Python
Raw Permalink Normal View History

[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
#
import torch
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormGated
from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu
from vllm_ascend.utils import enable_custom_op, get_weight_prefetch_method
[Core] Init vllm-ascend (#3) ### What this PR does / why we need it? vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on the Ascend NPU. This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [RFC]: Hardware pluggable, providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM. This patch also include changes to make CI work and use cache speed up e2e test, including: 1. Change push (post merge ci) and pull_request (pr ci) trigger branch to main 2. Make mypy work by ignore base_communicator and clear unused deps 3. Several improvements for vllm_ascend_test: - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins) - switch `git clone` command to `action/checkout` to speedup checkout and - Enable sv for pytest for better info dump - Remove network host to resole `docker: conflicting ontions: cannot attach both user-defined and non-user-definednetwork-modes`, which is a problem on docker 1.45 but not on 1.39. 4. Adapt MLA decode optimizations: https://github.com/vllm-project/vllm/commit/cabaf4eff3c7df30d785769d5a0a1fa1a1c48a8a ### Does this PR introduce _any_ user-facing change? Yes, init the PR. ### How was this patch tested? - This is the first PR to make ascend NPU work on vLLM. All code is tested on ascend with vLLM V0 Engine. - CI passed --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangshuai09 <391746016@qq.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00
class AscendRMSNorm(RMSNorm):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: int | None = None,
has_weight: bool = True,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
vllm_config = get_current_vllm_config()
self.bias = None
self.bias_loaded = False
# quantization with anti_method m4 will generate none-zero norm bias
if vllm_config.quant_config is not None and any(
"norm.bias" in name for name in vllm_config.quant_config.quant_description
):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
self.bias.weight_loader = self._bias_weight_loader
def _bias_weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()})"
)
param.data.copy_(loaded_weight)
self.bias_loaded = True
def forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
import torch_npu
if residual is not None:
[Feature] Support Flash Comm V1 for VL models (with MLA) (#7390) ## Summary Flash Comm V1 (flashcomm1) was previously blocked for all VL models. **Root cause:** For VL models, `inputs_embeds` at layer 0 originates from the vision encoder as a full `[N, H]` tensor — it has **not** been reduce-scattered across TP ranks. The original MLA forward path assumed inputs were already scattered, producing wrong output shapes under TP > 1. **Fix:** - Detect at init time (statically, not via runtime shape checks) whether a layer is the first layer of a VL model (`is_vl_first_layer`) so dynamo treats the branch as a constant. - In `AscendMultiHeadLatentAttention.forward`, when `flashcomm1 + TP > 1 + is_vl_first_layer`, set `need_gather_q_kv=False` and pre-allocate output as `[N//tp_size, H]`. - Remove the platform-level assertion that prevented VL models from enabling Flash Comm V1. **Other improvements:** - `is_vl_model()` now uses vllm's canonical detection (`hf_config is not hf_text_config`) instead of fragile key-name checks, with the old checks kept as fallback. - Added `parse_layer_idx(prefix)` utility. - Added `maybe_chunk_residual` call in `AscendRMSNorm` before the add-rms-norm op. - Removed unnecessary CPU/fp32 round-trip in `AscendLearnable2DInterpPosEmbDivided_fixed.forward()`. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: LoganJane <loganJane73@hotmail.com>
2026-03-22 21:05:28 +08:00
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
if enable_custom_op():
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
x, residual, self.weight, self.bias, self.variance_epsilon
)
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
else:
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon)
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
if self.bias is not None:
x.add_(self.bias)
return x, residual
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
if self.bias_loaded:
x.add_(self.bias)
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x)
return x
class AscendGemmaRMSNorm(GemmaRMSNorm):
def forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
import torch_npu
if residual is not None:
Qwen3.5 MoE supports flashcomm v1 (#7644) cherry pick from https://github.com/vllm-project/vllm-ascend/pull/7486 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> Multimodal models like Qwen3.5 MoE does embedding in model_runner, so when flash comm is enabled, the first AllGather operation should be skipped. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> No. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> - vLLM version: v0.18.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c --------- Signed-off-by: Wangbingjie <wangbj1207@126.com> Signed-off-by: wangbj127 <256472688+wangbj127@users.noreply.github.com>
2026-03-25 23:09:33 +08:00
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
if enable_custom_op():
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
x, residual, 1.0 + self.weight, None, self.variance_epsilon
)
else:
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, 1.0 + self.weight, self.variance_epsilon)
return x, residual
x, _ = torch.ops._C_ascend.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
return x
class LayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd_npu(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
return y.reshape(x_shape_og)
class AscendRMSNormGated(RMSNormGated):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
group_size: int | None = None,
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(hidden_size, eps, group_size, norm_before_gate, device, dtype)
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward_oot(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size, self.norm_before_gate, True)