Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.
### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:
* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.
**Known issues:**
1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.
There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.
This is essentially a port of #1503 and #1677, but includes two major
changes:
1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.
### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
```
### How was this patch tested?
Tests included.
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class AscendCommonAttentionMetadata:
|
|
"""
|
|
Per-batch attention metadata, shared across layers and backends.
|
|
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
|
|
|
For many of the tensors we keep both GPU and CPU versions.
|
|
"""
|
|
|
|
query_start_loc: torch.Tensor
|
|
query_start_loc_cpu: torch.Tensor
|
|
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
|
|
|
seq_lens_cpu: torch.Tensor
|
|
"""(batch_size,), the length of each request including both computed tokens
|
|
and newly scheduled tokens"""
|
|
|
|
seq_lens: torch.Tensor
|
|
"""same to seq_lens_cpu, for compatibility with some new attn metadata
|
|
(such as GDN)."""
|
|
|
|
num_computed_tokens_cpu: torch.Tensor
|
|
"""(batch_size,), the number of computed tokens for each request"""
|
|
|
|
num_reqs: int
|
|
"""Number of requests"""
|
|
num_actual_tokens: int
|
|
"""Total number of tokens in batch"""
|
|
|
|
max_query_len: int
|
|
"""Max token number of request in batch"""
|
|
|
|
decode_token_per_req: int
|
|
"""decode token number per request"""
|
|
|
|
block_table_tensor: torch.Tensor
|
|
|
|
slot_mapping: torch.Tensor
|
|
|
|
actual_seq_lengths_q: list[int]
|
|
|
|
positions: torch.Tensor = None
|
|
|
|
attn_mask: torch.Tensor = None
|
|
|
|
spec_attn_mask: torch.Tensor = None
|
|
|
|
attn_state: Any = None
|
|
|
|
enable_dbo_across_dp: bool = False
|
|
|
|
is_only_prefill: bool = False
|
|
|
|
graph_pad_size: int = -1
|
|
|
|
|
|
def split_decodes_and_prefills(
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
decode_threshold: int = 1,
|
|
) -> tuple[int, int, int, int]:
|
|
"""
|
|
Assuming a reordered batch, finds the boundary between prefill and decode
|
|
requests.
|
|
|
|
Args:
|
|
common_attn_metadata: AscendCommonAttentionMetadata object containing the
|
|
batch metadata.
|
|
decode_threshold: The maximum query length to be considered a decode.
|
|
|
|
Returns:
|
|
num_decodes: The number of decode requests.
|
|
num_prefills: The number of prefill requests.
|
|
num_decode_tokens: The number of tokens in the decode requests.
|
|
num_prefill_tokens: The number of tokens in the prefill requests.
|
|
"""
|
|
max_query_len = common_attn_metadata.max_query_len
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
|
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
|
|
|
if max_query_len <= decode_threshold:
|
|
return num_reqs, 0, num_tokens, 0
|
|
|
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
|
is_prefill = query_lens > decode_threshold
|
|
if not torch.any(is_prefill):
|
|
return num_reqs, 0, num_tokens, 0
|
|
|
|
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
|
assert torch.all(query_lens[first_prefill:] >= decode_threshold)
|
|
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
|
num_decodes = first_prefill
|
|
num_prefills = num_reqs - num_decodes
|
|
num_decode_tokens = query_start_loc[first_prefill].item()
|
|
num_prefill_tokens = num_tokens - num_decode_tokens
|
|
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|