[Perf] Add new npu_fused_infer_attention_score op to improve perfomance in splitfuse cases and resolve long-seq mask problems (#2962)
### What this PR does / why we need it?
Add new npu_fused_infer_attention_score op to improve perfomance in
splitfuse cases and resolve long-seq mask problems .
1. The original op's performance is suboptimal in certain scenarios,
necessitating optimization through the _new op_
(npu_fused_infer_attention_score)。
2. For ultra-long sequences (128k), the original operator will allocate
a large attn_mask, which consumes excessive CPU memory. In contrast, the
_new op_ supports a fixed-size compressed mask, effectively resolving
this issue.
NOTE1: The current PR retains the original logic and uses a version
check of the CANN package to determine whether the _new op_ can be
enabled. This ensures no impact on existing users. In future versions,
this version check and the original logic will be deprecated, and the
_new op_ scheduling will be uniformly adopted.
NOTE2: This pr relies on future CANN version, which is not available
now.
NOTE3: To enable the new op in chunked prefill, the parameter
additional_config should be set like `--additional-config
'{"ascend_scheduler_config":
{"enabled":true,"enable_chunked_prefill":true}}' \` at least.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passed
- vLLM version: v0.10.2
- vLLM main:
6c5f82e5aa
---------
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -456,18 +456,43 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
# TODO:The npu_fused_infer_attention_score op is planned to
|
||||
# be utilized in a wider range in upcoming versions.
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.query_start_loc[1:],
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
else:
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
@@ -561,12 +586,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
# npu_fused_infer_attention_score does not support cases
|
||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
||||
# Thus we need unpad it here.
|
||||
num_tokens = attn_metadata.query_start_loc[-1]
|
||||
query = query[:num_tokens]
|
||||
output = self._forward_v1_style(query, attn_metadata, output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user