[Feat][Graph] Support MTP for ACL Graph (#2932)

### What this PR does / why we need it?
This PR depends on the merge of #2707 and has adapted the aclgraph
functionality to support MTP.

### How was this patch tested?


- vLLM version: v0.10.2
- vLLM main:
2b85697031

---------

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-09-18 14:05:33 +08:00
committed by GitHub
parent cef43b524e
commit 6681dde902
7 changed files with 73 additions and 11 deletions

View File

@@ -187,7 +187,14 @@ class AscendMLAMetadataBuilder:
self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
@@ -275,7 +282,6 @@ class AscendMLAMetadataBuilder:
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs

View File

@@ -8,7 +8,7 @@ from torchair import patch_for_hcom
from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.forward_context import get_forward_context
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
@@ -363,8 +363,14 @@ class MtpProposer(Proposer):
not self.runner.with_prefill
if is_running_torchair:
# Torchair graph mode, padding is same as the main model
num_input_tokens = self.runner.graph_pad_size
elif (self.runner.use_aclgraph
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
# Acl graph mode, add padding to the batch size
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
# Eager mode, no padding needed
num_input_tokens = num_tokens
seq_lens = target_positions[last_token_indices] + 1
@@ -410,7 +416,7 @@ class MtpProposer(Proposer):
# TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(
num_tokens, self.runner.with_prefill, False)
num_input_tokens, self.runner.with_prefill, False)
else:
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp
@@ -418,6 +424,10 @@ class MtpProposer(Proposer):
moe_comm_method = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
for step in range(self.num_speculative_tokens):
with set_ascend_forward_context(
@@ -428,6 +438,7 @@ class MtpProposer(Proposer):
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
aclgraph_runtime_mode=aclgraph_runtime_mode,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):

View File

@@ -52,6 +52,10 @@ class NPUTorchairModelRunner(NPUModelRunner):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
super().__init__(vllm_config, device)
if self.speculative_config:
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
None, None, vllm_config, device)

View File

@@ -306,17 +306,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer,
MtpProposer]] = None
self.actual_seq_lengths_q = []
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.actual_seq_lengths_q = [
len for len in
range(self.decode_token_per_req, self.max_num_tokens +
1, self.decode_token_per_req)
]
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),