support prefill cache mode use fia op (#3696)
### What this PR does / why we need it?
support prefill cache mode use fia op for full graph
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
origin
============ Serving Benchmark Result ============
Successful requests: 30
Maximum request concurrency: 256
Request rate configured (RPS): 0.70
Benchmark duration (s): 131.63
Total input tokens: 61363
Total generated tokens: 61440
Request throughput (req/s): 0.23
Output token throughput (tok/s): 466.77
Peak output token throughput (tok/s): 750.00
Peak concurrent requests: 30.00
Total Token throughput (tok/s): 932.95
---------------Time to First Token----------------
Mean TTFT (ms): 125.17
Median TTFT (ms): 121.51
P50 TTFT (ms): 121.51
P90 TTFT (ms): 140.91
P99 TTFT (ms): 182.36
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 43.85
Median TPOT (ms): 43.84
P50 TPOT (ms): 43.84
P90 TPOT (ms): 44.28
P99 TPOT (ms): 44.32
---------------Inter-token Latency----------------
Mean ITL (ms): 43.85
Median ITL (ms): 42.63
P50 ITL (ms): 42.63
P90 ITL (ms): 48.74
P99 ITL (ms): 59.62
==================================================
after
============ Serving Benchmark Result ============
Successful requests: 30
Maximum request concurrency: 256
Request rate configured (RPS): 0.70
Benchmark duration (s): 130.10
Total input tokens: 61363
Total generated tokens: 61440
Request throughput (req/s): 0.23
Output token throughput (tok/s): 472.26
Peak output token throughput (tok/s): 750.00
Peak concurrent requests: 30.00
Total Token throughput (tok/s): 943.94
---------------Time to First Token----------------
Mean TTFT (ms): 123.69
Median TTFT (ms): 122.51
P50 TTFT (ms): 122.51
P90 TTFT (ms): 143.69
P99 TTFT (ms): 165.00
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 43.07
Median TPOT (ms): 43.13
P50 TPOT (ms): 43.13
P90 TPOT (ms): 43.50
P99 TPOT (ms): 43.57
---------------Inter-token Latency----------------
Mean ITL (ms): 43.07
Median ITL (ms): 41.81
P50 ITL (ms): 41.81
P90 ITL (ms): 48.11
P99 ITL (ms): 62.13
==================================================
Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
@@ -68,6 +68,8 @@ class AttentionMaskBuilder:
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
if max_seq_len == 2048 and torch.version.cann.startswith("8.3"):
|
||||
return self.chunked_prefill_attn_mask.to(torch.bool)
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
@@ -491,19 +491,44 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_table,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
if torch.version.cann.startswith("8.3") and block_size == 128:
|
||||
# TODO:The npu_fused_infer_attention_score op is planned to
|
||||
# be utilized in a wider range in upcoming versions.
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=compress_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
else:
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_table,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_decode_only(
|
||||
|
||||
@@ -962,8 +962,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Prefill with cache hit.
|
||||
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
128, self.dtype, self.device)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
2048, self.dtype, self.device)
|
||||
else:
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
128, self.dtype, self.device)
|
||||
# Decode-only situation.
|
||||
else:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user