[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (#2128)

Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.

### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:

* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.

**Known issues:**

1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.

There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.

This is essentially a port of #1503 and #1677, but includes two major
changes:

1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.

### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
    "cudagraph_mode": "FULL_DECODE_ONLY",
},
```

### How was this patch tested?
Tests included.


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-09-22 17:14:28 +08:00
committed by GitHub
parent f39bd309b6
commit 338231acaf
14 changed files with 390 additions and 91 deletions

View File

@@ -70,8 +70,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv, get_dtype_size,
is_pin_memory_available)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import \
reorder_batch_to_split_decodes_and_prefills
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
@@ -116,8 +116,9 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p,
lmhead_tp_enable, vllm_version_is)
get_ascend_soc_version, get_graph_params,
is_310p, lmhead_tp_enable, set_graph_params,
vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@@ -352,6 +353,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.uses_mrope = self.model_config.uses_mrope
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -1222,7 +1226,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
Optional[torch.Tensor], Optional[torch.Tensor], int]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@@ -1475,11 +1479,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
blk_table_tensor = blk_table.get_device_tensor()
slot_mapping = blk_table.slot_mapping_cpu[:
total_num_scheduled_tokens]
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
slot_mapping)
# # Fill unused with -1. Needed for reshape_and_cache in full cuda
# # graph mode.
# blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
slot_mapping[:total_num_scheduled_tokens],
non_blocking=True,
)
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1492,7 +1495,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
actual_seq_lengths_q=self.actual_seq_lengths_q,
# TODO: change this to the right block table for linear attn
block_table_tensor=blk_table_tensor[:num_reqs],
slot_mapping_cpu=self.slot_mapping_cpu,
slot_mapping=self.slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
@@ -1549,7 +1552,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return (attn_metadata, positions, num_scheduled_tokens,
num_input_tokens, num_tokens_across_dp,
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
input_ids, inputs_embeds, intermediate_tensors)
input_ids, inputs_embeds, intermediate_tensors,
max_num_scheduled_tokens)
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
maybe_padded_num_tokens,
@@ -1563,6 +1567,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
graph_params = get_graph_params()
self.update_attn_params(graph_params, forward_context,
positions.shape[0])
if get_forward_context().flashcomm_v1_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size
@@ -1570,6 +1581,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states = hidden_states[:-pad_size, :]
return hidden_states
def update_attn_params(self, graph_params, forward_context, runtime_shape):
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(self.update_stream):
torch.npu.graph_task_update_begin(self.update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(self.update_stream)
event.record(self.update_stream)
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
ascend_config = get_ascend_config()
@@ -1886,8 +1935,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
(attn_metadata, positions, num_scheduled_tokens_np,
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
intermediate_tensors,
max_query_len) = (self._prepare_inputs(scheduler_output,
intermediate_tensors))
if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
@@ -1895,8 +1945,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
scheduler_output.total_num_scheduled_tokens
== self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
uniform_decode=uniform_decode)
aclgraph_runtime_mode, batch_descriptor = \
self.aclgraph_dispatcher.dispatch(batch_descriptor)
@@ -2215,12 +2268,54 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.finished_req_ids)
return None, None
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
if skip_attn:
attn_metadata = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
attn_metadata = None
def _build_attention_metadata(self, create_mixed_batch, num_reqs,
num_tokens, max_query_len, force_attention):
attn_metadata: Optional[dict[str, Any]] = None
if force_attention:
attn_metadata = {}
if create_mixed_batch:
raise NotImplementedError(
"force_attention=True is not supported for mixed batches.")
else:
seq_lens = self.model_config.max_model_len
self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_table_tensor = self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu,
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=self.slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
decode_token_per_req=self.decode_token_per_req,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
return attn_metadata
def _generate_dummy_run_hidden_states(self, with_prefill,
@@ -2249,12 +2344,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
) -> torch.Tensor:
# only support eager mode and piecewise graph now
assert aclgraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
if force_attention:
raise RuntimeError(
"Capturing attention in aclgraph is unexpected, because full graph is not supported now"
)
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
@@ -2310,9 +2401,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer:
with_prefill = True
attn_metadata = self._build_attention_metadata(with_prefill,
num_reqs,
skip_attn=True)
attn_metadata = self._build_attention_metadata(
with_prefill,
num_reqs,
num_tokens,
max_query_len,
force_attention,
)
if not self.in_profile_run and self.dynamic_eplb:
self.eplb_updator.forward_before()
@@ -2551,6 +2646,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30))
# wrap the model with full graph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.update_stream = torch.npu.Stream()
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def _convert_torch_format(self, tensor):
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor
@@ -3167,9 +3270,78 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec
def initialize_aclgraph_capture(self) -> None:
# TODO: Add check of AttentionCGSupport and cudagraph_mode.decode_mode when full graph is supported
# Trigger aclgraph dispatching keys initialization here (after
# initializing attn backends).
min_ag_support = AttentionCGSupport.ALWAYS
min_ag_builder_name = None
for attn_group in self._attn_group_iterator():
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
if builder.cudagraph_support.value < min_ag_support.value:
min_ag_support = builder.cudagraph_support
min_ag_builder_name = builder.__class__.__name__
# This is an imitation of compilation_config.splitting_ops_contain_attention()
splitting_ops_contain_attention = (
self.compilation_config.splitting_ops is not None
and all(op in self.compilation_config.splitting_ops for op in [
"vllm.unified_ascend_attention_with_output",
"vllm.mla_forward",
]))
# Flexible resolve the aclgraph mode
aclgraph_mode = self.compilation_config.cudagraph_mode
# check graph for mixed batch is supported
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
and min_ag_support != AttentionCGSupport.ALWAYS:
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
f"with {min_ag_builder_name} backend (support: "
f"{min_ag_support})")
if min_ag_support == AttentionCGSupport.NEVER:
# if not supported any full graphs, just raise it.
msg += "; please try cudagraph_mode=PIECEWISE, and "\
"make sure compilation level is piecewise"
raise ValueError(msg)
# attempt to resolve the full graph related mode
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_AND_PIECEWISE)
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY)
logger.warning(msg)
# check that if spec-decode + decode full-graphs is supported
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1 and min_ag_support.value
< AttentionCGSupport.UNIFORM_BATCH.value):
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_ag_builder_name} (support: {min_ag_support})")
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
aclgraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# double check that we can support full graph if they are requested
# even after automatic downgrades
if aclgraph_mode.has_full_cudagraphs() \
and min_ag_support == AttentionCGSupport.NEVER:
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
f"supported with {min_ag_builder_name} backend ("
f"support:{min_ag_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise")
self.aclgraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)
@@ -3178,10 +3350,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
aclgraph_runtime_mode in [CUDAGraphMode.PIECEWISE]
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
logger.info(
"Starting to capture ACL graphs for cases: %s, "
"mode: %s, uniform_decode: %s", compilation_cases,
aclgraph_runtime_mode.name, uniform_decode)
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
@@ -3203,6 +3380,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
uniform_decode=uniform_decode)
self._dummy_run(num_tokens,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
uniform_decode=uniform_decode)
def _capture_model(self):
@@ -3229,6 +3407,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode=aclgraph_runtime_mode,
uniform_decode=False)
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
# Disable aclgraph capturing globally, so any unexpected aclgraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because