[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (#2128)
Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.
### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:
* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.
**Known issues:**
1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.
There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.
This is essentially a port of #1503 and #1677, but includes two major
changes:
1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.
### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
```
### How was this patch tested?
Tests included.
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -175,7 +175,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
@@ -185,11 +185,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
block_table[:num_reqs])
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
self.device,
|
||||
non_blocking=
|
||||
True)
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
Reference in New Issue
Block a user