[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (#2128)

Note: This depends on [vLLM
#25161](https://github.com/vllm-project/vllm/pull/25161) and the
torch\_npu release from September 30.

### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:

* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.

**Known issues:**

1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.

There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.

This is essentially a port of #1503 and #1677, but includes two major
changes:

1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.

### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
    "cudagraph_mode": "FULL_DECODE_ONLY",
},
```

### How was this patch tested?
Tests included.


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-09-22 17:14:28 +08:00
committed by GitHub
parent f39bd309b6
commit 338231acaf
14 changed files with 390 additions and 91 deletions

View File

@@ -101,7 +101,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
max_query_len=5, max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]), decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)), block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)), slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]), actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]), positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)), attn_mask=torch.ones((10, 10)),
@@ -134,7 +134,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
max_query_len=6, max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]), decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)), block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)), slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]), actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]), positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)), attn_mask=torch.ones((15, 15)),
@@ -165,7 +165,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
max_query_len=6, max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]), decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)), block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)), slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]), actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]), positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)), attn_mask=torch.ones((15, 15)),
@@ -378,10 +378,12 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_flash_attention_qlens.assert_called_once() mock_flash_attention_qlens.assert_called_once()
assert output.shape == (10, 8 * 64) assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_paged_attention')
def test_forward_decode_only(self, mock_paged_attention, def test_forward_decode_only(self, mock_paged_attention,
mock_npu_reshape_and_cache): mock_npu_reshape_and_cache,
mock_get_forward_context):
"""Test forward pass in DecodeOnly state""" """Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64) query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64)
@@ -395,6 +397,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.slot_mapping = torch.zeros(10, dtype=torch.long) metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl.forward(layer, output = self.impl.forward(layer,
query, query,
key, key,
@@ -435,12 +439,13 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_fused_infer_attention_score.assert_called_once() mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64) assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_paged_attention')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa_seq_len_mismatch( def test_forward_decode_only_swa_seq_len_mismatch(
self, mock_fused_infer_attention_score, mock_paged_attention, self, mock_fused_infer_attention_score, mock_paged_attention,
mock_npu_reshape_and_cache): mock_npu_reshape_and_cache, mock_get_forward_context):
"""Test forward pass in DecodeOnly state when seq)len_mismatch""" """Test forward pass in DecodeOnly state when seq)len_mismatch"""
query = torch.randn(10, 8 * 64) query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64)
@@ -457,6 +462,8 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1) 64), 1)
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl_swa.forward(self.layer_no_quant, output = self.impl_swa.forward(self.layer_no_quant,
query, query,
key, key,

View File

@@ -463,7 +463,7 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
max_query_len=1, max_query_len=1,
decode_token_per_req=torch.tensor([1, 1, 1]), decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)), block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)), slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]), actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([1, 1]), positions=torch.tensor([1, 1]),
attn_mask=torch.ones((15, 15)), attn_mask=torch.ones((15, 15)),

View File

@@ -31,13 +31,15 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group) is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
nd_to_nz_2d, nd_to_nz_spec) get_graph_params, is_310p, nd_to_nz_2d,
nd_to_nz_spec)
def wait_for_kv_layer_from_connector(layer_name: str): def wait_for_kv_layer_from_connector(layer_name: str):
@@ -197,6 +199,12 @@ class AscendMetadata:
class AscendAttentionMetadataBuilder: class AscendAttentionMetadataBuilder:
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: ClassVar[int] = 1
def __init__( def __init__(
@@ -221,7 +229,7 @@ class AscendAttentionMetadataBuilder:
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module, model: Optional[nn.Module] = None,
): ):
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -231,11 +239,7 @@ class AscendAttentionMetadataBuilder:
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[: slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
num_actual_tokens].to(
self.device,
non_blocking=
True)
attn_mask = common_attn_metadata.attn_mask attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
@@ -268,6 +272,24 @@ class AscendAttentionMetadataBuilder:
is_only_prefill=common_attn_metadata.is_only_prefill) is_only_prefill=common_attn_metadata.is_only_prefill)
return attn_metadata return attn_metadata
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)
attn_metadata.attn_state = attn_state
return attn_metadata
class AscendAttentionBackendImpl(AttentionImpl): class AscendAttentionBackendImpl(AttentionImpl):
@@ -406,16 +428,53 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = output.view(batch_size, self.num_heads, self.head_size) output = output.view(batch_size, self.num_heads, self.head_size)
else: else:
torch_npu._npu_paged_attention( graph_params = get_graph_params()
query=query, forward_context: ForwardContext = get_forward_context()
key_cache=self.key_cache, num_tokens = query.shape[0]
value_cache=self.value_cache, if forward_context.capturing:
num_kv_heads=self.num_kv_heads, stream = torch_npu.npu.current_stream()
num_heads=self.num_heads,
scale_value=self.scale, event = torch.npu.ExternalEvent()
block_table=attn_metadata.block_tables, event.wait(stream)
context_lens=attn_metadata.seq_lens, event.reset(stream)
out=output) graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
query,
self.key_cache,
self.value_cache,
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
output,
))
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
return output return output
def _forward_v1_style( def _forward_v1_style(

View File

@@ -292,11 +292,7 @@ class AscendMLAMetadataBuilder:
device = self.device device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping_cpu[: slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
num_actual_tokens].to(
device,
non_blocking=
True)
input_positions = common_attn_metadata.positions[: input_positions = common_attn_metadata.positions[:
num_actual_tokens].long( num_actual_tokens].long(
) )

View File

@@ -41,7 +41,7 @@ class AscendCommonAttentionMetadata:
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
slot_mapping_cpu: torch.Tensor slot_mapping: torch.Tensor
actual_seq_lengths_q: list[int] actual_seq_lengths_q: list[int]

View File

@@ -147,6 +147,7 @@ class ACLGraphWrapper:
patch("torch.npu.empty_cache", lambda: None)) patch("torch.npu.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory. # mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True
with torch.npu.graph(aclgraph, pool=self.graph_pool): with torch.npu.graph(aclgraph, pool=self.graph_pool):
# `output` is managed by pytorch's aclgraph pool # `output` is managed by pytorch's aclgraph pool
output = self.runnable(*args, **kwargs) output = self.runnable(*args, **kwargs)

View File

@@ -179,23 +179,13 @@ class NPUPlatform(Platform):
compilation_config.cudagraph_num_of_warmups = 1 compilation_config.cudagraph_num_of_warmups = 1
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode` if compilation_config.level not in [
# if cudagraph_mode is not explicitly set by users, set default value
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
elif compilation_config.level not in [
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
]: ]:
logger.warning( logger.warning(
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE", "NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
compilation_config.level) compilation_config.level)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.cudagraph_mode = CUDAGraphMode.NONE
else:
logger.warning(
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is. # set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
if ascend_config.torchair_graph_config.enabled: if ascend_config.torchair_graph_config.enabled:
@@ -221,7 +211,12 @@ class NPUPlatform(Platform):
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.level = CompilationLevel.NO_COMPILATION compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: # TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
# after MLA being supported
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
and model_config.use_mla):
logger.info( logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - " "PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode") "using only ACL Graph mode")
@@ -233,6 +228,24 @@ class NPUPlatform(Platform):
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward" "vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
]) ])
update_aclgraph_sizes(vllm_config) update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
compilation_config.use_inductor = False
warning_message = """\033[91m
**********************************************************************************
* WARNING: You have enabled the *full graph* feature.
* This is an early experimental stage and may involve various unknown issues.
* A known problem is that capturing too many batch sizes can lead to OOM
* (Out of Memory) errors or inference hangs. If you encounter such issues,
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
* batch size for graph capture.
* For more details, please refer to:
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
else: else:
logger.info( logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE", "%s cudagraph_mode is not support on NPU. falling back to NONE",
@@ -379,3 +392,7 @@ class NPUPlatform(Platform):
@classmethod @classmethod
def support_hybrid_kv_cache(cls) -> bool: def support_hybrid_kv_cache(cls) -> bool:
return True return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True

View File

@@ -347,7 +347,7 @@ class EagleProposer(Proposer):
actual_seq_lengths_q=self.runner.actual_seq_lengths_q, actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0]. block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(), get_device_tensor(),
slot_mapping_cpu=self.runner.slot_mapping_cpu, slot_mapping=self.runner.slot_mapping,
positions=self.runner.positions, positions=self.runner.positions,
attn_mask=self.runner.attn_mask, attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask, spec_attn_mask=self.runner.spec_attn_mask,
@@ -434,7 +434,7 @@ class EagleProposer(Proposer):
actual_seq_lengths_q=self.runner.actual_seq_lengths_q, actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0]. block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(), get_device_tensor(),
slot_mapping_cpu=target_slot_mapping, slot_mapping=target_slot_mapping,
positions=target_positions, positions=target_positions,
attn_mask=self.runner.attn_mask, attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask, spec_attn_mask=self.runner.spec_attn_mask,

View File

@@ -385,7 +385,7 @@ class MtpProposer(Proposer):
actual_seq_lengths_q=self.runner.actual_seq_lengths_q, actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0]. block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(), get_device_tensor(),
slot_mapping_cpu=target_slot_mapping, slot_mapping=target_slot_mapping,
positions=target_positions, positions=target_positions,
attn_mask=self.runner.attn_mask, attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask, spec_attn_mask=self.runner.spec_attn_mask,

View File

@@ -175,7 +175,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module, model: Optional[nn.Module] = None,
): ):
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -185,11 +185,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
block_table[:num_reqs]) block_table[:num_reqs])
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[: slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
num_actual_tokens].to(
self.device,
non_blocking=
True)
attn_mask = common_attn_metadata.attn_mask attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state attn_state = common_attn_metadata.attn_state

View File

@@ -400,11 +400,7 @@ class AscendMLATorchairMetadataBuilder:
device = self.device device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping_cpu[: slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
num_actual_tokens].to(
device,
non_blocking=
True)
input_positions = common_attn_metadata.positions[: input_positions = common_attn_metadata.positions[:
num_actual_tokens].long( num_actual_tokens].long(
) )

View File

@@ -121,12 +121,14 @@ class NPUTorchairModelRunner(NPUModelRunner):
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): def _build_attention_metadata(self, with_prefill, num_reqs, num_tokens,
max_query_len, force_attention):
# NOTE: If torchair graph mode and not with_prefill, # NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile. # we can't skip_attn, it will cause graph recompile.
if with_prefill or self.enable_shared_expert_dp: if with_prefill or self.enable_shared_expert_dp:
attn_metadata = super()._build_attention_metadata( attn_metadata = super()._build_attention_metadata(
with_prefill, num_reqs, skip_attn) with_prefill, num_reqs, num_tokens, max_query_len,
force_attention)
else: else:
common_attn_metadata = TorchairCommonAttentionMetadata( common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs, num_reqs=num_reqs,

View File

@@ -22,6 +22,7 @@ import functools
import math import math
import os import os
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum from enum import Enum
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
@@ -634,3 +635,34 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
return nullcontext() return nullcontext()
assert target_stream is not None assert target_stream is not None
return torch.npu.stream(target_stream) return torch.npu.stream(target_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def get_graph_params():
return _graph_params

View File

@@ -70,8 +70,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv, get_dtype_size, LazyLoader, cdiv, get_dtype_size,
is_pin_memory_available) is_pin_memory_available)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import \ from vllm.v1.attention.backends.utils import (
reorder_batch_to_split_decodes_and_prefills AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@@ -116,8 +116,9 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration, AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p, get_ascend_soc_version, get_graph_params,
lmhead_tp_enable, vllm_version_is) is_310p, lmhead_tp_enable, set_graph_params,
vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -352,6 +353,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens = torch.zeros(self.max_num_reqs, self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.uses_mrope = self.model_config.uses_mrope self.uses_mrope = self.model_config.uses_mrope
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -1222,7 +1226,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor, ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[torch.Tensor], Optional[torch.Tensor], int]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
@@ -1475,11 +1479,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
blk_table_tensor = blk_table.get_device_tensor() blk_table_tensor = blk_table.get_device_tensor()
slot_mapping = blk_table.slot_mapping_cpu[: slot_mapping = blk_table.slot_mapping_cpu[:
total_num_scheduled_tokens] total_num_scheduled_tokens]
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_( self.slot_mapping[:total_num_scheduled_tokens].copy_(
slot_mapping) slot_mapping[:total_num_scheduled_tokens],
# # Fill unused with -1. Needed for reshape_and_cache in full cuda non_blocking=True,
# # graph mode. )
# blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
# Make AscendCommonAttentionMetadata # Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata( common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1492,7 +1495,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
actual_seq_lengths_q=self.actual_seq_lengths_q, actual_seq_lengths_q=self.actual_seq_lengths_q,
# TODO: change this to the right block table for linear attn # TODO: change this to the right block table for linear attn
block_table_tensor=blk_table_tensor[:num_reqs], block_table_tensor=blk_table_tensor[:num_reqs],
slot_mapping_cpu=self.slot_mapping_cpu, slot_mapping=self.slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu, num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions, positions=self.positions,
attn_mask=self.attn_mask, attn_mask=self.attn_mask,
@@ -1549,7 +1552,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return (attn_metadata, positions, num_scheduled_tokens, return (attn_metadata, positions, num_scheduled_tokens,
num_input_tokens, num_tokens_across_dp, num_input_tokens, num_tokens_across_dp,
maybe_padded_num_tokens, logits_indices, spec_decode_metadata, maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
input_ids, inputs_embeds, intermediate_tensors) input_ids, inputs_embeds, intermediate_tensors,
max_num_scheduled_tokens)
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
maybe_padded_num_tokens, maybe_padded_num_tokens,
@@ -1563,6 +1567,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
graph_params = get_graph_params()
self.update_attn_params(graph_params, forward_context,
positions.shape[0])
if get_forward_context().flashcomm_v1_enabled: if get_forward_context().flashcomm_v1_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size pad_size = get_forward_context().pad_size
@@ -1570,6 +1581,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states = hidden_states[:-pad_size, :] hidden_states = hidden_states[:-pad_size, :]
return hidden_states return hidden_states
def update_attn_params(self, graph_params, forward_context, runtime_shape):
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(self.update_stream):
torch.npu.graph_task_update_begin(self.update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(self.update_stream)
event.record(self.update_stream)
def _build_attn_state(self, num_reqs, num_scheduled_tokens, def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens): num_valid_tokens):
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
@@ -1886,8 +1935,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
(attn_metadata, positions, num_scheduled_tokens_np, (attn_metadata, positions, num_scheduled_tokens_np,
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
logits_indices, spec_decode_metadata, input_ids, inputs_embeds, logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors) = (self._prepare_inputs( intermediate_tensors,
scheduler_output, intermediate_tensors)) max_query_len) = (self._prepare_inputs(scheduler_output,
intermediate_tensors))
if self.dynamic_eplb: if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.take_update_info_from_eplb_process()
@@ -1895,8 +1945,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
moe_comm_method = self._select_moe_comm_method(num_input_tokens, moe_comm_method = self._select_moe_comm_method(num_input_tokens,
self.with_prefill) self.with_prefill)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
scheduler_output.total_num_scheduled_tokens
== self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False) uniform_decode=uniform_decode)
aclgraph_runtime_mode, batch_descriptor = \ aclgraph_runtime_mode, batch_descriptor = \
self.aclgraph_dispatcher.dispatch(batch_descriptor) self.aclgraph_dispatcher.dispatch(batch_descriptor)
@@ -2215,12 +2268,54 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.finished_req_ids) scheduler_output.finished_req_ids)
return None, None return None, None
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): def _build_attention_metadata(self, create_mixed_batch, num_reqs,
if skip_attn: num_tokens, max_query_len, force_attention):
attn_metadata = None attn_metadata: Optional[dict[str, Any]] = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata if force_attention:
attn_metadata = None attn_metadata = {}
if create_mixed_batch:
raise NotImplementedError(
"force_attention=True is not supported for mixed batches.")
else:
seq_lens = self.model_config.max_model_len
self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_table_tensor = self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu,
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=self.slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
decode_token_per_req=self.decode_token_per_req,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
return attn_metadata return attn_metadata
def _generate_dummy_run_hidden_states(self, with_prefill, def _generate_dummy_run_hidden_states(self, with_prefill,
@@ -2249,12 +2344,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
) -> torch.Tensor: ) -> torch.Tensor:
# only support eager mode and piecewise graph now # only support eager mode and piecewise graph now
assert aclgraph_runtime_mode in { assert aclgraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
} }
if force_attention:
raise RuntimeError(
"Capturing attention in aclgraph is unexpected, because full graph is not supported now"
)
# Padding for DP # Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill, (num_tokens, num_tokens_across_dp, with_prefill,
@@ -2310,9 +2401,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer: if self.is_kv_producer:
with_prefill = True with_prefill = True
attn_metadata = self._build_attention_metadata(with_prefill, attn_metadata = self._build_attention_metadata(
num_reqs, with_prefill,
skip_attn=True) num_reqs,
num_tokens,
max_query_len,
force_attention,
)
if not self.in_profile_run and self.dynamic_eplb: if not self.in_profile_run and self.dynamic_eplb:
self.eplb_updator.forward_before() self.eplb_updator.forward_before()
@@ -2551,6 +2646,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30)) m.consumed_memory / float(2**30))
# wrap the model with full graph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.update_stream = torch.npu.Stream()
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def _convert_torch_format(self, tensor): def _convert_torch_format(self, tensor):
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor return tensor
@@ -3167,9 +3270,78 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec return kv_cache_spec
def initialize_aclgraph_capture(self) -> None: def initialize_aclgraph_capture(self) -> None:
# TODO: Add check of AttentionCGSupport and cudagraph_mode.decode_mode when full graph is supported min_ag_support = AttentionCGSupport.ALWAYS
# Trigger aclgraph dispatching keys initialization here (after min_ag_builder_name = None
# initializing attn backends).
for attn_group in self._attn_group_iterator():
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
if builder.cudagraph_support.value < min_ag_support.value:
min_ag_support = builder.cudagraph_support
min_ag_builder_name = builder.__class__.__name__
# This is an imitation of compilation_config.splitting_ops_contain_attention()
splitting_ops_contain_attention = (
self.compilation_config.splitting_ops is not None
and all(op in self.compilation_config.splitting_ops for op in [
"vllm.unified_ascend_attention_with_output",
"vllm.mla_forward",
]))
# Flexible resolve the aclgraph mode
aclgraph_mode = self.compilation_config.cudagraph_mode
# check graph for mixed batch is supported
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
and min_ag_support != AttentionCGSupport.ALWAYS:
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
f"with {min_ag_builder_name} backend (support: "
f"{min_ag_support})")
if min_ag_support == AttentionCGSupport.NEVER:
# if not supported any full graphs, just raise it.
msg += "; please try cudagraph_mode=PIECEWISE, and "\
"make sure compilation level is piecewise"
raise ValueError(msg)
# attempt to resolve the full graph related mode
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_AND_PIECEWISE)
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY)
logger.warning(msg)
# check that if spec-decode + decode full-graphs is supported
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1 and min_ag_support.value
< AttentionCGSupport.UNIFORM_BATCH.value):
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_ag_builder_name} (support: {min_ag_support})")
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
aclgraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# double check that we can support full graph if they are requested
# even after automatic downgrades
if aclgraph_mode.has_full_cudagraphs() \
and min_ag_support == AttentionCGSupport.NEVER:
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
f"supported with {min_ag_builder_name} backend ("
f"support:{min_ag_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise")
self.aclgraph_dispatcher.initialize_cudagraph_keys( self.aclgraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode, self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len) self.uniform_decode_query_len)
@@ -3178,10 +3350,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode: CUDAGraphMode, aclgraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool): uniform_decode: bool):
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \ assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
aclgraph_runtime_mode in [CUDAGraphMode.PIECEWISE] aclgraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture # Only rank 0 should print progress bar during capture
if is_global_first_rank(): if is_global_first_rank():
logger.info(
"Starting to capture ACL graphs for cases: %s, "
"mode: %s, uniform_decode: %s", compilation_cases,
aclgraph_runtime_mode.name, uniform_decode)
compilation_cases = tqdm( compilation_cases = tqdm(
compilation_cases, compilation_cases,
disable=not self.load_config.use_tqdm_on_load, disable=not self.load_config.use_tqdm_on_load,
@@ -3203,6 +3380,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
uniform_decode=uniform_decode) uniform_decode=uniform_decode)
self._dummy_run(num_tokens, self._dummy_run(num_tokens,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
uniform_decode=uniform_decode) uniform_decode=uniform_decode)
def _capture_model(self): def _capture_model(self):
@@ -3229,6 +3407,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
uniform_decode=False) uniform_decode=False)
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
# Disable aclgraph capturing globally, so any unexpected aclgraph # Disable aclgraph capturing globally, so any unexpected aclgraph
# capturing will be detected and raise an error after here. # capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because # Note: We don't put it into graph_capture context manager because