support prefill cache mode use fia op (#3696)

### What this PR does / why we need it?
support prefill cache mode use fia op for full graph
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

origin
============ Serving Benchmark Result ============
Successful requests:                     30
Maximum request concurrency:             256
Request rate configured (RPS):           0.70
Benchmark duration (s):                  131.63
Total input tokens:                      61363
Total generated tokens:                  61440
Request throughput (req/s):              0.23
Output token throughput (tok/s):         466.77
Peak output token throughput (tok/s):    750.00
Peak concurrent requests:                30.00
Total Token throughput (tok/s):          932.95
---------------Time to First Token----------------
Mean TTFT (ms):                          125.17
Median TTFT (ms):                        121.51
P50 TTFT (ms):                           121.51
P90 TTFT (ms):                           140.91
P99 TTFT (ms):                           182.36
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.85
Median TPOT (ms):                        43.84
P50 TPOT (ms):                           43.84
P90 TPOT (ms):                           44.28
P99 TPOT (ms):                           44.32
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.85
Median ITL (ms):                         42.63
P50 ITL (ms):                            42.63
P90 ITL (ms):                            48.74
P99 ITL (ms):                            59.62
==================================================

after
============ Serving Benchmark Result ============
Successful requests:                     30
Maximum request concurrency:             256
Request rate configured (RPS):           0.70
Benchmark duration (s):                  130.10
Total input tokens:                      61363
Total generated tokens:                  61440
Request throughput (req/s):              0.23
Output token throughput (tok/s):         472.26
Peak output token throughput (tok/s):    750.00
Peak concurrent requests:                30.00
Total Token throughput (tok/s):          943.94
---------------Time to First Token----------------
Mean TTFT (ms):                          123.69
Median TTFT (ms):                        122.51
P50 TTFT (ms):                           122.51
P90 TTFT (ms):                           143.69
P99 TTFT (ms):                           165.00
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.07
Median TPOT (ms):                        43.13
P50 TPOT (ms):                           43.13
P90 TPOT (ms):                           43.50
P99 TPOT (ms):                           43.57
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.07
Median ITL (ms):                         41.81
P50 ITL (ms):                            41.81
P90 ITL (ms):                            48.11
P99 ITL (ms):                            62.13
==================================================

Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
shiyuan680
2025-10-27 19:41:07 +08:00
committed by GitHub
parent 3e5ae49160
commit 00aa0bf33e
3 changed files with 45 additions and 14 deletions

View File

@@ -68,6 +68,8 @@ class AttentionMaskBuilder:
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
device: torch.device):
if max_seq_len == 2048 and torch.version.cann.startswith("8.3"):
return self.chunked_prefill_attn_mask.to(torch.bool)
self._update_attn_cache(max_seq_len, dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(device, non_blocking=True)

View File

@@ -491,19 +491,44 @@ class AscendAttentionBackendImpl(AttentionImpl):
compress_mask = attn_metadata.attn_mask
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_table,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
if torch.version.cann.startswith("8.3") and block_size == 128:
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=compress_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
else:
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_table,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output
def _forward_decode_only(

View File

@@ -962,8 +962,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_attn_mask(
2048, self.dtype, self.device)
else:
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
# Decode-only situation.
else:
return None