Update torch-npu version to 2.7.1 (#3896)

### What this PR does / why we need it?
Upgrade torch-npu to the official release version 2.7.1


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-10-31 17:16:31 +08:00
committed by GitHub
parent 5f6d1b3323
commit fcc9a0eaeb
15 changed files with 83 additions and 168 deletions

View File

@@ -43,7 +43,7 @@ from vllm_ascend.compilation.acl_graph import (get_graph_params,
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec,
prefill_context_parallel_enable, version_check,
prefill_context_parallel_enable,
weak_ref_tensors)
# isort: off
@@ -436,7 +436,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.torch_npu_check = version_check()
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
@@ -581,22 +580,21 @@ class AscendAttentionBackendImpl(AttentionImpl):
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
if self.torch_npu_check:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(
num_tokens, weak_ref_tensors(workspace))
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
@@ -618,30 +616,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
))
torch.npu.graph_task_group_begin(stream)
if self.torch_npu_check:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else: