Upgrade CANN to 8.3.rc1 (#3945)

### What this PR does / why we need it?
This PR upgrade CANN from 8.2rc1 to 8.3rc1 and remove the CANN version
check logic.

TODO: we notice that UT runs failed with CANN 8.3 image. So the base
image for UT is still 8.2. We'll fix it later.


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-11-03 20:21:07 +08:00
committed by GitHub
parent 49d74785c4
commit cc2cd42ad3
39 changed files with 119 additions and 213 deletions

View File

@@ -47,11 +47,10 @@ class AttentionMaskBuilder:
self.attn_mask_cache = attn_mask
self.device = device
self.pooling_mask = None
if torch.version.cann.startswith("8.3"):
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(assigned_mask_dim, assigned_mask_dim),
diagonal=1).to(torch.int8).to(device)
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(assigned_mask_dim, assigned_mask_dim),
diagonal=1).to(torch.int8).to(device)
@staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
@@ -68,7 +67,7 @@ class AttentionMaskBuilder:
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
device: torch.device):
if max_seq_len == 2048 and torch.version.cann.startswith("8.3"):
if max_seq_len == 2048:
return self.chunked_prefill_attn_mask.to(torch.bool)
self._update_attn_cache(max_seq_len, dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
@@ -89,23 +88,7 @@ class AttentionMaskBuilder:
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
if torch.version.cann.startswith("8.3"):
return self.chunked_prefill_attn_mask
else:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(
dtype)
attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)
return self.chunked_prefill_attn_mask
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:

View File

@@ -500,7 +500,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
if torch.version.cann.startswith("8.3") and block_size == 128:
if block_size == 128:
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
key = self.key_cache.view( # type: ignore
@@ -680,43 +680,30 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)
if torch.version.cann.startswith("8.3"):
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
else:
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
@@ -1155,12 +1142,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
query, attn_metadata, output)
# Normal V1 situation.
else:
if torch.version.cann.startswith("8.3"):
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)

View File

@@ -45,8 +45,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if (is_enable_nz() and torch.version.cann.startswith("8.3") and
layer.weight.data.dtype in [torch.float16, torch.bfloat16]):
if (is_enable_nz() and layer.weight.data.dtype
in [torch.float16, torch.bfloat16]):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)

View File

@@ -411,9 +411,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
quant_per_tensor)
# For unquant
if mmrs_fusion and isinstance(
self.layer.quant_method, UnquantizedLinearMethod
) and torch.version.cann.startswith("8.3"):
if mmrs_fusion and isinstance(self.layer.quant_method,
UnquantizedLinearMethod):
output = torch_npu.npu_mm_reduce_scatter_base(
x,
self.layer.weight.t(),
@@ -429,8 +428,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
elif mmrs_fusion and (
isinstance(self.layer.quant_method, AscendLinearMethod)
and isinstance(self.layer.quant_method.quant_method,
AscendW8A8LinearMethod)
) and torch.version.cann.startswith("8.3"):
AscendW8A8LinearMethod)):
if x.dtype != torch.int8:
x_quant = quant_per_tensor(
x, self.layer.aclnn_input_scale_reciprocal,

View File

@@ -367,13 +367,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
use_sparse=self.use_sparse)
if self.pcp_size > 1:
self.attn_mask_builder = None
elif torch.version.cann.startswith("8.3"):
else:
self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
self.device)
else:
self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype)
self._set_up_drafter()
@@ -988,11 +985,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_seq_len = max(seq_lens.max().item(), 0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
elif torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_splitfuse_attn_mask()
else:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
return self.attn_mask_builder.get_splitfuse_attn_mask()
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
@@ -1001,12 +995,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_attn_mask(
2048, self.dtype, self.device)
else:
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
return self.attn_mask_builder.get_attn_mask(
2048, self.dtype, self.device)
# Decode-only situation.
else:
return None