[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (#2128)

Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.

### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:

* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.

**Known issues:**

1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.

There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.

This is essentially a port of #1503 and #1677, but includes two major
changes:

1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.

### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
    "cudagraph_mode": "FULL_DECODE_ONLY",
},
```

### How was this patch tested?
Tests included.


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-09-22 17:14:28 +08:00
committed by GitHub
parent f39bd309b6
commit 338231acaf
14 changed files with 390 additions and 91 deletions

View File

@@ -179,23 +179,13 @@ class NPUPlatform(Platform):
compilation_config.cudagraph_num_of_warmups = 1
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
# if cudagraph_mode is not explicitly set by users, set default value
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
elif compilation_config.level not in [
if compilation_config.level not in [
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
]:
logger.warning(
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
compilation_config.level)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
else:
logger.warning(
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
if ascend_config.torchair_graph_config.enabled:
@@ -221,7 +211,12 @@ class NPUPlatform(Platform):
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
# after MLA being supported
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
and model_config.use_mla):
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
@@ -233,6 +228,24 @@ class NPUPlatform(Platform):
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
])
update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
compilation_config.use_inductor = False
warning_message = """\033[91m
**********************************************************************************
* WARNING: You have enabled the *full graph* feature.
* This is an early experimental stage and may involve various unknown issues.
* A known problem is that capturing too many batch sizes can lead to OOM
* (Out of Memory) errors or inference hangs. If you encounter such issues,
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
* batch size for graph capture.
* For more details, please refer to:
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
@@ -379,3 +392,7 @@ class NPUPlatform(Platform):
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True