[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (#2128)
Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.
### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:
* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.
**Known issues:**
1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.
There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.
This is essentially a port of #1503 and #1677, but includes two major
changes:
1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.
### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
```
### How was this patch tested?
Tests included.
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -101,7 +101,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
max_query_len=5,
|
||||
decode_token_per_req=torch.tensor([1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
@@ -134,7 +134,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
@@ -165,7 +165,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
@@ -378,10 +378,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_flash_attention_qlens.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
def test_forward_decode_only(self, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache):
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -395,6 +397,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
@@ -435,12 +439,13 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_decode_only_swa_seq_len_mismatch(
|
||||
self, mock_fused_infer_attention_score, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache):
|
||||
mock_npu_reshape_and_cache, mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -457,6 +462,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
||||
64), 1)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl_swa.forward(self.layer_no_quant,
|
||||
query,
|
||||
key,
|
||||
|
||||
@@ -463,7 +463,7 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
max_query_len=1,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([1, 1]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
|
||||
@@ -31,13 +31,15 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
|
||||
get_graph_params, is_310p, nd_to_nz_2d,
|
||||
nd_to_nz_spec)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
@@ -197,6 +199,12 @@ class AscendMetadata:
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(
|
||||
@@ -221,7 +229,7 @@ class AscendAttentionMetadataBuilder:
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
@@ -231,11 +239,7 @@ class AscendAttentionMetadataBuilder:
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
self.device,
|
||||
non_blocking=
|
||||
True)
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
@@ -268,6 +272,24 @@ class AscendAttentionMetadataBuilder:
|
||||
is_only_prefill=common_attn_metadata.is_only_prefill)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
):
|
||||
if attn_state == AscendAttentionState.DecodeOnly:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently we only support building dummy metadata for DecodeOnly state"
|
||||
)
|
||||
|
||||
attn_metadata.attn_state = attn_state
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
@@ -406,16 +428,53 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
query,
|
||||
self.key_cache,
|
||||
self.value_cache,
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
output,
|
||||
))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_v1_style(
|
||||
|
||||
@@ -292,11 +292,7 @@ class AscendMLAMetadataBuilder:
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
device,
|
||||
non_blocking=
|
||||
True)
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
|
||||
slot_mapping_cpu: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
|
||||
@@ -147,6 +147,7 @@ class ACLGraphWrapper:
|
||||
patch("torch.npu.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
forward_context.capturing = True
|
||||
with torch.npu.graph(aclgraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's aclgraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
|
||||
@@ -179,23 +179,13 @@ class NPUPlatform(Platform):
|
||||
|
||||
compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
|
||||
# if cudagraph_mode is not explicitly set by users, set default value
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
elif compilation_config.level not in [
|
||||
if compilation_config.level not in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
|
||||
]:
|
||||
logger.warning(
|
||||
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
|
||||
compilation_config.level)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
logger.warning(
|
||||
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
@@ -221,7 +211,12 @@ class NPUPlatform(Platform):
|
||||
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
|
||||
# after MLA being supported
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
|
||||
compilation_config.cudagraph_mode
|
||||
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
|
||||
and model_config.use_mla):
|
||||
logger.info(
|
||||
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
@@ -233,6 +228,24 @@ class NPUPlatform(Platform):
|
||||
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
|
||||
])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
logger.info(
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
compilation_config.use_inductor = False
|
||||
warning_message = """\033[91m
|
||||
**********************************************************************************
|
||||
* WARNING: You have enabled the *full graph* feature.
|
||||
* This is an early experimental stage and may involve various unknown issues.
|
||||
* A known problem is that capturing too many batch sizes can lead to OOM
|
||||
* (Out of Memory) errors or inference hangs. If you encounter such issues,
|
||||
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
|
||||
* batch size for graph capture.
|
||||
* For more details, please refer to:
|
||||
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
|
||||
**********************************************************************************\033[0m
|
||||
"""
|
||||
logger.warning(warning_message)
|
||||
else:
|
||||
logger.info(
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
||||
@@ -379,3 +392,7 @@ class NPUPlatform(Platform):
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -347,7 +347,7 @@ class EagleProposer(Proposer):
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor(),
|
||||
slot_mapping_cpu=self.runner.slot_mapping_cpu,
|
||||
slot_mapping=self.runner.slot_mapping,
|
||||
positions=self.runner.positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
@@ -434,7 +434,7 @@ class EagleProposer(Proposer):
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor(),
|
||||
slot_mapping_cpu=target_slot_mapping,
|
||||
slot_mapping=target_slot_mapping,
|
||||
positions=target_positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
|
||||
@@ -385,7 +385,7 @@ class MtpProposer(Proposer):
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor(),
|
||||
slot_mapping_cpu=target_slot_mapping,
|
||||
slot_mapping=target_slot_mapping,
|
||||
positions=target_positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
|
||||
@@ -175,7 +175,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
@@ -185,11 +185,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
block_table[:num_reqs])
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
self.device,
|
||||
non_blocking=
|
||||
True)
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
@@ -400,11 +400,7 @@ class AscendMLATorchairMetadataBuilder:
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
device,
|
||||
non_blocking=
|
||||
True)
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
@@ -121,12 +121,14 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
||||
|
||||
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
|
||||
def _build_attention_metadata(self, with_prefill, num_reqs, num_tokens,
|
||||
max_query_len, force_attention):
|
||||
# NOTE: If torchair graph mode and not with_prefill,
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
attn_metadata = super()._build_attention_metadata(
|
||||
with_prefill, num_reqs, skip_attn)
|
||||
with_prefill, num_reqs, num_tokens, max_query_len,
|
||||
force_attention)
|
||||
else:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
|
||||
@@ -22,6 +22,7 @@ import functools
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
@@ -634,3 +635,34 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
|
||||
return nullcontext()
|
||||
assert target_stream is not None
|
||||
return torch.npu.stream(target_stream)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphParams:
|
||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||
workspaces: dict[int, torch.Tensor]
|
||||
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
|
||||
attn_params: dict[int, list[tuple]]
|
||||
|
||||
|
||||
_graph_params: Optional[GraphParams] = None
|
||||
|
||||
|
||||
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
raise ValueError("Graph parameters have already been set!")
|
||||
_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
@@ -70,8 +70,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LazyLoader, cdiv, get_dtype_size,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import \
|
||||
reorder_batch_to_split_decodes_and_prefills
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@@ -116,8 +116,9 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
get_ascend_soc_version, is_310p,
|
||||
lmhead_tp_enable, vllm_version_is)
|
||||
get_ascend_soc_version, get_graph_params,
|
||||
is_310p, lmhead_tp_enable, set_graph_params,
|
||||
vllm_version_is)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -352,6 +353,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
@@ -1222,7 +1226,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
|
||||
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
Optional[torch.Tensor], Optional[torch.Tensor], int]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@@ -1475,11 +1479,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
blk_table_tensor = blk_table.get_device_tensor()
|
||||
slot_mapping = blk_table.slot_mapping_cpu[:
|
||||
total_num_scheduled_tokens]
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
|
||||
slot_mapping)
|
||||
# # Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# # graph mode.
|
||||
# blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
||||
slot_mapping[:total_num_scheduled_tokens],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
# Make AscendCommonAttentionMetadata
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
@@ -1492,7 +1495,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
# TODO: change this to the right block table for linear attn
|
||||
block_table_tensor=blk_table_tensor[:num_reqs],
|
||||
slot_mapping_cpu=self.slot_mapping_cpu,
|
||||
slot_mapping=self.slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
@@ -1549,7 +1552,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return (attn_metadata, positions, num_scheduled_tokens,
|
||||
num_input_tokens, num_tokens_across_dp,
|
||||
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
||||
input_ids, inputs_embeds, intermediate_tensors)
|
||||
input_ids, inputs_embeds, intermediate_tensors,
|
||||
max_num_scheduled_tokens)
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
maybe_padded_num_tokens,
|
||||
@@ -1563,6 +1567,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
graph_params = get_graph_params()
|
||||
self.update_attn_params(graph_params, forward_context,
|
||||
positions.shape[0])
|
||||
|
||||
if get_forward_context().flashcomm_v1_enabled:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
pad_size = get_forward_context().pad_size
|
||||
@@ -1570,6 +1581,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states = hidden_states[:-pad_size, :]
|
||||
return hidden_states
|
||||
|
||||
def update_attn_params(self, graph_params, forward_context, runtime_shape):
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
# block_table = forward_context.attn_metadata[key].block_tables
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
with torch.npu.stream(self.update_stream):
|
||||
torch.npu.graph_task_update_begin(self.update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_end(self.update_stream)
|
||||
|
||||
event.record(self.update_stream)
|
||||
|
||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens):
|
||||
ascend_config = get_ascend_config()
|
||||
@@ -1886,8 +1935,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
(attn_metadata, positions, num_scheduled_tokens_np,
|
||||
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
|
||||
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
|
||||
intermediate_tensors) = (self._prepare_inputs(
|
||||
scheduler_output, intermediate_tensors))
|
||||
intermediate_tensors,
|
||||
max_query_len) = (self._prepare_inputs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.eplb_updator.take_update_info_from_eplb_process()
|
||||
@@ -1895,8 +1945,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
|
||||
self.with_prefill)
|
||||
|
||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
scheduler_output.total_num_scheduled_tokens
|
||||
== self.input_batch.num_reqs * max_query_len)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=False)
|
||||
uniform_decode=uniform_decode)
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.aclgraph_dispatcher.dispatch(batch_descriptor)
|
||||
|
||||
@@ -2215,12 +2268,54 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
|
||||
attn_metadata = None
|
||||
def _build_attention_metadata(self, create_mixed_batch, num_reqs,
|
||||
num_tokens, max_query_len, force_attention):
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
if force_attention:
|
||||
attn_metadata = {}
|
||||
|
||||
if create_mixed_batch:
|
||||
raise NotImplementedError(
|
||||
"force_attention=True is not supported for mixed batches.")
|
||||
else:
|
||||
seq_lens = self.model_config.max_model_len
|
||||
self.seq_lens_np[:num_reqs] = seq_lens
|
||||
self.seq_lens_np[num_reqs:] = 0
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
block_table_tensor = self.input_batch.block_table[
|
||||
kv_cache_group_id].get_device_tensor()
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||
1],
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
block_table_tensor=block_table_tensor[:num_reqs],
|
||||
slot_mapping=self.slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
max_query_len=max_query_len,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
if vllm_version_is("0.10.2"):
|
||||
builder = attn_group.metadata_builder
|
||||
else:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
attn_metadata_i = builder.build_for_graph_capture(
|
||||
common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||
@@ -2249,12 +2344,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
) -> torch.Tensor:
|
||||
# only support eager mode and piecewise graph now
|
||||
assert aclgraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
if force_attention:
|
||||
raise RuntimeError(
|
||||
"Capturing attention in aclgraph is unexpected, because full graph is not supported now"
|
||||
)
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
@@ -2310,9 +2401,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_kv_producer:
|
||||
with_prefill = True
|
||||
|
||||
attn_metadata = self._build_attention_metadata(with_prefill,
|
||||
num_reqs,
|
||||
skip_attn=True)
|
||||
attn_metadata = self._build_attention_metadata(
|
||||
with_prefill,
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
max_query_len,
|
||||
force_attention,
|
||||
)
|
||||
|
||||
if not self.in_profile_run and self.dynamic_eplb:
|
||||
self.eplb_updator.forward_before()
|
||||
@@ -2551,6 +2646,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
m.consumed_memory / float(2**30))
|
||||
|
||||
# wrap the model with full graph wrapper if needed.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.update_stream = torch.npu.Stream()
|
||||
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
def _convert_torch_format(self, tensor):
|
||||
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
||||
return tensor
|
||||
@@ -3167,9 +3270,78 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return kv_cache_spec
|
||||
|
||||
def initialize_aclgraph_capture(self) -> None:
|
||||
# TODO: Add check of AttentionCGSupport and cudagraph_mode.decode_mode when full graph is supported
|
||||
# Trigger aclgraph dispatching keys initialization here (after
|
||||
# initializing attn backends).
|
||||
min_ag_support = AttentionCGSupport.ALWAYS
|
||||
min_ag_builder_name = None
|
||||
|
||||
for attn_group in self._attn_group_iterator():
|
||||
if vllm_version_is("0.10.2"):
|
||||
builder = attn_group.metadata_builder
|
||||
else:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if builder.cudagraph_support.value < min_ag_support.value:
|
||||
min_ag_support = builder.cudagraph_support
|
||||
min_ag_builder_name = builder.__class__.__name__
|
||||
|
||||
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
||||
splitting_ops_contain_attention = (
|
||||
self.compilation_config.splitting_ops is not None
|
||||
and all(op in self.compilation_config.splitting_ops for op in [
|
||||
"vllm.unified_ascend_attention_with_output",
|
||||
"vllm.mla_forward",
|
||||
]))
|
||||
|
||||
# Flexible resolve the aclgraph mode
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode
|
||||
# check graph for mixed batch is supported
|
||||
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
|
||||
and min_ag_support != AttentionCGSupport.ALWAYS:
|
||||
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
|
||||
f"with {min_ag_builder_name} backend (support: "
|
||||
f"{min_ag_support})")
|
||||
if min_ag_support == AttentionCGSupport.NEVER:
|
||||
# if not supported any full graphs, just raise it.
|
||||
msg += "; please try cudagraph_mode=PIECEWISE, and "\
|
||||
"make sure compilation level is piecewise"
|
||||
raise ValueError(msg)
|
||||
|
||||
# attempt to resolve the full graph related mode
|
||||
if splitting_ops_contain_attention:
|
||||
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE)
|
||||
else:
|
||||
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
||||
CUDAGraphMode.FULL_DECODE_ONLY)
|
||||
logger.warning(msg)
|
||||
|
||||
# check that if spec-decode + decode full-graphs is supported
|
||||
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and self.uniform_decode_query_len > 1 and min_ag_support.value
|
||||
< AttentionCGSupport.UNIFORM_BATCH.value):
|
||||
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
|
||||
f" with spec-decode for attention backend "
|
||||
f"{min_ag_builder_name} (support: {min_ag_support})")
|
||||
if splitting_ops_contain_attention:
|
||||
msg += "; setting cudagraph_mode=PIECEWISE"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
msg += "; setting cudagraph_mode=NONE"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.NONE
|
||||
logger.warning(msg)
|
||||
|
||||
# double check that we can support full graph if they are requested
|
||||
# even after automatic downgrades
|
||||
if aclgraph_mode.has_full_cudagraphs() \
|
||||
and min_ag_support == AttentionCGSupport.NEVER:
|
||||
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
|
||||
f"supported with {min_ag_builder_name} backend ("
|
||||
f"support:{min_ag_support}) "
|
||||
"; please try cudagraph_mode=PIECEWISE, "
|
||||
"and make sure compilation level is piecewise")
|
||||
|
||||
self.aclgraph_dispatcher.initialize_cudagraph_keys(
|
||||
self.compilation_config.cudagraph_mode,
|
||||
self.uniform_decode_query_len)
|
||||
@@ -3178,10 +3350,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
aclgraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool):
|
||||
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
|
||||
aclgraph_runtime_mode in [CUDAGraphMode.PIECEWISE]
|
||||
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
|
||||
CUDAGraphMode.PIECEWISE]
|
||||
|
||||
# Only rank 0 should print progress bar during capture
|
||||
if is_global_first_rank():
|
||||
logger.info(
|
||||
"Starting to capture ACL graphs for cases: %s, "
|
||||
"mode: %s, uniform_decode: %s", compilation_cases,
|
||||
aclgraph_runtime_mode.name, uniform_decode)
|
||||
compilation_cases = tqdm(
|
||||
compilation_cases,
|
||||
disable=not self.load_config.use_tqdm_on_load,
|
||||
@@ -3203,6 +3380,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
uniform_decode=uniform_decode)
|
||||
self._dummy_run(num_tokens,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode)
|
||||
|
||||
def _capture_model(self):
|
||||
@@ -3229,6 +3407,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
uniform_decode=False)
|
||||
|
||||
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||
aclgraph_mode.separate_routine():
|
||||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
||||
self.uniform_decode_query_len
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
|
||||
and x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(
|
||||
reversed(decode_cudagraph_batch_sizes))
|
||||
self._capture_aclgraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
uniform_decode=True)
|
||||
|
||||
# Disable aclgraph capturing globally, so any unexpected aclgraph
|
||||
# capturing will be detected and raise an error after here.
|
||||
# Note: We don't put it into graph_capture context manager because
|
||||
|
||||
Reference in New Issue
Block a user