[Perf] Add new npu_fused_infer_attention_score op to improve perfomance in splitfuse cases and resolve long-seq mask problems (#2962)

### What this PR does / why we need it?
Add new npu_fused_infer_attention_score op to improve perfomance in
splitfuse cases and resolve long-seq mask problems .

1. The original op's performance is suboptimal in certain scenarios,
necessitating optimization through the _new op_
(npu_fused_infer_attention_score)。
2. For ultra-long sequences (128k), the original operator will allocate
a large attn_mask, which consumes excessive CPU memory. In contrast, the
_new op_ supports a fixed-size compressed mask, effectively resolving
this issue.

NOTE1: The current PR retains the original logic and uses a version
check of the CANN package to determine whether the _new op_ can be
enabled. This ensures no impact on existing users. In future versions,
this version check and the original logic will be deprecated, and the
_new op_ scheduling will be uniformly adopted.
NOTE2: This pr relies on future CANN version, which is not available
now.
NOTE3: To enable the new op in chunked prefill, the parameter
additional_config should be set like `--additional-config
'{"ascend_scheduler_config":
{"enabled":true,"enable_chunked_prefill":true}}' \` at least.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passed




- vLLM version: v0.10.2
- vLLM main:
6c5f82e5aa

---------

Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
This commit is contained in:
tianyitang
2025-09-22 14:56:14 +08:00
committed by GitHub
parent c90a6d3658
commit f1f2c8f5e5
3 changed files with 88 additions and 34 deletions

View File

@@ -39,11 +39,22 @@ class AttentionMaskBuilder:
self, self,
max_seq_len: int, max_seq_len: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device = None,
): ):
# NOTE: The device argument specifies the target NPU
# to be used for the newly added FIA operator.
# Only pass this parameter when using the new FIA operator.
attn_mask = _generate_attn_mask(max_seq_len, dtype) attn_mask = _generate_attn_mask(max_seq_len, dtype)
self._seq_len_cached = attn_mask.shape[0] self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask self.attn_mask_cache = attn_mask
self.device = device
if torch.version.cann.startswith("8.3"):
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(assigned_mask_dim, assigned_mask_dim),
diagonal=1).to(torch.int8).to(device)
@staticmethod @staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16): def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
@@ -66,24 +77,28 @@ class AttentionMaskBuilder:
def get_splitfuse_attn_mask( def get_splitfuse_attn_mask(
self, self,
seq_lens: torch.Tensor, seq_lens: torch.Tensor = None,
position: torch.Tensor, position: torch.Tensor = None,
dtype: torch.dtype, dtype: torch.dtype = None,
device: torch.device, device: torch.device = None,
) -> torch.Tensor: ) -> torch.Tensor:
if dtype not in [torch.float16, torch.bfloat16]: if torch.version.cann.startswith("8.3"):
raise ValueError( return self.chunked_prefill_attn_mask
"splitfuse_attn_mask now only supports bf16 and fp16") else:
max_seq_len = max(seq_lens, default=0) if dtype not in [torch.float16, torch.bfloat16]:
self._update_attn_cache(max_seq_len, dtype) raise ValueError(
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation "splitfuse_attn_mask now only supports bf16 and fp16")
# is not the same. Fix this in the future when kernel is ready. max_seq_len = max(seq_lens, default=0)
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype) self._update_attn_cache(max_seq_len, dtype)
attn_mask = torch.index_select(self.attn_mask_cache, # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
dim=0, # is not the same. Fix this in the future when kernel is ready.
index=position)[:, :max_seq_len] mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(
attn_mask *= mask_scale_factor dtype)
return attn_mask.contiguous().to(device, non_blocking=True) attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype): def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached: if seqlen > self._seq_len_cached:

View File

@@ -456,18 +456,43 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata.seq_lens = \ attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device) attn_metadata.seq_lens.to(device=query.device)
torch_npu._npu_paged_attention_splitfuse( if torch.version.cann.startswith("8.3"):
query=query, # TODO:The npu_fused_infer_attention_score op is planned to
key_cache=self.key_cache, # be utilized in a wider range in upcoming versions.
value_cache=self.value_cache, num_block, block_size, _, _ = self.key_cache.shape # type: ignore
mask=attn_metadata.attn_mask, key = self.key_cache.view( # type: ignore
block_table=attn_metadata.block_tables, num_block, block_size, -1)
seq_len=attn_metadata.query_lens, value = self.value_cache.view( # type: ignore
context_lens=attn_metadata.seq_lens, num_block, block_size, -1)
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads, output, _ = torch_npu.npu_fused_infer_attention_score(
scale_value=self.scale, query=query,
out=output) key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.query_start_loc[1:],
actual_seq_lengths_kv=attn_metadata.seq_lens,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
else:
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output return output
def forward( def forward(
@@ -561,12 +586,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
output) output)
# Normal V1 situation. # Normal V1 situation.
else: else:
if torch.version.cann.startswith("8.3"):
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
output = self._forward_v1_style(query, attn_metadata, output) output = self._forward_v1_style(query, attn_metadata, output)
# to make in-place change to the output tensor # to make in-place change to the output tensor
if hasattr(layer, 'quant_method') and use_kv_cache_int8: if hasattr(layer, 'quant_method') and use_kv_cache_int8:
output = output.view(num_tokens, self.num_heads, self.head_size) output = output.view(num_tokens, self.num_heads, self.head_size)
ori_output[:, :, :] = output[:num_tokens, :, :] ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
return output.view(num_tokens, self.hidden_size) return output.view(num_tokens, self.hidden_size)

View File

@@ -301,8 +301,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
use_mla=self.model_config.use_mla, use_mla=self.model_config.use_mla,
) )
self.attn_mask_builder = AttentionMaskBuilder( if torch.version.cann.startswith("8.3"):
self.model_config.max_model_len, self.dtype) self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
self.device)
else:
self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype)
# Set up speculative decoding. # Set up speculative decoding.
self.spec_attn_mask = None self.spec_attn_mask = None
@@ -860,8 +865,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_state) -> torch.Tensor: attn_state) -> torch.Tensor:
# Chunk Prefill situation. # Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla: if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask( if torch.version.cann.startswith("8.3"):
seq_lens, position, self.dtype, self.device) return self.attn_mask_builder.get_splitfuse_attn_mask()
else:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
# Prefill without cache situation. # Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache: elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0) max_seq_len = max(seq_lens, default=0)