Update torch-npu version to 2.7.1 (#3896)

### What this PR does / why we need it?
Upgrade torch-npu to the official release version 2.7.1


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-10-31 17:16:31 +08:00
committed by GitHub
parent 5f6d1b3323
commit fcc9a0eaeb
15 changed files with 83 additions and 168 deletions

View File

@@ -11,8 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import (enable_sp, has_layer_idx, is_moe_model,
version_check)
from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
@@ -163,9 +162,7 @@ def set_ascend_forward_context(
# this optim now just support dense models due to the specific operators used.
# Once the necessary conditions are met, support for MOE models will also be added.
from vllm_ascend.quantization.quant_config import AscendQuantConfig
model_type_scope = ["llama", "qwen2", "qwen3"]
if version_check():
model_type_scope.append("qwen3_moe")
model_type_scope = ["llama", "qwen2", "qwen3", "qwen3_moe"]
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
vllm_config.model_config.hf_config.model_type in model_type_scope and \
forward_context.layer_idx is not None

View File

@@ -43,7 +43,7 @@ from vllm_ascend.compilation.acl_graph import (get_graph_params,
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec,
prefill_context_parallel_enable, version_check,
prefill_context_parallel_enable,
weak_ref_tensors)
# isort: off
@@ -436,7 +436,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.torch_npu_check = version_check()
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
@@ -581,22 +580,21 @@ class AscendAttentionBackendImpl(AttentionImpl):
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
if self.torch_npu_check:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(
num_tokens, weak_ref_tensors(workspace))
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
@@ -618,30 +616,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
))
torch.npu.graph_task_group_begin(stream)
if self.torch_npu_check:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:

View File

@@ -19,8 +19,6 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm_ascend.utils import version_check
from ..utils import weak_ref_tensors
@@ -214,32 +212,20 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
torch_npu_check = version_check()
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
if torch_npu_check:
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=graph_params.workspaces.get(runtime_shape))
else:
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=graph_params.workspaces.get(runtime_shape))
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)

View File

@@ -22,8 +22,6 @@ from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm_ascend.utils import version_check
def _addrmsnorm_forward_oot(
self,
@@ -36,7 +34,6 @@ def _addrmsnorm_forward_oot(
from vllm_ascend.utils import is_310p
torch_npu_check = version_check()
if layer is not None and not is_310p():
layer_cls_name = layer.__class__.__name__
try:
@@ -53,23 +50,15 @@ def _addrmsnorm_forward_oot(
start_flag=x,
)
# add_rms_norm_quant
if torch_npu_check:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
beta=bias,
epsilon=self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
beta=bias,
epsilon=self.variance_epsilon)
# prefetch qkvo_proj.weight postprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
@@ -87,7 +76,7 @@ def _addrmsnorm_forward_oot(
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if torch_npu_check and bias is not None:
if bias is not None:
x.add_(bias)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
@@ -106,9 +95,8 @@ class AscendRMSNorm(RMSNorm):
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
vllm_config = get_current_vllm_config()
self.bias = None
self.torch_npu_check = version_check()
# quantization with anti_method m4 will generate none-zero norm bias
if self.torch_npu_check and vllm_config.quant_config is not None and \
if vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
@@ -128,7 +116,7 @@ class AscendRMSNorm(RMSNorm):
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
if self.torch_npu_check and self.bias is not None:
if self.bias is not None:
x.add_(self.bias)
return x

View File

@@ -7,7 +7,6 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import WeightPrefetchConfig
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.utils import version_check
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
MOE_PREFETCH_TOKEN_THRESHOLD = 96
@@ -83,8 +82,7 @@ class WeightPrefetchMethod:
if not self.moe.is_active_this_forward:
return
forward_context = get_forward_context()
if not version_check():
forward_context.layer_idx += 1
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
weight = forward_context.model_instance.model.layers[
forward_context.layer_idx - 1].mlp.experts.w13_weight
weight_size = weight.data.element_size() * weight.data.numel(

View File

@@ -551,8 +551,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
AscendQuantRMSNorm, AscendRMSNorm)
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
@@ -586,12 +585,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"FusedMoE": AscendFusedMoE,
"SharedFusedMoE": AscendSharedFusedMoE,
}
if vllm_config is not None and \
vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()) and \
not version_check():
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
mla_to_register = "MultiHeadLatentAttention" if vllm_version_is(
"0.11.0") else "MultiHeadLatentAttentionWrapper"
if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla:
@@ -791,21 +784,6 @@ def is_hierarchical_communication_enabled():
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
@functools.cache
def version_check():
"""check if torch_npu version >= dev20250919"""
import re # noqa
torch_npu_version = torch_npu.version.__version__
date_pattern = r'dev(\d{8})'
match = re.search(date_pattern, torch_npu_version)
if match:
full_date = match.group(1)
if full_date >= "20250919":
return True
return False
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
if model_instance is None:
return False