From 00aa0bf33e72aa376501195b9df4dcfa11c41b2e Mon Sep 17 00:00:00 2001 From: shiyuan680 <72335504+shiyuan680@users.noreply.github.com> Date: Mon, 27 Oct 2025 19:41:07 +0800 Subject: [PATCH] support prefill cache mode use fia op (#3696) ### What this PR does / why we need it? support prefill cache mode use fia op for full graph ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 origin ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 131.63 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 466.77 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 932.95 ---------------Time to First Token---------------- Mean TTFT (ms): 125.17 Median TTFT (ms): 121.51 P50 TTFT (ms): 121.51 P90 TTFT (ms): 140.91 P99 TTFT (ms): 182.36 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.85 Median TPOT (ms): 43.84 P50 TPOT (ms): 43.84 P90 TPOT (ms): 44.28 P99 TPOT (ms): 44.32 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.85 Median ITL (ms): 42.63 P50 ITL (ms): 42.63 P90 ITL (ms): 48.74 P99 ITL (ms): 59.62 ================================================== after ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 130.10 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 472.26 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 943.94 ---------------Time to First Token---------------- Mean TTFT (ms): 123.69 Median TTFT (ms): 122.51 P50 TTFT (ms): 122.51 P90 TTFT (ms): 143.69 P99 TTFT (ms): 165.00 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.07 Median TPOT (ms): 43.13 P50 TPOT (ms): 43.13 P90 TPOT (ms): 43.50 P99 TPOT (ms): 43.57 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.07 Median ITL (ms): 41.81 P50 ITL (ms): 41.81 P90 ITL (ms): 48.11 P99 ITL (ms): 62.13 ================================================== Signed-off-by: shiyuan680 <917935075@qq.com> --- vllm_ascend/attention/attention_mask.py | 2 + vllm_ascend/attention/attention_v1.py | 49 +++++++++++++++++++------ vllm_ascend/worker/model_runner_v1.py | 8 +++- 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index b1da7232..5a94b82b 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -68,6 +68,8 @@ class AttentionMaskBuilder: def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, device: torch.device): + if max_seq_len == 2048 and torch.version.cann.startswith("8.3"): + return self.chunked_prefill_attn_mask.to(torch.bool) self._update_attn_cache(max_seq_len, dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( ).to(device, non_blocking=True) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 62bca309..07ef6d91 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -491,19 +491,44 @@ class AscendAttentionBackendImpl(AttentionImpl): compress_mask = attn_metadata.attn_mask batch_size = attn_metadata.query_lens.shape[0] block_table = attn_metadata.block_tables[:batch_size, :] + num_block, block_size, _, _ = self.key_cache.shape # type: ignore - torch_npu._npu_flash_attention_qlens( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - block_table=block_table, - mask=compress_mask, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + if torch.version.cann.startswith("8.3") and block_size == 128: + # TODO:The npu_fused_infer_attention_score op is planned to + # be utilized in a wider range in upcoming versions. + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=compress_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=attn_metadata.seq_lens_list, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + else: + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=block_table, + mask=compress_mask, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) return output def _forward_decode_only( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f30a9a39..bf0465c8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -962,8 +962,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_seq_len, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: - return self.attn_mask_builder.get_attn_mask( - 128, self.dtype, self.device) + if torch.version.cann.startswith("8.3"): + return self.attn_mask_builder.get_attn_mask( + 2048, self.dtype, self.device) + else: + return self.attn_mask_builder.get_attn_mask( + 128, self.dtype, self.device) # Decode-only situation. else: return None