[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (#2128)
Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.
### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:
* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.
**Known issues:**
1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.
There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.
This is essentially a port of #1503 and #1677, but includes two major
changes:
1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.
### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
```
### How was this patch tested?
Tests included.
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -70,8 +70,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LazyLoader, cdiv, get_dtype_size,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import \
|
||||
reorder_batch_to_split_decodes_and_prefills
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@@ -116,8 +116,9 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
get_ascend_soc_version, is_310p,
|
||||
lmhead_tp_enable, vllm_version_is)
|
||||
get_ascend_soc_version, get_graph_params,
|
||||
is_310p, lmhead_tp_enable, set_graph_params,
|
||||
vllm_version_is)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -352,6 +353,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
@@ -1222,7 +1226,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
|
||||
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
Optional[torch.Tensor], Optional[torch.Tensor], int]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@@ -1475,11 +1479,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
blk_table_tensor = blk_table.get_device_tensor()
|
||||
slot_mapping = blk_table.slot_mapping_cpu[:
|
||||
total_num_scheduled_tokens]
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
|
||||
slot_mapping)
|
||||
# # Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# # graph mode.
|
||||
# blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
||||
slot_mapping[:total_num_scheduled_tokens],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
# Make AscendCommonAttentionMetadata
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
@@ -1492,7 +1495,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
# TODO: change this to the right block table for linear attn
|
||||
block_table_tensor=blk_table_tensor[:num_reqs],
|
||||
slot_mapping_cpu=self.slot_mapping_cpu,
|
||||
slot_mapping=self.slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
@@ -1549,7 +1552,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return (attn_metadata, positions, num_scheduled_tokens,
|
||||
num_input_tokens, num_tokens_across_dp,
|
||||
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
||||
input_ids, inputs_embeds, intermediate_tensors)
|
||||
input_ids, inputs_embeds, intermediate_tensors,
|
||||
max_num_scheduled_tokens)
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
maybe_padded_num_tokens,
|
||||
@@ -1563,6 +1567,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
graph_params = get_graph_params()
|
||||
self.update_attn_params(graph_params, forward_context,
|
||||
positions.shape[0])
|
||||
|
||||
if get_forward_context().flashcomm_v1_enabled:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
pad_size = get_forward_context().pad_size
|
||||
@@ -1570,6 +1581,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states = hidden_states[:-pad_size, :]
|
||||
return hidden_states
|
||||
|
||||
def update_attn_params(self, graph_params, forward_context, runtime_shape):
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
# block_table = forward_context.attn_metadata[key].block_tables
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
with torch.npu.stream(self.update_stream):
|
||||
torch.npu.graph_task_update_begin(self.update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_end(self.update_stream)
|
||||
|
||||
event.record(self.update_stream)
|
||||
|
||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens):
|
||||
ascend_config = get_ascend_config()
|
||||
@@ -1886,8 +1935,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
(attn_metadata, positions, num_scheduled_tokens_np,
|
||||
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
|
||||
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
|
||||
intermediate_tensors) = (self._prepare_inputs(
|
||||
scheduler_output, intermediate_tensors))
|
||||
intermediate_tensors,
|
||||
max_query_len) = (self._prepare_inputs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.eplb_updator.take_update_info_from_eplb_process()
|
||||
@@ -1895,8 +1945,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
|
||||
self.with_prefill)
|
||||
|
||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
scheduler_output.total_num_scheduled_tokens
|
||||
== self.input_batch.num_reqs * max_query_len)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=False)
|
||||
uniform_decode=uniform_decode)
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.aclgraph_dispatcher.dispatch(batch_descriptor)
|
||||
|
||||
@@ -2215,12 +2268,54 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
|
||||
attn_metadata = None
|
||||
def _build_attention_metadata(self, create_mixed_batch, num_reqs,
|
||||
num_tokens, max_query_len, force_attention):
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
if force_attention:
|
||||
attn_metadata = {}
|
||||
|
||||
if create_mixed_batch:
|
||||
raise NotImplementedError(
|
||||
"force_attention=True is not supported for mixed batches.")
|
||||
else:
|
||||
seq_lens = self.model_config.max_model_len
|
||||
self.seq_lens_np[:num_reqs] = seq_lens
|
||||
self.seq_lens_np[num_reqs:] = 0
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
block_table_tensor = self.input_batch.block_table[
|
||||
kv_cache_group_id].get_device_tensor()
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||
1],
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
block_table_tensor=block_table_tensor[:num_reqs],
|
||||
slot_mapping=self.slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
max_query_len=max_query_len,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
if vllm_version_is("0.10.2"):
|
||||
builder = attn_group.metadata_builder
|
||||
else:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
attn_metadata_i = builder.build_for_graph_capture(
|
||||
common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||
@@ -2249,12 +2344,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
) -> torch.Tensor:
|
||||
# only support eager mode and piecewise graph now
|
||||
assert aclgraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
if force_attention:
|
||||
raise RuntimeError(
|
||||
"Capturing attention in aclgraph is unexpected, because full graph is not supported now"
|
||||
)
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
@@ -2310,9 +2401,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_kv_producer:
|
||||
with_prefill = True
|
||||
|
||||
attn_metadata = self._build_attention_metadata(with_prefill,
|
||||
num_reqs,
|
||||
skip_attn=True)
|
||||
attn_metadata = self._build_attention_metadata(
|
||||
with_prefill,
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
max_query_len,
|
||||
force_attention,
|
||||
)
|
||||
|
||||
if not self.in_profile_run and self.dynamic_eplb:
|
||||
self.eplb_updator.forward_before()
|
||||
@@ -2551,6 +2646,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
m.consumed_memory / float(2**30))
|
||||
|
||||
# wrap the model with full graph wrapper if needed.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.update_stream = torch.npu.Stream()
|
||||
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
def _convert_torch_format(self, tensor):
|
||||
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
||||
return tensor
|
||||
@@ -3167,9 +3270,78 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return kv_cache_spec
|
||||
|
||||
def initialize_aclgraph_capture(self) -> None:
|
||||
# TODO: Add check of AttentionCGSupport and cudagraph_mode.decode_mode when full graph is supported
|
||||
# Trigger aclgraph dispatching keys initialization here (after
|
||||
# initializing attn backends).
|
||||
min_ag_support = AttentionCGSupport.ALWAYS
|
||||
min_ag_builder_name = None
|
||||
|
||||
for attn_group in self._attn_group_iterator():
|
||||
if vllm_version_is("0.10.2"):
|
||||
builder = attn_group.metadata_builder
|
||||
else:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if builder.cudagraph_support.value < min_ag_support.value:
|
||||
min_ag_support = builder.cudagraph_support
|
||||
min_ag_builder_name = builder.__class__.__name__
|
||||
|
||||
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
||||
splitting_ops_contain_attention = (
|
||||
self.compilation_config.splitting_ops is not None
|
||||
and all(op in self.compilation_config.splitting_ops for op in [
|
||||
"vllm.unified_ascend_attention_with_output",
|
||||
"vllm.mla_forward",
|
||||
]))
|
||||
|
||||
# Flexible resolve the aclgraph mode
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode
|
||||
# check graph for mixed batch is supported
|
||||
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
|
||||
and min_ag_support != AttentionCGSupport.ALWAYS:
|
||||
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
|
||||
f"with {min_ag_builder_name} backend (support: "
|
||||
f"{min_ag_support})")
|
||||
if min_ag_support == AttentionCGSupport.NEVER:
|
||||
# if not supported any full graphs, just raise it.
|
||||
msg += "; please try cudagraph_mode=PIECEWISE, and "\
|
||||
"make sure compilation level is piecewise"
|
||||
raise ValueError(msg)
|
||||
|
||||
# attempt to resolve the full graph related mode
|
||||
if splitting_ops_contain_attention:
|
||||
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE)
|
||||
else:
|
||||
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
||||
CUDAGraphMode.FULL_DECODE_ONLY)
|
||||
logger.warning(msg)
|
||||
|
||||
# check that if spec-decode + decode full-graphs is supported
|
||||
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and self.uniform_decode_query_len > 1 and min_ag_support.value
|
||||
< AttentionCGSupport.UNIFORM_BATCH.value):
|
||||
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
|
||||
f" with spec-decode for attention backend "
|
||||
f"{min_ag_builder_name} (support: {min_ag_support})")
|
||||
if splitting_ops_contain_attention:
|
||||
msg += "; setting cudagraph_mode=PIECEWISE"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
msg += "; setting cudagraph_mode=NONE"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.NONE
|
||||
logger.warning(msg)
|
||||
|
||||
# double check that we can support full graph if they are requested
|
||||
# even after automatic downgrades
|
||||
if aclgraph_mode.has_full_cudagraphs() \
|
||||
and min_ag_support == AttentionCGSupport.NEVER:
|
||||
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
|
||||
f"supported with {min_ag_builder_name} backend ("
|
||||
f"support:{min_ag_support}) "
|
||||
"; please try cudagraph_mode=PIECEWISE, "
|
||||
"and make sure compilation level is piecewise")
|
||||
|
||||
self.aclgraph_dispatcher.initialize_cudagraph_keys(
|
||||
self.compilation_config.cudagraph_mode,
|
||||
self.uniform_decode_query_len)
|
||||
@@ -3178,10 +3350,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
aclgraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool):
|
||||
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
|
||||
aclgraph_runtime_mode in [CUDAGraphMode.PIECEWISE]
|
||||
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
|
||||
CUDAGraphMode.PIECEWISE]
|
||||
|
||||
# Only rank 0 should print progress bar during capture
|
||||
if is_global_first_rank():
|
||||
logger.info(
|
||||
"Starting to capture ACL graphs for cases: %s, "
|
||||
"mode: %s, uniform_decode: %s", compilation_cases,
|
||||
aclgraph_runtime_mode.name, uniform_decode)
|
||||
compilation_cases = tqdm(
|
||||
compilation_cases,
|
||||
disable=not self.load_config.use_tqdm_on_load,
|
||||
@@ -3203,6 +3380,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
uniform_decode=uniform_decode)
|
||||
self._dummy_run(num_tokens,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode)
|
||||
|
||||
def _capture_model(self):
|
||||
@@ -3229,6 +3407,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
uniform_decode=False)
|
||||
|
||||
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||
aclgraph_mode.separate_routine():
|
||||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
||||
self.uniform_decode_query_len
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
|
||||
and x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(
|
||||
reversed(decode_cudagraph_batch_sizes))
|
||||
self._capture_aclgraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
uniform_decode=True)
|
||||
|
||||
# Disable aclgraph capturing globally, so any unexpected aclgraph
|
||||
# capturing will be detected and raise an error after here.
|
||||
# Note: We don't put it into graph_capture context manager because
|
||||
|
||||
Reference in New Issue
Block a user