2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
2025-04-17 19:31:50 +08:00
|
|
|
from enum import Enum
|
2025-07-26 15:43:29 +08:00
|
|
|
from typing import List, Optional, Tuple, Type
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch_npu
|
|
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
|
|
|
AttentionLayer, AttentionType)
|
|
|
|
|
from vllm.attention.backends.utils import CommonAttentionState
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
|
|
|
from vllm.utils import direct_register_custom_op
|
2025-04-19 17:38:18 +08:00
|
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
|
|
|
|
|
|
|
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
2025-07-26 15:43:29 +08:00
|
|
|
nd_to_nz_2d, nd_to_nz_spec)
|
2025-07-28 16:01:59 +08:00
|
|
|
from vllm_ascend.worker.npu_input_batch import InputBatch
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendAttentionBackend(AttentionBackend):
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
accept_output_buffer: bool = True
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_name() -> str:
|
|
|
|
|
return "ASCEND"
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
|
|
|
|
return AscendAttentionBackendImpl
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_metadata_cls() -> Type["AscendMetadata"]:
|
|
|
|
|
return AscendMetadata
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
|
|
|
|
return CommonAttentionState
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
|
|
|
|
return AscendAttentionMetadataBuilder
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_kv_cache_shape(
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
) -> Tuple[int, ...]:
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
if is_310p():
|
|
|
|
|
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
|
|
|
|
16)
|
2025-04-17 19:31:50 +08:00
|
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-06-28 18:51:07 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_bsh_kv_cache_shape(
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
) -> Tuple[int, ...]:
|
|
|
|
|
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def swap_blocks(
|
|
|
|
|
src_kv_cache: List[torch.Tensor],
|
|
|
|
|
dst_kv_cache: List[torch.Tensor],
|
|
|
|
|
src_to_dst: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
|
|
|
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
|
|
|
src_indices = src_to_dst[:, 0]
|
|
|
|
|
dst_indices = src_to_dst[:, 1]
|
|
|
|
|
|
|
|
|
|
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
|
|
|
|
dst_key_cache.device)
|
|
|
|
|
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
|
|
|
|
dst_key_cache.device)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def copy_blocks(
|
|
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
|
|
src_to_dists: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
src_indices = src_to_dists[:, 0]
|
|
|
|
|
dst_indices = src_to_dists[:, 1]
|
|
|
|
|
|
|
|
|
|
for kv_cache in kv_caches:
|
|
|
|
|
key_caches = kv_cache[0]
|
|
|
|
|
value_caches = kv_cache[1]
|
|
|
|
|
key_caches[dst_indices] = key_caches[src_indices]
|
|
|
|
|
value_caches[dst_indices] = value_caches[src_indices]
|
|
|
|
|
|
|
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
class AscendAttentionState(Enum):
|
2025-05-09 16:39:28 +08:00
|
|
|
PrefillNoCache = 0
|
|
|
|
|
PrefillCacheHit = 1
|
|
|
|
|
DecodeOnly = 2
|
|
|
|
|
ChunkedPrefill = 3
|
2025-06-09 22:21:42 +08:00
|
|
|
SpecDecoding = 4
|
2025-04-17 19:31:50 +08:00
|
|
|
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
@dataclass
|
|
|
|
|
class AscendMetadata:
|
2025-07-28 14:06:20 +08:00
|
|
|
|
2025-07-24 19:31:36 +08:00
|
|
|
# **************************** Basic Properties ****************************
|
|
|
|
|
attn_mask: Optional[torch.Tensor] = None
|
2025-04-17 19:31:50 +08:00
|
|
|
# Current state of this attention run.
|
|
|
|
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
2025-07-24 19:31:36 +08:00
|
|
|
|
|
|
|
|
# Number of tokens excluding padding.
|
|
|
|
|
num_actual_tokens: int = 0
|
|
|
|
|
|
|
|
|
|
# The sequence length per sequence. Sequence length means the computed
|
|
|
|
|
# tokens + new tokens (is None if it is a decoding).
|
|
|
|
|
# (batch_size,)
|
|
|
|
|
seq_lens: torch.Tensor = None
|
|
|
|
|
|
|
|
|
|
query_start_loc: torch.Tensor = None
|
|
|
|
|
query_lens: torch.Tensor = None
|
|
|
|
|
# Maximum query length in the batch (None for decoding).
|
|
|
|
|
max_query_len: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
# ********************** KV Cache Related Properties ***********************
|
|
|
|
|
# Block addresses per sequence (Seq id -> list of physical block).
|
|
|
|
|
# (batch_size, max_blocks_per_seq)
|
|
|
|
|
block_tables: torch.Tensor = None
|
|
|
|
|
|
|
|
|
|
# The indices of the token slots that input tokens will be stored into.
|
|
|
|
|
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
|
|
|
|
|
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
|
|
|
|
|
# and 1st slot in block 1, respectively.
|
|
|
|
|
# (num_tokens,)
|
|
|
|
|
slot_mapping: torch.Tensor = None
|
|
|
|
|
|
2025-08-01 09:08:45 +08:00
|
|
|
enable_dbo_across_dp: bool = False
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
class AscendAttentionMetadataBuilder:
|
|
|
|
|
|
|
|
|
|
def __init__(self, runner):
|
|
|
|
|
self.runner = runner
|
|
|
|
|
|
|
|
|
|
def reorder_batch(self, input_batch: "InputBatch",
|
|
|
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
|
|
|
return False
|
|
|
|
|
|
2025-08-01 09:08:45 +08:00
|
|
|
def build(self,
|
|
|
|
|
num_reqs,
|
|
|
|
|
num_actual_tokens,
|
|
|
|
|
max_query_len,
|
|
|
|
|
enable_dbo_across_dp: bool = False):
|
2025-05-28 21:18:41 +08:00
|
|
|
|
|
|
|
|
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
|
|
|
|
)
|
|
|
|
|
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
|
|
|
|
|
block_table[:num_reqs])
|
2025-05-16 12:14:55 +08:00
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
query_lens = self.runner.query_lens
|
|
|
|
|
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
|
|
|
|
|
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
|
|
|
|
self.runner.device, non_blocking=True)
|
|
|
|
|
attn_mask = self.runner.attn_mask
|
|
|
|
|
attn_state = self.runner.attn_state
|
Spec decode support for V1 Engine (#874)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
Make spec decode support for V1 Engine
- Currently, Ascend does not support the triton kernel. PyTorch is used
to rewrite the `rejection_sampler.py` triton kernel. However, PyTorch is
not as good as Triton. Therefore, ascend c is used to implement the
function in the future.
- Currently, spec decode supports only the ngram algorithm. The eagle
algorithm needs to be further adapted.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
Not change user facing.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
test by `tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py` and
`tests/sample/test_rejection_sampler.py`, test base function of
rejection sampler and e2e function of spec decode.
Signed-off-by: ponix-j <657511300@qq.com>
2025-05-23 14:25:46 +08:00
|
|
|
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
|
|
|
|
|
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
|
|
|
|
non_blocking=True)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
if is_310p():
|
|
|
|
|
if attn_state == AscendAttentionState.PrefillNoCache:
|
|
|
|
|
mask_nz = nd_to_nz_2d(attn_mask)
|
|
|
|
|
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
|
|
|
|
ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
|
elif attn_state == AscendAttentionState.ChunkedPrefill:
|
|
|
|
|
mask_nz = nd_to_nz_spec(attn_mask)
|
|
|
|
|
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
|
|
|
|
ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
|
|
2025-08-01 09:08:45 +08:00
|
|
|
attn_metadata = AscendMetadata(
|
|
|
|
|
num_actual_tokens=num_actual_tokens,
|
|
|
|
|
block_tables=block_table,
|
|
|
|
|
query_start_loc=query_start_loc,
|
|
|
|
|
query_lens=query_lens,
|
|
|
|
|
seq_lens=seq_lens,
|
|
|
|
|
max_query_len=max_query_len,
|
|
|
|
|
slot_mapping=slot_mapping,
|
|
|
|
|
attn_mask=attn_mask,
|
|
|
|
|
attn_state=attn_state,
|
|
|
|
|
enable_dbo_across_dp=enable_dbo_across_dp)
|
2025-04-19 17:38:18 +08:00
|
|
|
return attn_metadata
|
|
|
|
|
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
class AscendAttentionBackendImpl(AttentionImpl):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
scale: float,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
alibi_slopes: Optional[List[float]],
|
|
|
|
|
sliding_window: Optional[int],
|
|
|
|
|
kv_cache_dtype: str,
|
2025-07-24 10:23:34 +08:00
|
|
|
logits_soft_cap: Optional[float],
|
|
|
|
|
attn_type: str,
|
|
|
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
|
|
|
**kwargs,
|
2025-03-20 19:34:44 +08:00
|
|
|
) -> None:
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.head_size = head_size
|
|
|
|
|
self.scale = float(scale)
|
|
|
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
|
|
|
self.hidden_size = self.num_heads * self.head_size
|
|
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
|
self.sliding_window = sliding_window
|
|
|
|
|
if alibi_slopes is not None:
|
|
|
|
|
alibi_slopes = torch.tensor(alibi_slopes,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device="npu")
|
|
|
|
|
self.alibi_slopes = alibi_slopes
|
|
|
|
|
self.attn_type = attn_type
|
|
|
|
|
|
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
|
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
2025-04-17 19:31:50 +08:00
|
|
|
self.key_cache = None
|
|
|
|
|
self.value_cache = None
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
layer: AttentionLayer,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
2025-07-26 17:15:47 +08:00
|
|
|
kv_cache: Tuple[torch.Tensor],
|
2025-03-20 19:34:44 +08:00
|
|
|
attn_metadata: AscendMetadata,
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
trace_flag: bool = True,
|
2025-03-20 19:34:44 +08:00
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Forward pass with Ascend attention.
|
|
|
|
|
Args:
|
|
|
|
|
query: shape = [batch_size, seq_len, num_heads * head_size]
|
|
|
|
|
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
|
|
|
|
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
2025-07-26 17:15:47 +08:00
|
|
|
kv_cache: shape = [key_cache, value_cache]
|
2025-03-20 19:34:44 +08:00
|
|
|
key_cache = [num_blocks, block_size,
|
2025-05-20 09:31:30 +08:00
|
|
|
num_kv_heads, head_size]
|
2025-03-20 19:34:44 +08:00
|
|
|
value_cache = [num_blocks, block_size,
|
2025-05-20 09:31:30 +08:00
|
|
|
num_kv_heads, head_size]
|
2025-03-20 19:34:44 +08:00
|
|
|
attn_metadata: Metadata for attention.
|
|
|
|
|
Returns:
|
|
|
|
|
shape = [batch_size * seq_len, num_heads, head_size]
|
|
|
|
|
"""
|
|
|
|
|
num_tokens = query.shape[0]
|
2025-07-26 17:15:47 +08:00
|
|
|
use_kv_cache_int8 = len(
|
|
|
|
|
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
if output is None:
|
|
|
|
|
output = torch.empty(num_tokens,
|
|
|
|
|
self.num_heads,
|
|
|
|
|
self.head_size,
|
|
|
|
|
dtype=query.dtype,
|
|
|
|
|
device=query.device)
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
ori_output = output
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
if trace_flag:
|
|
|
|
|
torch.ops.vllm.unified_ascend_attention_with_output(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
output=output,
|
|
|
|
|
layer_name=layer.layer_name)
|
2025-06-28 18:51:07 +08:00
|
|
|
|
2025-07-02 16:40:51 +08:00
|
|
|
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
2025-06-28 18:51:07 +08:00
|
|
|
output = layer.quant_method.apply(layer, query, key, value,
|
|
|
|
|
kv_cache, attn_metadata,
|
|
|
|
|
self.attn_type, self.scale,
|
|
|
|
|
output)
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
else:
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
if attn_metadata is None:
|
|
|
|
|
return output.view(num_tokens, self.hidden_size)
|
2025-05-12 20:26:22 +08:00
|
|
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
|
|
|
|
attn_type = self.attn_type
|
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
|
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
|
|
|
"encoder/decoder cross-attention "
|
|
|
|
|
"are not implemented for "
|
|
|
|
|
"PallasAttentionBackendImpl")
|
|
|
|
|
# View q k v to BSH.
|
|
|
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|
|
|
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
# TODO: Remove this contiguous in the future.
|
|
|
|
|
value = value.contiguous()
|
|
|
|
|
|
2025-07-26 17:15:47 +08:00
|
|
|
if len(kv_cache) > 1:
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
if self.key_cache is None:
|
|
|
|
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
|
|
|
|
slots = attn_metadata.slot_mapping
|
2025-05-12 20:26:22 +08:00
|
|
|
torch_npu._npu_reshape_and_cache(
|
|
|
|
|
key=key[:num_actual_tokens],
|
|
|
|
|
value=value[:num_actual_tokens],
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
slot_indices=slots)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
|
|
|
|
|
# V0-Style scheduler situation.
|
2025-06-28 18:51:07 +08:00
|
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
assert attn_metadata is not None
|
|
|
|
|
assert attn_metadata.attn_mask is not None
|
|
|
|
|
mask = attn_metadata.attn_mask
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
if is_310p():
|
|
|
|
|
# align q k v output tensors
|
|
|
|
|
query = aligned_16(query)
|
|
|
|
|
key = aligned_16(key)
|
|
|
|
|
value = aligned_16(value)
|
|
|
|
|
output = aligned_16(output)
|
|
|
|
|
|
|
|
|
|
# do reformat in case of broadcasted tensors
|
|
|
|
|
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
|
|
|
|
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
|
|
|
|
ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
|
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
torch_npu._npu_flash_attention(query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
mask=mask,
|
|
|
|
|
seq_len=attn_metadata.seq_lens,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
out=output)
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
output = output[:num_tokens, :, :]
|
2025-05-09 16:39:28 +08:00
|
|
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
|
|
|
|
assert attn_metadata is not None
|
|
|
|
|
assert attn_metadata.attn_mask is not None
|
|
|
|
|
compress_mask = attn_metadata.attn_mask
|
2025-06-30 11:25:19 +08:00
|
|
|
batch_size = attn_metadata.query_lens.shape[0]
|
|
|
|
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
2025-05-09 16:39:28 +08:00
|
|
|
torch_npu._npu_flash_attention_qlens(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
2025-06-30 11:25:19 +08:00
|
|
|
block_table=block_table,
|
2025-05-09 16:39:28 +08:00
|
|
|
mask=compress_mask,
|
|
|
|
|
seq_len=attn_metadata.query_lens,
|
|
|
|
|
context_lens=attn_metadata.seq_lens,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
out=output)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
if is_310p():
|
|
|
|
|
# # seq_lens_tensor needs to be transferred to the device for 310P
|
|
|
|
|
attn_metadata.seq_lens = \
|
|
|
|
|
attn_metadata.seq_lens.to(device=query.device)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
torch_npu._npu_paged_attention(
|
2025-04-19 17:38:18 +08:00
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
2025-05-09 16:39:28 +08:00
|
|
|
block_table=attn_metadata.block_tables,
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
context_lens=attn_metadata.seq_lens,
|
2025-04-19 17:38:18 +08:00
|
|
|
out=output)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
# Normal V1 situation.
|
|
|
|
|
else:
|
|
|
|
|
# use chunked prefill for head size 192 scenario, like deepseek
|
|
|
|
|
# paged_attention_splitfuse maybe crash at such scenario
|
|
|
|
|
# TODO: vanilla path will be removed after the kernel support
|
|
|
|
|
# head_size 192 scenario
|
|
|
|
|
if self.head_size == 192:
|
|
|
|
|
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
|
|
|
|
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
2025-07-23 14:52:52 +08:00
|
|
|
cu_seqlen_q = torch.tensor(cu_seqlen_q,
|
|
|
|
|
device=query.device)
|
|
|
|
|
cu_seqlen_k = torch.tensor(cu_seqlen_k,
|
|
|
|
|
device=query.device)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
|
|
|
|
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
|
|
|
|
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
|
|
|
|
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
|
|
|
|
vanilla_chunked_prefill(output, query, self.key_cache,
|
|
|
|
|
self.value_cache,
|
|
|
|
|
attn_metadata.block_tables,
|
|
|
|
|
cu_seqlen_q, cu_seqlen_k,
|
|
|
|
|
max_seqlen_q, max_seqlen_k,
|
|
|
|
|
self.scale, None, True)
|
|
|
|
|
else:
|
|
|
|
|
# use paged attention
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
assert attn_metadata is not None
|
|
|
|
|
assert attn_metadata.attn_mask is not None
|
|
|
|
|
if is_310p():
|
|
|
|
|
# do reformat in case of broadcasted tensors
|
|
|
|
|
attn_metadata.attn_mask = \
|
|
|
|
|
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
|
attn_metadata.seq_lens = \
|
|
|
|
|
attn_metadata.seq_lens.to(device=query.device)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
torch_npu._npu_paged_attention_splitfuse(
|
|
|
|
|
query=query,
|
|
|
|
|
key_cache=self.key_cache,
|
|
|
|
|
value_cache=self.value_cache,
|
|
|
|
|
mask=attn_metadata.attn_mask,
|
|
|
|
|
block_table=attn_metadata.block_tables,
|
|
|
|
|
seq_len=attn_metadata.query_lens,
|
|
|
|
|
context_lens=attn_metadata.seq_lens,
|
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
scale_value=self.scale,
|
|
|
|
|
out=output)
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
|
|
|
|
|
# to make in-place change to the output tensor
|
2025-07-02 16:40:51 +08:00
|
|
|
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
2025-06-28 18:51:07 +08:00
|
|
|
output = output.view(num_tokens, self.num_heads, self.head_size)
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
ori_output[:, :, :] = output[:num_tokens, :, :]
|
2025-03-20 19:34:44 +08:00
|
|
|
return output.view(num_tokens, self.hidden_size)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def unified_ascend_attention_with_output(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
layer_name: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
forward_context: ForwardContext = get_forward_context()
|
|
|
|
|
attn_metadata = forward_context.attn_metadata
|
|
|
|
|
self = forward_context.no_compile_layers[layer_name]
|
|
|
|
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
|
|
|
self.impl.forward(self,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
kv_cache,
|
|
|
|
|
attn_metadata,
|
|
|
|
|
output,
|
|
|
|
|
trace_flag=False)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unified_attention_with_output_fake(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
layer_name: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
direct_register_custom_op(
|
|
|
|
|
op_name="unified_ascend_attention_with_output",
|
|
|
|
|
op_func=unified_ascend_attention_with_output,
|
|
|
|
|
mutates_args=["output"],
|
|
|
|
|
fake_impl=unified_attention_with_output_fake,
|
|
|
|
|
dispatch_key="PrivateUse1",
|
|
|
|
|
)
|