Files
xc-llm-ascend/vllm_ascend/attention/attention_v1.py

981 lines
39 KiB
Python
Raw Permalink Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from dataclasses import dataclass
from enum import Enum
import torch
import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
)
from vllm.v1.attention.backends.registry import ( # type: ignore
AttentionBackendEnum,
register_backend,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata,
enable_cp,
split_decodes_and_prefills,
using_paged_attention,
)
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params,
get_graph_params,
update_draft_graph_params_workspaces,
update_graph_params_workspaces,
)
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import weak_ref_tensors
# default max value of sliding window size
SWA_INT_MAX = 2147483647
upgrade to vllm 0.11.2 (#4400) Bump vLLM version to v0.11.2 What's broken and changed by vLLM: 1. structured_output is broken by https://github.com/vllm-project/vllm/pull/26866 2. get_mrope_input_positions is broken by https://github.com/vllm-project/vllm/pull/28399 3. graph mode is broken by https://github.com/vllm-project/vllm/pull/25110 we'll upgrade torch to 2.8 to fix the problem later 4. embedding is broken by https://github.com/vllm-project/vllm/pull/27583 5. `get_attn_backend_cls` and attention backend is broken are broken by https://github.com/vllm-project/vllm/pull/28534 6. spec decode is broken by https://github.com/vllm-project/vllm/pull/28771 7. sp feature is broken by https://github.com/vllm-project/vllm/pull/27126 8. mtp is broken by https://github.com/vllm-project/vllm/pull/27922 9. lora is broken by https://github.com/vllm-project/vllm/pull/21068 10. execute_model is broken by https://github.com/vllm-project/vllm/pull/26866 11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by https://github.com/vllm-project/vllm/pull/28159 12. kv cahe is broken by https://github.com/vllm-project/vllm/pull/27753 13. dp is broken by https://github.com/vllm-project/vllm/pull/25110 What's broken and changed by ourself: 1. qwen vl is broken by https://github.com/vllm-project/vllm/pull/28455 We'll remove model files in the future to avoid this kind of error 2. Engine core is broken by https://github.com/vllm-project/vllm/pull/23691 We'll remove the patch file in the future. 3. Ascend scheduler is broken by https://github.com/vllm-project/vllm/pull/28733 We'll remove ascend scheudler later. 4. qwen3-next is broken by https://github.com/vllm-project/vllm/pull/28083 We'll remove model files in the future to avoid this kind of error 5. qwen vl is broken by https://github.com/vllm-project/vllm/pull/27764. We'll remove model files in the future Known issue: 1. ray doesn't work 2. the accuracy of qwen3-next is not correct 3. qwen3-vl is broken 4. prefix cache+ ascend scheduler + deepseek v2 lite is broken. Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: 22dimensions <waitingwind@foxmail.com> Co-authored-by: shen-shanshan <467638484@qq.com> - vLLM version: v0.11.2 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com>
2025-11-26 11:48:58 +08:00
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
class AscendAttentionBackend(AttentionBackend):
support aclgraph (#426) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> This PR supports the access of vllm-acend to the piecewise_graph feature provided by the v1 engine. 1. register unifiled_ascend_attention_with_output for piecewise_graph to split graph. 2. support NPUGraph to accelerate kernel launch. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> support npugraph to default, Users can disenable the npugraph feature by configuring enforce_eager. This has corresponding requirements for the versions of torch_npu and CANN, and they need to support graph capture. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> it turn to default --------- Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type["AscendAttentionBackendImpl"]:
if enable_cp():
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl
return AscendAttentionCPImpl
return AscendAttentionBackendImpl
@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
if enable_cp():
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPMetadataBuilder
return AscendAttentionCPMetadataBuilder
return AscendAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: list[torch.Tensor],
dst_kv_cache: list[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(dst_key_cache.device)
@staticmethod
def copy_blocks(
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
dst_indices = src_to_dists[:, 1]
for kv_cache in kv_caches:
key_caches = kv_cache[0]
value_caches = kv_cache[1]
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
@staticmethod
def get_supported_kernel_block_sizes() -> list[int]:
[BugFix][main] Adapted to torch_npu.npu_fused_infer_attention_score (#4025) ### What this PR does / why we need it? Fixes a compatible bug with `torch_npu.npu_fused_infer_attention_score` which is discribed in https://github.com/vllm-project/vllm-ascend/issues/4020. @momo609 tells us this solution. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? The environment is same with this issue, https://github.com/vllm-project/vllm-ascend/issues/4020. We modify the code according to https://github.com/vllm-project/vllm-ascend/pull/3918. And run below codes: ```python # run with Qwen3-next-mtp prompts = [ "Who are you?", ] sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=128) llm = LLM(model="/home/model/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, distributed_executor_backend="mp", gpu_memory_utilization=0.7, speculative_config={ "method": "qwen3_next_mtp", "num_speculative_tokens": 1, }, max_model_len=4096) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` Outputs: ```text Prompt: 'Who are you?', Generated text: ' I am Qwen, a large-scale language model independently developed by the Tongyi Lab under Alibaba Group. I am designed to answer questions, create text such as stories, official documents, emails, scripts, and more, as well as perform logical reasoning, programming, and other tasks. If you have any questions or need assistance, feel free to let me know anytime!' ``` Now, `torch_npu.npu_fused_infer_attention_score` is compatible with Qwen3-Next. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac Signed-off-by: drslark <slarksblood@qq.com>
2025-11-06 22:00:24 +08:00
return [128]
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
class AscendAttentionState(Enum):
PrefillNoCache = 0
PrefillCacheHit = 1
DecodeOnly = 2
ChunkedPrefill = 3
SpecDecoding = 4
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
@dataclass
class AscendMetadata:
"""
Per-layer attention metadata for Ascend FlashAttention backend.
Contains attention masks, token counts, sequence lengths and KV cache
related properties for attention computation.
"""
# **************************** Basic Properties ************************** #
attn_mask: torch.Tensor | None = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
# Number of tokens excluding padding.
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
num_actual_tokens_pcp_padded: int = 0
num_actual_tokens: int = 0
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
num_decodes_flatten: int = 0
# The sequence length per sequence. Sequence length means the computed
# tokens + new tokens (is None if it is a decoding).
# (batch_size,)
# TODO(Angazenn): The following parameters are quite redundant and
# contains similar information (such as seq_lens seq_lens_list). We
# should simplified these parameters once attention schema in vLLM-Ascend
# is unified.
seq_lens: torch.Tensor = None
seq_lens_cpu: torch.Tensor = None
seq_lens_list: list[int] = None # type: ignore
actual_seq_lengths_q: list[int] = None # type: ignore
query_start_loc: torch.Tensor = None
# Maximum query length in the batch (None for decoding).
max_query_len: int | None = None
# ********************** KV Cache Related Properties ********************* #
# Block addresses per sequence (Seq id -> list of physical block).
# (batch_size, max_blocks_per_seq)
block_tables: torch.Tensor = None
# The indices of the token slots that input tokens will be stored into.
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
# and 1st slot in block 1, respectively.
# (num_tokens,)
slot_mapping: torch.Tensor = None
# pcp
prefill: AscendMetadataForPrefill | None = None
# dcp
decode_meta: AscendMetadataForDecode | None = None
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
causal: bool = True
# runner_type in model_config.
model_runner_type: str = ""
# prefill reshape_and_cache event
reshape_cache_event: torch.npu.Event = None
# sliding window attention mask
swa_mask: torch.Tensor | None = None
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
"""
Builder for constructing AscendMetadata from CommonAttentionMetadata.
Handles attention mask generation and metadata preparation for
Ascend FlashAttention backend.
"""
[Feat][Graph] Support `FULL_DECODE_ONLY` mode for GQA/MHA models (#2128) Note: This depends on [vLLM #25161](https://github.com/vllm-project/vllm/pull/25161) and the torch\_npu release from September 30. ### What this PR does / why we need it? This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA models like DeepSeek V3/R1 are not included). Key improvements include: * **Reduced dispatch latency:** By replaying the entire model execution graph at once, we cut overhead compared with multiple smaller replays. * **Stabilized multi-device performance:** Captureing the whole model as one static graph also mitigates the dispatch fluctuations across devices. * **Stream/resource savings:** Consolidating graph captures frees up streams, allowing more graphs to be captured. **Known issues:** 1. `_npu_paged_attention` currently manages its own workspace in `torch_npu`, which can deadlock when synchronizing during graph replay — we’re working on a fix. There may be other corner cases. This PR is the first in a planned series; we’ll continue to iterate and address remaining issues in follow-ups. This is essentially a port of #1503 and #1677, but includes two major changes: 1. Let `graph_dispatcher` decide the graph mode instead of hard-coding it in the backend, which decouples Full Graph and Piecewise Graph and could make it possible to remove dynamo. 2. Adapt to the new `attn_group` logic, but leave a small hack in `update_graph_params`; multi-attention models may or may not be fully supported yet. ### Does this PR introduce _any_ user-facing change? ```python compilation_config={ "cudagraph_mode": "FULL_DECODE_ONLY", }, ``` ### How was this patch tested? Tests included. - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/9607d5eb449711b349d4c2bee0a9c94afcc7ed14 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-09-22 17:14:28 +08:00
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int = 1
def __init__(
self,
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.compilation_config = vllm_config.compilation_config
self.device = device
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len, AscendAttentionBackend.get_supported_kernel_block_sizes()[0]
)
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, (
f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
)
self.reorder_batch_threshold = self.decode_threshold
scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@classmethod
def get_cudagraph_support(
cls: type["AscendAttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.ALWAYS
def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool:
return False
def build(
self,
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
fast_build: bool = False,
) -> AscendMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.decode_threshold
)
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
block_table = common_attn_metadata.block_table_tensor
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
[Main2Main][Deps][Misc] Upgrade vLLM to v0.15.0 (#6470) ### What this PR does / why we need it? This PR upgrades the vLLM dependency from `v0.14.1` to `v0.15.0`. This involves: - Updating the `VLLM_TAG` in all `Dockerfile`. - Updating the vLLM version in `docs/source/conf.py`. - Removing conditional code paths specific to `v0.14.1` across the codebase, which simplifies maintenance. - Fix `TypeError: MMEncoderAttention.__init__() got an unexpected keyword argument 'multimodal_config'` due to https://github.com/vllm-project/vllm/pull/31972. - Fix `_shared_experts: 'NoneType' object is not callable` due to https://github.com/vllm-project/vllm/pull/32082 by https://github.com/vllm-project/vllm-ascend/pull/6335. - Fix `ReshapeAndCacheOperation setup failed!` due to https://github.com/vllm-project/vllm/pull/25954 by overriding attention metadata slots. This upgrade is necessary to keep the project aligned with the latest features, bug fixes, and API changes in the vLLM project. ### Does this PR introduce _any_ user-facing change? No, this is an internal dependency update and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with these changes, ensuring that all existing tests are successful with the new vLLM version. - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd co-authored-by: shen-shanshan <467638484@qq.com> --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2026-02-02 15:57:55 +08:00
# this slot_mapping override doesn't work since vllm will override it again. We should fix it vllm.
# see: https://github.com/vllm-project/vllm/blob/ce88756b967c2c5006746a424c15dd59a284ed8c/vllm/model_executor/layers/attention/cross_attention.py#L117
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
seq_lens = common_attn_metadata.seq_lens
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
[feat][spec decode]Unified draft parallel (#6766) ### What this PR does / why we need it? Implement a unified parallelized speculative decoding in VLLM Ascend,which can simultaneously support parallel speculative inference schemes such as Pard, P-Eagle, etc. refer to https://github.com/vllm-project/vllm-ascend/pull/6565 and https://github.com/vllm-project/vllm-ascend/pull/4078 ### How was this patch tested? run with parallel drafting script: export target=/model/Llama-3.1-8B-Instruct export draft=/model/PARD-Llama-3.2-1B export CUDA_VISIBLE_DEVICES=6 export ASCEND_RT_VISIBLE_DEVICES=6 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 \ --speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method": "draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}' base script: export target=/model/Llama-3.1-8B-Instruct export draft=/model/PARD-Llama-3.2-1B export CUDA_VISIBLE_DEVICES=6 export ASCEND_RT_VISIBLE_DEVICES=6 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 benchmark script: MAX_CONCURRENCY=1 NUM_PROMPTS=80 vllm bench serve --port 8811 \ --temperature 0 \ --model /model/Llama-3.1-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --dataset-path philschmid/mt-bench \ --num-prompts ${NUM_PROMPTS} \ --max-concurrency ${MAX_CONCURRENCY} \ --seed 1234 test results : base(without spec decode): TTFT 79.46ms TPOT 26.99ms output_tokens_throughput 36.75 tok/s this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms output_tokens_throughput 72.98 tok/s per-position acceptance(from position 0 to 7): 79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%. ---------------------------------------------------------------------- run on qwen3 model script : export target=/model/Qwen3-1.7B export draft=/model/PARD-Qwen3-0.6B export CUDA_VISIBLE_DEVICES=1 export ASCEND_RT_VISIBLE_DEVICES=1 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 \ --speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method": "draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}' cc @NickJudyHvv - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9562912cead1f11e8540fb91306c5cbda66f0007 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: kx <1670186653@qq.com> Signed-off-by: HF-001 <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
2026-03-13 14:07:35 +08:00
elif self.speculative_config and self.speculative_config.parallel_drafting:
seq_lens = common_attn_metadata.seq_lens
attn_state = common_attn_metadata.attn_state
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config)
swa_mask = None
is_swa = hasattr(self.model_config.hf_text_config, "sliding_window")
if self.model_config is not None and is_swa:
swa_mask = self.attn_mask_builder.get_swa_mask(
self.model_config.dtype, self.model_config.hf_text_config.sliding_window
)
[FEAT] Support DeepSeek-V3.2 with `FULL_DECODE_ONLY` mode (#4706) ### What this PR does / why we need it? The first commit support `FULL_DECODE_ONLY`: - Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for slicing slots and positions, ensuring fixed tensor shapes. - Implement padding logic for `query_start_loc` in `NPUModelRunner` to support uniform decode in full graph mode, aligning with GPU runner behavior. - Adjust MLA cosine cache allocation to occur independently of graph mode and switch to using device-resident sequence lengths for attention metadata. - Remove redundant slicing of hidden states and outputs in `AscendSFAImpl` and optimize `sin`/`cos` cache updates. The second commit take MTP into account: - Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for slicing slots and positions, ensuring fixed tensor shapes. - Implement padding logic for `query_start_loc` in `NPUModelRunner` to support uniform decode in full graph mode, aligning with GPU runner behavior. - Adjust MLA cosine cache allocation to occur independently of graph mode and switch to using device-resident sequence lengths for attention metadata. - Remove redundant slicing of hidden states and outputs in `AscendSFAImpl` and optimize `sin`/`cos` cache updates. And the rest of them are just bugfix. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Test cases needed. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-12-10 20:11:09 +08:00
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device, non_blocking=True)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
num_decode_tokens=num_decode_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len,
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping,
attn_mask=attn_mask,
swa_mask=swa_mask,
attn_state=attn_state,
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
num_prefills=num_prefills,
num_decodes=num_decodes,
causal=common_attn_metadata.causal,
model_runner_type=self.model_config.runner_type,
)
return attn_metadata
[Feat][Graph] Support `FULL_DECODE_ONLY` mode for GQA/MHA models (#2128) Note: This depends on [vLLM #25161](https://github.com/vllm-project/vllm/pull/25161) and the torch\_npu release from September 30. ### What this PR does / why we need it? This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA models like DeepSeek V3/R1 are not included). Key improvements include: * **Reduced dispatch latency:** By replaying the entire model execution graph at once, we cut overhead compared with multiple smaller replays. * **Stabilized multi-device performance:** Captureing the whole model as one static graph also mitigates the dispatch fluctuations across devices. * **Stream/resource savings:** Consolidating graph captures frees up streams, allowing more graphs to be captured. **Known issues:** 1. `_npu_paged_attention` currently manages its own workspace in `torch_npu`, which can deadlock when synchronizing during graph replay — we’re working on a fix. There may be other corner cases. This PR is the first in a planned series; we’ll continue to iterate and address remaining issues in follow-ups. This is essentially a port of #1503 and #1677, but includes two major changes: 1. Let `graph_dispatcher` decide the graph mode instead of hard-coding it in the backend, which decouples Full Graph and Piecewise Graph and could make it possible to remove dynamo. 2. Adapt to the new `attn_group` logic, but leave a small hack in `update_graph_params`; multi-attention models may or may not be fully supported yet. ### Does this PR introduce _any_ user-facing change? ```python compilation_config={ "cudagraph_mode": "FULL_DECODE_ONLY", }, ``` ### How was this patch tested? Tests included. - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/9607d5eb449711b349d4c2bee0a9c94afcc7ed14 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-09-22 17:14:28 +08:00
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state in (
AscendAttentionState.DecodeOnly,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
):
[Feat][Graph] Support `FULL_DECODE_ONLY` mode for GQA/MHA models (#2128) Note: This depends on [vLLM #25161](https://github.com/vllm-project/vllm/pull/25161) and the torch\_npu release from September 30. ### What this PR does / why we need it? This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA models like DeepSeek V3/R1 are not included). Key improvements include: * **Reduced dispatch latency:** By replaying the entire model execution graph at once, we cut overhead compared with multiple smaller replays. * **Stabilized multi-device performance:** Captureing the whole model as one static graph also mitigates the dispatch fluctuations across devices. * **Stream/resource savings:** Consolidating graph captures frees up streams, allowing more graphs to be captured. **Known issues:** 1. `_npu_paged_attention` currently manages its own workspace in `torch_npu`, which can deadlock when synchronizing during graph replay — we’re working on a fix. There may be other corner cases. This PR is the first in a planned series; we’ll continue to iterate and address remaining issues in follow-ups. This is essentially a port of #1503 and #1677, but includes two major changes: 1. Let `graph_dispatcher` decide the graph mode instead of hard-coding it in the backend, which decouples Full Graph and Piecewise Graph and could make it possible to remove dynamo. 2. Adapt to the new `attn_group` logic, but leave a small hack in `update_graph_params`; multi-attention models may or may not be fully supported yet. ### Does this PR introduce _any_ user-facing change? ```python compilation_config={ "cudagraph_mode": "FULL_DECODE_ONLY", }, ``` ### How was this patch tested? Tests included. - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/9607d5eb449711b349d4c2bee0a9c94afcc7ed14 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-09-22 17:14:28 +08:00
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
[Feat][main] Supported to use full-graph with Qwen3-Next-MTP (#5477) ### What this PR does / why we need it? Supported to use full-graph with Qwen3-Next-MTP. In detail, we adatpted `AscendAttentionState.ChunkedPrefill` in main model, and also adapted `AscendAttentionState.ChunkedPrefill` in mtp model. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? We changed the test of Qwen3-Next-MTP in `tests/e2e/multicard/test_qwen3_next.py` to make it a test of `FULL_DECODE_ONLY`. Then run `pytest -s tests/e2e/multicard/test_qwen3_next.py::test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4`. And this test passed. ```text . ================================================================================================================================= warnings summary ================================================================================================================================= <frozen importlib._bootstrap>:241 <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute <frozen importlib._bootstrap>:241 <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ==================================================================================================================== 1 passed, 2 warnings in 271.89s (0:04:31) ===================================================================================================================== ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Signed-off-by: drslark <slarksblood@qq.com>
2026-01-04 12:03:21 +08:00
"Currently we only support building dummy metadata for DecodeOnly and ChunkedPrefill state"
[Feat][Graph] Support `FULL_DECODE_ONLY` mode for GQA/MHA models (#2128) Note: This depends on [vLLM #25161](https://github.com/vllm-project/vllm/pull/25161) and the torch\_npu release from September 30. ### What this PR does / why we need it? This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA models like DeepSeek V3/R1 are not included). Key improvements include: * **Reduced dispatch latency:** By replaying the entire model execution graph at once, we cut overhead compared with multiple smaller replays. * **Stabilized multi-device performance:** Captureing the whole model as one static graph also mitigates the dispatch fluctuations across devices. * **Stream/resource savings:** Consolidating graph captures frees up streams, allowing more graphs to be captured. **Known issues:** 1. `_npu_paged_attention` currently manages its own workspace in `torch_npu`, which can deadlock when synchronizing during graph replay — we’re working on a fix. There may be other corner cases. This PR is the first in a planned series; we’ll continue to iterate and address remaining issues in follow-ups. This is essentially a port of #1503 and #1677, but includes two major changes: 1. Let `graph_dispatcher` decide the graph mode instead of hard-coding it in the backend, which decouples Full Graph and Piecewise Graph and could make it possible to remove dynamo. 2. Adapt to the new `attn_group` logic, but leave a small hack in `update_graph_params`; multi-attention models may or may not be fully supported yet. ### Does this PR introduce _any_ user-facing change? ```python compilation_config={ "cudagraph_mode": "FULL_DECODE_ONLY", }, ``` ### How was this patch tested? Tests included. - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/9607d5eb449711b349d4c2bee0a9c94afcc7ed14 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-09-22 17:14:28 +08:00
)
attn_metadata.attn_state = attn_state
return attn_metadata
class AscendAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
[Attention] add gpt-oss support (#5901) ### What this PR does / why we need it? Please refer to the following link for the historical conversation https://github.com/vllm-project/vllm-ascend/pull/4467. We have made updates in light of the comments from the prior PR review. Given the refactoring of the attention_v1 component, we have carried out necessary adjustments to fit the newly revised code. ### Does this PR introduce _any_ user-facing change? 1. Modified the code in the Attention section to adapt to the SWA and Sink features required by gpt-oss. 2. Modified the code in the MoE section to add support for bias and swigluoai. ### How was this patch tested? Please refer to the https://github.com/vllm-project/vllm-ascend/pull/4467 for performance tests, on the basis of which the accuracy tests from AIME2024 have been newly added. ![img_v3_02tu_501e88e3-2217-4565-8edf-b9acf4f43f2g](https://github.com/user-attachments/assets/024f8283-18ab-4d4d-ab12-27917b5d7d06) - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bde38c11df0ea066a740efe9b77fff5418be45df --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: mikequan0425 <mikequan0425@foxmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Signed-off-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Signed-off-by: pu-zhe <zpuaa@outlook.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com> Signed-off-by: luomin2005 <luomin2005@huawei.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MrZ20 <2609716663@qq.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leon_tao <taoyao2@huawei.com> Co-authored-by: nurxat <738457498@qq.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: mikequan <199741451@qq.com> Co-authored-by: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com> Co-authored-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Co-authored-by: pu-zhe <zpuaa@outlook.com> Co-authored-by: luomin2005 <luomin2005@huawei.com> Co-authored-by: liziyu <56102866+liziyu179@users.noreply.github.com> Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Cao Yi <slightwindsec@gmail.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: SILONG ZENG <2609716663@qq.com>
2026-02-12 10:55:34 +08:00
sinks: torch.Tensor = None,
**kwargs,
) -> None:
self.vllm_config = get_current_vllm_config()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.hidden_size = self.num_heads * self.head_size
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
[Attention] add gpt-oss support (#5901) ### What this PR does / why we need it? Please refer to the following link for the historical conversation https://github.com/vllm-project/vllm-ascend/pull/4467. We have made updates in light of the comments from the prior PR review. Given the refactoring of the attention_v1 component, we have carried out necessary adjustments to fit the newly revised code. ### Does this PR introduce _any_ user-facing change? 1. Modified the code in the Attention section to adapt to the SWA and Sink features required by gpt-oss. 2. Modified the code in the MoE section to add support for bias and swigluoai. ### How was this patch tested? Please refer to the https://github.com/vllm-project/vllm-ascend/pull/4467 for performance tests, on the basis of which the accuracy tests from AIME2024 have been newly added. ![img_v3_02tu_501e88e3-2217-4565-8edf-b9acf4f43f2g](https://github.com/user-attachments/assets/024f8283-18ab-4d4d-ab12-27917b5d7d06) - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bde38c11df0ea066a740efe9b77fff5418be45df --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: mikequan0425 <mikequan0425@foxmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Signed-off-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Signed-off-by: pu-zhe <zpuaa@outlook.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com> Signed-off-by: luomin2005 <luomin2005@huawei.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MrZ20 <2609716663@qq.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leon_tao <taoyao2@huawei.com> Co-authored-by: nurxat <738457498@qq.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: mikequan <199741451@qq.com> Co-authored-by: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com> Co-authored-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Co-authored-by: pu-zhe <zpuaa@outlook.com> Co-authored-by: luomin2005 <luomin2005@huawei.com> Co-authored-by: liziyu <56102866+liziyu179@users.noreply.github.com> Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Cao Yi <slightwindsec@gmail.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: SILONG ZENG <2609716663@qq.com>
2026-02-12 10:55:34 +08:00
self.sinks = sinks
@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
draft_attn_metadatas=None,
):
if using_paged_attention(num_tokens, vllm_config):
# Paged Attention update logic
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
else:
# FIA update logic
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
attn_metadata = draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
else:
graph_params = get_graph_params()
attn_metadata = forward_context.attn_metadata
attn_keys = list(attn_metadata.keys())
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
# TODO: We use a new variable `attn_keys` to ensure the loop count is
# correct after get by `zip` because of the new structure of the attn_metadata
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
num_layers = len(attn_keys)
if num_layers == 0:
return
if _EXTRA_CTX.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
attn_keys,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
if _EXTRA_CTX.is_draft_model:
draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
block_tables = attn_metadata[draft_step][key].block_tables
attn_count = attn_count + 1
else:
seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
block_tables = attn_metadata[key].block_tables
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key_cache,
value=value,
block_table=block_tables,
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
flashcomm2_oshard_manager.post_process_after_loading()
def full_graph_fia(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
) -> torch.Tensor:
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
sparse_mode=3,
scale=self.scale,
)
if _EXTRA_CTX.is_draft_model:
update_draft_graph_params_workspaces(num_tokens, workspace)
else:
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(query),
weak_ref_tensors(key),
weak_ref_tensors(value),
weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.attn_mask),
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
self.num_kv_heads,
self.num_heads,
self.scale,
weak_ref_tensors(output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
workspace=workspace,
out=[output, softmax_lse],
)
output = output.view(num_tokens, self.num_heads, self.head_size)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output, num_tokens
def full_graph_pa(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
):
graph_params = get_graph_params()
num_tokens = query.shape[0]
if _EXTRA_CTX.capturing:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
)
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace,
)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
if self.attn_type == AttentionType.ENCODER_DECODER:
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, dim=0).tolist()
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.seq_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
support prefill cache mode use fia op (#3696) ### What this PR does / why we need it? support prefill cache mode use fia op for full graph ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 origin ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 131.63 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 466.77 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 932.95 ---------------Time to First Token---------------- Mean TTFT (ms): 125.17 Median TTFT (ms): 121.51 P50 TTFT (ms): 121.51 P90 TTFT (ms): 140.91 P99 TTFT (ms): 182.36 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.85 Median TPOT (ms): 43.84 P50 TPOT (ms): 43.84 P90 TPOT (ms): 44.28 P99 TPOT (ms): 44.32 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.85 Median ITL (ms): 42.63 P50 ITL (ms): 42.63 P90 ITL (ms): 48.74 P99 ITL (ms): 59.62 ================================================== after ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 130.10 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 472.26 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 943.94 ---------------Time to First Token---------------- Mean TTFT (ms): 123.69 Median TTFT (ms): 122.51 P50 TTFT (ms): 122.51 P90 TTFT (ms): 143.69 P99 TTFT (ms): 165.00 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.07 Median TPOT (ms): 43.13 P50 TPOT (ms): 43.13 P90 TPOT (ms): 43.50 P99 TPOT (ms): 43.57 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.07 Median ITL (ms): 41.81 P50 ITL (ms): 41.81 P90 ITL (ms): 48.11 P99 ITL (ms): 62.13 ================================================== Signed-off-by: shiyuan680 <917935075@qq.com>
2025-10-27 19:41:07 +08:00
key = self.key_cache.view( # type: ignore
num_block, block_size, -1
)
support prefill cache mode use fia op (#3696) ### What this PR does / why we need it? support prefill cache mode use fia op for full graph ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 origin ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 131.63 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 466.77 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 932.95 ---------------Time to First Token---------------- Mean TTFT (ms): 125.17 Median TTFT (ms): 121.51 P50 TTFT (ms): 121.51 P90 TTFT (ms): 140.91 P99 TTFT (ms): 182.36 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.85 Median TPOT (ms): 43.84 P50 TPOT (ms): 43.84 P90 TPOT (ms): 44.28 P99 TPOT (ms): 44.32 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.85 Median ITL (ms): 42.63 P50 ITL (ms): 42.63 P90 ITL (ms): 48.74 P99 ITL (ms): 59.62 ================================================== after ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 130.10 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 472.26 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 943.94 ---------------Time to First Token---------------- Mean TTFT (ms): 123.69 Median TTFT (ms): 122.51 P50 TTFT (ms): 122.51 P90 TTFT (ms): 143.69 P99 TTFT (ms): 165.00 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.07 Median TPOT (ms): 43.13 P50 TPOT (ms): 43.13 P90 TPOT (ms): 43.50 P99 TPOT (ms): 43.57 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.07 Median ITL (ms): 41.81 P50 ITL (ms): 41.81 P90 ITL (ms): 48.11 P99 ITL (ms): 62.13 ================================================== Signed-off-by: shiyuan680 <917935075@qq.com>
2025-10-27 19:41:07 +08:00
value = self.value_cache.view( # type: ignore
num_block, block_size, -1
)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1
)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# chunked prefill.
support prefill cache mode use fia op (#3696) ### What this PR does / why we need it? support prefill cache mode use fia op for full graph ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 origin ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 131.63 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 466.77 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 932.95 ---------------Time to First Token---------------- Mean TTFT (ms): 125.17 Median TTFT (ms): 121.51 P50 TTFT (ms): 121.51 P90 TTFT (ms): 140.91 P99 TTFT (ms): 182.36 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.85 Median TPOT (ms): 43.84 P50 TPOT (ms): 43.84 P90 TPOT (ms): 44.28 P99 TPOT (ms): 44.32 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.85 Median ITL (ms): 42.63 P50 ITL (ms): 42.63 P90 ITL (ms): 48.74 P99 ITL (ms): 59.62 ================================================== after ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 130.10 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 472.26 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 943.94 ---------------Time to First Token---------------- Mean TTFT (ms): 123.69 Median TTFT (ms): 122.51 P50 TTFT (ms): 122.51 P90 TTFT (ms): 143.69 P99 TTFT (ms): 165.00 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.07 Median TPOT (ms): 43.13 P50 TPOT (ms): 43.13 P90 TPOT (ms): 43.50 P99 TPOT (ms): 43.57 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.07 Median ITL (ms): 41.81 P50 ITL (ms): 41.81 P90 ITL (ms): 48.11 P99 ITL (ms): 62.13 ================================================== Signed-off-by: shiyuan680 <917935075@qq.com>
2025-10-27 19:41:07 +08:00
else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1
)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
return key, value, block_size, block_table, actual_seq_lengths_kv
def _forward_fia_slidingwindow(self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor):
batch_size = attn_metadata.seq_lens.shape[0]
block_size = 128
query = query.view(batch_size, 1, self.num_heads * self.head_size)
key = self.key_cache
value = self.value_cache
if self.key_cache is not None and self.value_cache is not None:
block_size = self.key_cache.shape[1]
key = self.key_cache.flatten(2, 3).contiguous()
value = self.value_cache.flatten(2, 3).contiguous()
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSH",
block_size=block_size,
pre_tokens=self.sliding_window,
scale=self.scale,
block_table=attn_metadata.block_tables,
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
actual_seq_lengths_kv=attn_metadata.seq_lens,
)
attn_output = attn_output.view(batch_size, self.num_heads, self.head_size)
output[:batch_size] = attn_output[:batch_size]
return output
def forward_fused_infer_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
# we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context,
# just use getattr to avoid attribute error.
if _EXTRA_CTX.capturing:
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
if (
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and self.sliding_window is not None
and attn_metadata.seq_lens.shape[0] == query.size(0)
[Attention] add gpt-oss support (#5901) ### What this PR does / why we need it? Please refer to the following link for the historical conversation https://github.com/vllm-project/vllm-ascend/pull/4467. We have made updates in light of the comments from the prior PR review. Given the refactoring of the attention_v1 component, we have carried out necessary adjustments to fit the newly revised code. ### Does this PR introduce _any_ user-facing change? 1. Modified the code in the Attention section to adapt to the SWA and Sink features required by gpt-oss. 2. Modified the code in the MoE section to add support for bias and swigluoai. ### How was this patch tested? Please refer to the https://github.com/vllm-project/vllm-ascend/pull/4467 for performance tests, on the basis of which the accuracy tests from AIME2024 have been newly added. ![img_v3_02tu_501e88e3-2217-4565-8edf-b9acf4f43f2g](https://github.com/user-attachments/assets/024f8283-18ab-4d4d-ab12-27917b5d7d06) - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bde38c11df0ea066a740efe9b77fff5418be45df --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: mikequan0425 <mikequan0425@foxmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Signed-off-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Signed-off-by: pu-zhe <zpuaa@outlook.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com> Signed-off-by: luomin2005 <luomin2005@huawei.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MrZ20 <2609716663@qq.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leon_tao <taoyao2@huawei.com> Co-authored-by: nurxat <738457498@qq.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: mikequan <199741451@qq.com> Co-authored-by: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com> Co-authored-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Co-authored-by: pu-zhe <zpuaa@outlook.com> Co-authored-by: luomin2005 <luomin2005@huawei.com> Co-authored-by: liziyu <56102866+liziyu179@users.noreply.github.com> Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Cao Yi <slightwindsec@gmail.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: SILONG ZENG <2609716663@qq.com>
2026-02-12 10:55:34 +08:00
and self.sinks is None
):
return self._forward_fia_slidingwindow(query, attn_metadata, output)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
if (
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
and self.attn_type != AttentionType.ENCODER_DECODER
):
key = key[:num_tokens]
value = value[:num_tokens]
# Get workspace from cache or calculate it if not present.
[Attention] add gpt-oss support (#5901) ### What this PR does / why we need it? Please refer to the following link for the historical conversation https://github.com/vllm-project/vllm-ascend/pull/4467. We have made updates in light of the comments from the prior PR review. Given the refactoring of the attention_v1 component, we have carried out necessary adjustments to fit the newly revised code. ### Does this PR introduce _any_ user-facing change? 1. Modified the code in the Attention section to adapt to the SWA and Sink features required by gpt-oss. 2. Modified the code in the MoE section to add support for bias and swigluoai. ### How was this patch tested? Please refer to the https://github.com/vllm-project/vllm-ascend/pull/4467 for performance tests, on the basis of which the accuracy tests from AIME2024 have been newly added. ![img_v3_02tu_501e88e3-2217-4565-8edf-b9acf4f43f2g](https://github.com/user-attachments/assets/024f8283-18ab-4d4d-ab12-27917b5d7d06) - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bde38c11df0ea066a740efe9b77fff5418be45df --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: mikequan0425 <mikequan0425@foxmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Signed-off-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Signed-off-by: pu-zhe <zpuaa@outlook.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com> Signed-off-by: luomin2005 <luomin2005@huawei.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MrZ20 <2609716663@qq.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leon_tao <taoyao2@huawei.com> Co-authored-by: nurxat <738457498@qq.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: mikequan <199741451@qq.com> Co-authored-by: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com> Co-authored-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Co-authored-by: pu-zhe <zpuaa@outlook.com> Co-authored-by: luomin2005 <luomin2005@huawei.com> Co-authored-by: liziyu <56102866+liziyu179@users.noreply.github.com> Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Cao Yi <slightwindsec@gmail.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: SILONG ZENG <2609716663@qq.com>
2026-02-12 10:55:34 +08:00
if self.sinks is not None:
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
actual_seq_qlen = torch.tensor([1] * len(attn_metadata.seq_lens_list), dtype=torch.int32).cumsum(dim=0)
if self.sliding_window is not None:
atten_mask = attn_metadata.swa_mask
sparse_mode = 4
else:
atten_mask = attn_metadata.attn_mask
sparse_mode = 3
attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(
query,
key,
value,
num_query_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND",
pre_tokens=self.sliding_window if self.sliding_window is not None else SWA_INT_MAX,
next_tokens=0,
atten_mask=atten_mask,
sparse_mode=sparse_mode,
softmax_scale=self.scale,
block_table=block_table,
block_size=block_size,
actual_seq_qlen=actual_seq_qlen,
actual_seq_kvlen=actual_seq_lengths_kv,
learnable_sink=self.sinks,
)
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
[Attention] add gpt-oss support (#5901) ### What this PR does / why we need it? Please refer to the following link for the historical conversation https://github.com/vllm-project/vllm-ascend/pull/4467. We have made updates in light of the comments from the prior PR review. Given the refactoring of the attention_v1 component, we have carried out necessary adjustments to fit the newly revised code. ### Does this PR introduce _any_ user-facing change? 1. Modified the code in the Attention section to adapt to the SWA and Sink features required by gpt-oss. 2. Modified the code in the MoE section to add support for bias and swigluoai. ### How was this patch tested? Please refer to the https://github.com/vllm-project/vllm-ascend/pull/4467 for performance tests, on the basis of which the accuracy tests from AIME2024 have been newly added. ![img_v3_02tu_501e88e3-2217-4565-8edf-b9acf4f43f2g](https://github.com/user-attachments/assets/024f8283-18ab-4d4d-ab12-27917b5d7d06) - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bde38c11df0ea066a740efe9b77fff5418be45df --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: mikequan0425 <mikequan0425@foxmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Signed-off-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Signed-off-by: pu-zhe <zpuaa@outlook.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com> Signed-off-by: luomin2005 <luomin2005@huawei.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MrZ20 <2609716663@qq.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leon_tao <taoyao2@huawei.com> Co-authored-by: nurxat <738457498@qq.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: mikequan <199741451@qq.com> Co-authored-by: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com> Co-authored-by: jiangyunfan1 <jiangyunfan1@h-partners.com> Co-authored-by: pu-zhe <zpuaa@outlook.com> Co-authored-by: luomin2005 <luomin2005@huawei.com> Co-authored-by: liziyu <56102866+liziyu179@users.noreply.github.com> Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Cao Yi <slightwindsec@gmail.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: SILONG ZENG <2609716663@qq.com>
2026-02-12 10:55:34 +08:00
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output[:num_tokens]
return output
def forward_paged_attention(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
) -> torch.Tensor:
if _EXTRA_CTX.capturing:
return self.full_graph_pa(query, attn_metadata, output)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
_: torch.Tensor,
) -> torch.Tensor:
[BugFix][v0.18.0][cherry-pick] Fix embedding prefix caching for APC (#7894) ## What this PR does / why we need it? pick-from:https://github.com/vllm-project/vllm-ascend/pull/7452 ### Problem Embedding models produce inconsistent outputs when prefix caching is enabled vs disabled. ### Root Cause The attention router condition was too broad: - All `model_runner_type == "pooling"` → `_forward_encoder_attention()` → uses `npu_fusion_attention` - **But `npu_fusion_attention` does NOT support prefix caching** - Result: Numerical mismatch when KV cache is managed by prefix caching ### Solution Refine the router condition to check causality: **Before**: ``` if attn_metadata.model_runner_type == "pooling": → npu_fusion_attention (no prefix caching support) ``` **After**: ``` if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal: → npu_fusion_attention (for true encoders) else: → npu_fused_infer_attention_score (prefix caching support) ``` ### Changes Made 1. **Fixed router condition** (`vllm_ascend/attention/attention_v1.py` L968) - Added `and not attn_metadata.causal` check - Effect: Non-causal embeddings now use correct operator 2. **Simplified encoder attention** (`vllm_ascend/attention/attention_v1.py` L864-877) - Removed redundant causal branch (encoders never use causal mask) - Reduced from 34 lines to 14 lines 3. **Added test** (`tests/e2e/singlecard/pooling/test_embedding.py`) - Validates embedding outputs with/without prefix caching are consistent ## Does this PR introduce _any_ user-facing change? ### Functional Changes ✅ **Yes** - Bug fix: Embedding models now produce consistent outputs with prefix caching ### API Changes ❌ **No** - All public APIs unchanged ### Configuration Changes ❌ **No** - No new configuration required ### Backward Compatibility ✅ **Fully compatible** - Only fixes incorrect behavior ## How was this patch tested? ### New Test Added `test_embed_models_using_prefix_caching_correctness()`: - Tests: `Qwen3-Embedding-0.6B` - Validates numerical consistency between runs with/without prefix caching - Uses long sequences to activate prefix caching - Tolerance: 1e-2 - vLLM version: v0.18.0 Signed-off-by: underfituu <hzhucong@163.com>
2026-04-01 16:57:33 +08:00
# use default sparse_mode 0 in normal scenario, which means no mask works on it
return torch_npu.npu_fusion_attention(
query=query,
key=key,
value=value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
)[0]
def reshape_and_cache(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
if len(kv_cache) > 1:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
encoder_decoder = self.attn_type == AttentionType.ENCODER_DECODER
DeviceOperator.reshape_and_cache(
key=key[: attn_metadata.num_actual_tokens] if not encoder_decoder else key,
value=value[: attn_metadata.num_actual_tokens] if not encoder_decoder else value,
key_cache=self.key_cache,
value_cache=self.value_cache,
[Main2Main][Deps][Misc] Upgrade vLLM to v0.15.0 (#6470) ### What this PR does / why we need it? This PR upgrades the vLLM dependency from `v0.14.1` to `v0.15.0`. This involves: - Updating the `VLLM_TAG` in all `Dockerfile`. - Updating the vLLM version in `docs/source/conf.py`. - Removing conditional code paths specific to `v0.14.1` across the codebase, which simplifies maintenance. - Fix `TypeError: MMEncoderAttention.__init__() got an unexpected keyword argument 'multimodal_config'` due to https://github.com/vllm-project/vllm/pull/31972. - Fix `_shared_experts: 'NoneType' object is not callable` due to https://github.com/vllm-project/vllm/pull/32082 by https://github.com/vllm-project/vllm-ascend/pull/6335. - Fix `ReshapeAndCacheOperation setup failed!` due to https://github.com/vllm-project/vllm/pull/25954 by overriding attention metadata slots. This upgrade is necessary to keep the project aligned with the latest features, bug fixes, and API changes in the vLLM project. ### Does this PR introduce _any_ user-facing change? No, this is an internal dependency update and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with these changes, ensuring that all existing tests are successful with the new vLLM version. - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd co-authored-by: shen-shanshan <467638484@qq.com> --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2026-02-02 15:57:55 +08:00
# quick fix to make sure slots is int32 for cross attention case.
# see: https://github.com/vllm-project/vllm/blob/ce88756b967c2c5006746a424c15dd59a284ed8c/vllm/model_executor/layers/attention/cross_attention.py#L117
slot_mapping=slots[: attn_metadata.num_actual_tokens] if not encoder_decoder else slots.to(torch.int32),
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return query, key, value, output
def forward_impl(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
num_tokens = query.shape[0]
if (
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and using_paged_attention(num_tokens, self.vllm_config)
and self.sliding_window is None
):
output = self.forward_paged_attention(query, attn_metadata, output)
else:
output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output)
return output
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError("fused output quantization is not yet supported for AscendAttentionBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
output_padded = None
if key is not None and value is not None:
output_padded = output
query, key, value, output_padded = self.reshape_and_cache(
query, key, value, kv_cache, attn_metadata, output
)
# pooling model branch
[BugFix][v0.18.0][cherry-pick] Fix embedding prefix caching for APC (#7894) ## What this PR does / why we need it? pick-from:https://github.com/vllm-project/vllm-ascend/pull/7452 ### Problem Embedding models produce inconsistent outputs when prefix caching is enabled vs disabled. ### Root Cause The attention router condition was too broad: - All `model_runner_type == "pooling"` → `_forward_encoder_attention()` → uses `npu_fusion_attention` - **But `npu_fusion_attention` does NOT support prefix caching** - Result: Numerical mismatch when KV cache is managed by prefix caching ### Solution Refine the router condition to check causality: **Before**: ``` if attn_metadata.model_runner_type == "pooling": → npu_fusion_attention (no prefix caching support) ``` **After**: ``` if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal: → npu_fusion_attention (for true encoders) else: → npu_fused_infer_attention_score (prefix caching support) ``` ### Changes Made 1. **Fixed router condition** (`vllm_ascend/attention/attention_v1.py` L968) - Added `and not attn_metadata.causal` check - Effect: Non-causal embeddings now use correct operator 2. **Simplified encoder attention** (`vllm_ascend/attention/attention_v1.py` L864-877) - Removed redundant causal branch (encoders never use causal mask) - Reduced from 34 lines to 14 lines 3. **Added test** (`tests/e2e/singlecard/pooling/test_embedding.py`) - Validates embedding outputs with/without prefix caching are consistent ## Does this PR introduce _any_ user-facing change? ### Functional Changes ✅ **Yes** - Bug fix: Embedding models now produce consistent outputs with prefix caching ### API Changes ❌ **No** - All public APIs unchanged ### Configuration Changes ❌ **No** - No new configuration required ### Backward Compatibility ✅ **Fully compatible** - Only fixes incorrect behavior ## How was this patch tested? ### New Test Added `test_embed_models_using_prefix_caching_correctness()`: - Tests: `Qwen3-Embedding-0.6B` - Validates numerical consistency between runs with/without prefix caching - Uses long sequences to activate prefix caching - Tolerance: 1e-2 - vLLM version: v0.18.0 Signed-off-by: underfituu <hzhucong@163.com>
2026-04-01 16:57:33 +08:00
if attn_metadata.model_runner_type == "pooling" and not attn_metadata.causal:
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
if output_padded is not None:
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output_padded)
else:
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output