[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (#2128)

Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.

### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:

* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.

**Known issues:**

1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.

There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.

This is essentially a port of #1503 and #1677, but includes two major
changes:

1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.

### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
    "cudagraph_mode": "FULL_DECODE_ONLY",
},
```

### How was this patch tested?
Tests included.


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-09-22 17:14:28 +08:00
committed by GitHub
parent f39bd309b6
commit 338231acaf
14 changed files with 390 additions and 91 deletions

View File

@@ -101,7 +101,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
@@ -134,7 +134,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
@@ -165,7 +165,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
@@ -378,10 +378,12 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_flash_attention_qlens.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
def test_forward_decode_only(self, mock_paged_attention,
mock_npu_reshape_and_cache):
mock_npu_reshape_and_cache,
mock_get_forward_context):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -395,6 +397,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl.forward(layer,
query,
key,
@@ -435,12 +439,13 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa_seq_len_mismatch(
self, mock_fused_infer_attention_score, mock_paged_attention,
mock_npu_reshape_and_cache):
mock_npu_reshape_and_cache, mock_get_forward_context):
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -457,6 +462,8 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl_swa.forward(self.layer_no_quant,
query,
key,

View File

@@ -463,7 +463,7 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
max_query_len=1,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([1, 1]),
attn_mask=torch.ones((15, 15)),

View File

@@ -31,13 +31,15 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
get_graph_params, is_310p, nd_to_nz_2d,
nd_to_nz_spec)
def wait_for_kv_layer_from_connector(layer_name: str):
@@ -197,6 +199,12 @@ class AscendMetadata:
class AscendAttentionMetadataBuilder:
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
@@ -221,7 +229,7 @@ class AscendAttentionMetadataBuilder:
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
model: Optional[nn.Module] = None,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -231,11 +239,7 @@ class AscendAttentionMetadataBuilder:
block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
self.device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
@@ -268,6 +272,24 @@ class AscendAttentionMetadataBuilder:
is_only_prefill=common_attn_metadata.is_only_prefill)
return attn_metadata
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)
attn_metadata.attn_state = attn_state
return attn_metadata
class AscendAttentionBackendImpl(AttentionImpl):
@@ -406,16 +428,53 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = output.view(batch_size, self.num_heads, self.head_size)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
query,
self.key_cache,
self.value_cache,
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
output,
))
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
return output
def _forward_v1_style(

View File

@@ -292,11 +292,7 @@ class AscendMLAMetadataBuilder:
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)

View File

@@ -41,7 +41,7 @@ class AscendCommonAttentionMetadata:
block_table_tensor: torch.Tensor
slot_mapping_cpu: torch.Tensor
slot_mapping: torch.Tensor
actual_seq_lengths_q: list[int]

View File

@@ -147,6 +147,7 @@ class ACLGraphWrapper:
patch("torch.npu.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True
with torch.npu.graph(aclgraph, pool=self.graph_pool):
# `output` is managed by pytorch's aclgraph pool
output = self.runnable(*args, **kwargs)

View File

@@ -179,23 +179,13 @@ class NPUPlatform(Platform):
compilation_config.cudagraph_num_of_warmups = 1
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
# if cudagraph_mode is not explicitly set by users, set default value
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
elif compilation_config.level not in [
if compilation_config.level not in [
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
]:
logger.warning(
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
compilation_config.level)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
else:
logger.warning(
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
if ascend_config.torchair_graph_config.enabled:
@@ -221,7 +211,12 @@ class NPUPlatform(Platform):
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
# after MLA being supported
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
and model_config.use_mla):
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
@@ -233,6 +228,24 @@ class NPUPlatform(Platform):
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
])
update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
compilation_config.use_inductor = False
warning_message = """\033[91m
**********************************************************************************
* WARNING: You have enabled the *full graph* feature.
* This is an early experimental stage and may involve various unknown issues.
* A known problem is that capturing too many batch sizes can lead to OOM
* (Out of Memory) errors or inference hangs. If you encounter such issues,
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
* batch size for graph capture.
* For more details, please refer to:
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
@@ -379,3 +392,7 @@ class NPUPlatform(Platform):
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True

View File

@@ -347,7 +347,7 @@ class EagleProposer(Proposer):
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=self.runner.slot_mapping_cpu,
slot_mapping=self.runner.slot_mapping,
positions=self.runner.positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
@@ -434,7 +434,7 @@ class EagleProposer(Proposer):
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=target_slot_mapping,
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,

View File

@@ -385,7 +385,7 @@ class MtpProposer(Proposer):
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=target_slot_mapping,
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,

View File

@@ -175,7 +175,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
model: Optional[nn.Module] = None,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -185,11 +185,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
block_table[:num_reqs])
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
self.device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state

View File

@@ -400,11 +400,7 @@ class AscendMLATorchairMetadataBuilder:
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)

View File

@@ -121,12 +121,14 @@ class NPUTorchairModelRunner(NPUModelRunner):
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
def _build_attention_metadata(self, with_prefill, num_reqs, num_tokens,
max_query_len, force_attention):
# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if with_prefill or self.enable_shared_expert_dp:
attn_metadata = super()._build_attention_metadata(
with_prefill, num_reqs, skip_attn)
with_prefill, num_reqs, num_tokens, max_query_len,
force_attention)
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,

View File

@@ -22,6 +22,7 @@ import functools
import math
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
@@ -634,3 +635,34 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
return nullcontext()
assert target_stream is not None
return torch.npu.stream(target_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def get_graph_params():
return _graph_params

View File

@@ -70,8 +70,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv, get_dtype_size,
is_pin_memory_available)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import \
reorder_batch_to_split_decodes_and_prefills
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
@@ -116,8 +116,9 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p,
lmhead_tp_enable, vllm_version_is)
get_ascend_soc_version, get_graph_params,
is_310p, lmhead_tp_enable, set_graph_params,
vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@@ -352,6 +353,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.uses_mrope = self.model_config.uses_mrope
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -1222,7 +1226,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
Optional[torch.Tensor], Optional[torch.Tensor], int]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@@ -1475,11 +1479,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
blk_table_tensor = blk_table.get_device_tensor()
slot_mapping = blk_table.slot_mapping_cpu[:
total_num_scheduled_tokens]
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
slot_mapping)
# # Fill unused with -1. Needed for reshape_and_cache in full cuda
# # graph mode.
# blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
slot_mapping[:total_num_scheduled_tokens],
non_blocking=True,
)
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1492,7 +1495,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
actual_seq_lengths_q=self.actual_seq_lengths_q,
# TODO: change this to the right block table for linear attn
block_table_tensor=blk_table_tensor[:num_reqs],
slot_mapping_cpu=self.slot_mapping_cpu,
slot_mapping=self.slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
@@ -1549,7 +1552,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return (attn_metadata, positions, num_scheduled_tokens,
num_input_tokens, num_tokens_across_dp,
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
input_ids, inputs_embeds, intermediate_tensors)
input_ids, inputs_embeds, intermediate_tensors,
max_num_scheduled_tokens)
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
maybe_padded_num_tokens,
@@ -1563,6 +1567,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
graph_params = get_graph_params()
self.update_attn_params(graph_params, forward_context,
positions.shape[0])
if get_forward_context().flashcomm_v1_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size
@@ -1570,6 +1581,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states = hidden_states[:-pad_size, :]
return hidden_states
def update_attn_params(self, graph_params, forward_context, runtime_shape):
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(self.update_stream):
torch.npu.graph_task_update_begin(self.update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(self.update_stream)
event.record(self.update_stream)
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
ascend_config = get_ascend_config()
@@ -1886,8 +1935,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
(attn_metadata, positions, num_scheduled_tokens_np,
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
intermediate_tensors,
max_query_len) = (self._prepare_inputs(scheduler_output,
intermediate_tensors))
if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
@@ -1895,8 +1945,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
scheduler_output.total_num_scheduled_tokens
== self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
uniform_decode=uniform_decode)
aclgraph_runtime_mode, batch_descriptor = \
self.aclgraph_dispatcher.dispatch(batch_descriptor)
@@ -2215,12 +2268,54 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.finished_req_ids)
return None, None
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
if skip_attn:
attn_metadata = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
attn_metadata = None
def _build_attention_metadata(self, create_mixed_batch, num_reqs,
num_tokens, max_query_len, force_attention):
attn_metadata: Optional[dict[str, Any]] = None
if force_attention:
attn_metadata = {}
if create_mixed_batch:
raise NotImplementedError(
"force_attention=True is not supported for mixed batches.")
else:
seq_lens = self.model_config.max_model_len
self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_table_tensor = self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu,
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=self.slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
decode_token_per_req=self.decode_token_per_req,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
return attn_metadata
def _generate_dummy_run_hidden_states(self, with_prefill,
@@ -2249,12 +2344,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
) -> torch.Tensor:
# only support eager mode and piecewise graph now
assert aclgraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
if force_attention:
raise RuntimeError(
"Capturing attention in aclgraph is unexpected, because full graph is not supported now"
)
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
@@ -2310,9 +2401,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer:
with_prefill = True
attn_metadata = self._build_attention_metadata(with_prefill,
num_reqs,
skip_attn=True)
attn_metadata = self._build_attention_metadata(
with_prefill,
num_reqs,
num_tokens,
max_query_len,
force_attention,
)
if not self.in_profile_run and self.dynamic_eplb:
self.eplb_updator.forward_before()
@@ -2551,6 +2646,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30))
# wrap the model with full graph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.update_stream = torch.npu.Stream()
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def _convert_torch_format(self, tensor):
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor
@@ -3167,9 +3270,78 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec
def initialize_aclgraph_capture(self) -> None:
# TODO: Add check of AttentionCGSupport and cudagraph_mode.decode_mode when full graph is supported
# Trigger aclgraph dispatching keys initialization here (after
# initializing attn backends).
min_ag_support = AttentionCGSupport.ALWAYS
min_ag_builder_name = None
for attn_group in self._attn_group_iterator():
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
if builder.cudagraph_support.value < min_ag_support.value:
min_ag_support = builder.cudagraph_support
min_ag_builder_name = builder.__class__.__name__
# This is an imitation of compilation_config.splitting_ops_contain_attention()
splitting_ops_contain_attention = (
self.compilation_config.splitting_ops is not None
and all(op in self.compilation_config.splitting_ops for op in [
"vllm.unified_ascend_attention_with_output",
"vllm.mla_forward",
]))
# Flexible resolve the aclgraph mode
aclgraph_mode = self.compilation_config.cudagraph_mode
# check graph for mixed batch is supported
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
and min_ag_support != AttentionCGSupport.ALWAYS:
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
f"with {min_ag_builder_name} backend (support: "
f"{min_ag_support})")
if min_ag_support == AttentionCGSupport.NEVER:
# if not supported any full graphs, just raise it.
msg += "; please try cudagraph_mode=PIECEWISE, and "\
"make sure compilation level is piecewise"
raise ValueError(msg)
# attempt to resolve the full graph related mode
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_AND_PIECEWISE)
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY)
logger.warning(msg)
# check that if spec-decode + decode full-graphs is supported
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1 and min_ag_support.value
< AttentionCGSupport.UNIFORM_BATCH.value):
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_ag_builder_name} (support: {min_ag_support})")
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
aclgraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# double check that we can support full graph if they are requested
# even after automatic downgrades
if aclgraph_mode.has_full_cudagraphs() \
and min_ag_support == AttentionCGSupport.NEVER:
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
f"supported with {min_ag_builder_name} backend ("
f"support:{min_ag_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise")
self.aclgraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)
@@ -3178,10 +3350,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
aclgraph_runtime_mode in [CUDAGraphMode.PIECEWISE]
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
logger.info(
"Starting to capture ACL graphs for cases: %s, "
"mode: %s, uniform_decode: %s", compilation_cases,
aclgraph_runtime_mode.name, uniform_decode)
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
@@ -3203,6 +3380,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
uniform_decode=uniform_decode)
self._dummy_run(num_tokens,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
uniform_decode=uniform_decode)
def _capture_model(self):
@@ -3229,6 +3407,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode=aclgraph_runtime_mode,
uniform_decode=False)
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
# Disable aclgraph capturing globally, so any unexpected aclgraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because