[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (#2128)
Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.
### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:
* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.
**Known issues:**
1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.
There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.
This is essentially a port of #1503 and #1677, but includes two major
changes:
1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.
### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
```
### How was this patch tested?
Tests included.
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -179,23 +179,13 @@ class NPUPlatform(Platform):
|
||||
|
||||
compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
|
||||
# if cudagraph_mode is not explicitly set by users, set default value
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
elif compilation_config.level not in [
|
||||
if compilation_config.level not in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
|
||||
]:
|
||||
logger.warning(
|
||||
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
|
||||
compilation_config.level)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
logger.warning(
|
||||
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
@@ -221,7 +211,12 @@ class NPUPlatform(Platform):
|
||||
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
|
||||
# after MLA being supported
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
|
||||
compilation_config.cudagraph_mode
|
||||
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
|
||||
and model_config.use_mla):
|
||||
logger.info(
|
||||
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
@@ -233,6 +228,24 @@ class NPUPlatform(Platform):
|
||||
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
|
||||
])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
logger.info(
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
compilation_config.use_inductor = False
|
||||
warning_message = """\033[91m
|
||||
**********************************************************************************
|
||||
* WARNING: You have enabled the *full graph* feature.
|
||||
* This is an early experimental stage and may involve various unknown issues.
|
||||
* A known problem is that capturing too many batch sizes can lead to OOM
|
||||
* (Out of Memory) errors or inference hangs. If you encounter such issues,
|
||||
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
|
||||
* batch size for graph capture.
|
||||
* For more details, please refer to:
|
||||
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
|
||||
**********************************************************************************\033[0m
|
||||
"""
|
||||
logger.warning(warning_message)
|
||||
else:
|
||||
logger.info(
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
||||
@@ -379,3 +392,7 @@ class NPUPlatform(Platform):
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user