[cherry-pick]Upgrade CANN to 8.3.rc1 (#3945) (#3962)

This PR upgrade CANN from 8.2rc1 to 8.3rc1 and remove the CANN version
check logic.

TODO: we notice that UT runs failed with CANN 8.3 image. So the base
image for UT is still 8.2. We'll fix it later.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-11-06 09:05:08 +08:00
committed by GitHub
parent 66b67f9cf2
commit 7ee0b0b5d8
36 changed files with 104 additions and 192 deletions

View File

@@ -47,11 +47,10 @@ class AttentionMaskBuilder:
self.attn_mask_cache = attn_mask
self.device = device
self.pooling_mask = None
if torch.version.cann.startswith("8.3"):
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(assigned_mask_dim, assigned_mask_dim),
diagonal=1).to(torch.int8).to(device)
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(assigned_mask_dim, assigned_mask_dim),
diagonal=1).to(torch.int8).to(device)
@staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
@@ -87,23 +86,7 @@ class AttentionMaskBuilder:
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
if torch.version.cann.startswith("8.3"):
return self.chunked_prefill_attn_mask
else:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(
dtype)
attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)
return self.chunked_prefill_attn_mask
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:

View File

@@ -528,43 +528,30 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)
if torch.version.cann.startswith("8.3"):
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
else:
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output
def forward(
@@ -673,12 +660,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
output)
# Normal V1 situation.
else:
if torch.version.cann.startswith("8.3"):
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
output = self._forward_v1_style(query, attn_metadata, output)
# to make in-place change to the output tensor

View File

@@ -45,8 +45,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if (is_enable_nz() and torch.version.cann.startswith("8.3") and
layer.weight.data.dtype in [torch.float16, torch.bfloat16]):
if (is_enable_nz() and layer.weight.data.dtype
in [torch.float16, torch.bfloat16]):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)

View File

@@ -411,9 +411,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
quant_per_tensor)
# For unquant
if mmrs_fusion and isinstance(
self.layer.quant_method, UnquantizedLinearMethod
) and torch.version.cann.startswith("8.3"):
if mmrs_fusion and isinstance(self.layer.quant_method,
UnquantizedLinearMethod):
output = torch_npu.npu_mm_reduce_scatter_base(
x,
self.layer.weight.t(),
@@ -429,8 +428,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
elif mmrs_fusion and (
isinstance(self.layer.quant_method, AscendLinearMethod)
and isinstance(self.layer.quant_method.quant_method,
AscendW8A8LinearMethod)
) and torch.version.cann.startswith("8.3"):
AscendW8A8LinearMethod)):
if x.dtype != torch.int8:
x_quant = quant_per_tensor(
x, self.layer.aclnn_input_scale_reciprocal,

View File

@@ -319,13 +319,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse)
if torch.version.cann.startswith("8.3"):
self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
self.device)
else:
self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype)
self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
self.device)
# Set up speculative decoding.
self.spec_attn_mask = None
@@ -899,11 +895,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return self.attn_mask_builder.get_pooling_mask(self.device)
# Chunk Prefill situation.
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_splitfuse_attn_mask()
else:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
return self.attn_mask_builder.get_splitfuse_attn_mask()
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache: