[Feat] Support MTP to running in full graph mode (#3892)

### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.

The change in both disable_padded_drafter_batch is True and False case
include:

1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
    1). Rebuild MTP model with ACLGraphWrapper.
    2). Add common attn metadata when start capture in MTP dummy_run.
    3). Add common attn metadata update in MTP.
    4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.

Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-11-20 20:34:54 +08:00
committed by GitHub
parent 15c1eb025c
commit 5c9f4a40c6
8 changed files with 536 additions and 42 deletions

View File

@@ -1,5 +1,5 @@
import importlib
from typing import Optional
from typing import Optional, Union
import numpy as np
import torch
@@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
@@ -32,7 +32,11 @@ from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_mtp_graph_params,
update_mla_attn_params)
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
prefill_context_parallel_enable,
@@ -52,9 +56,14 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1
_deepseek_mtp_path = "vllm.model_executor.models.deepseek_mtp"
_deepseek_mtp_model = "DeepSeekMTP"
if vllm_version_is("0.11.0"):
_deepseek_mtp_path = "vllm_ascend.patch.worker.patch_deepseek_mtp"
_deepseek_mtp_model = "AscendDeepSeekMTP"
_MTP_MODELS = {
"DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"DeepseekV3ForCausalLM": (_deepseek_mtp_path, _deepseek_mtp_model),
"Qwen3NextForCausalLM":
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
}
@@ -75,6 +84,9 @@ def _load_model(architecture):
class MtpProposer(Proposer):
# TODO: Find out why ModelRunner does not this explicit typing?
model: Union[nn.Module, ACLGraphWrapper]
def __init__(
self,
vllm_config: VllmConfig,
@@ -203,6 +215,15 @@ class MtpProposer(Proposer):
process_weights_after_loading(self.model, draft_model_config,
target_device)
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
):
self.update_stream: torch.npu.Stream = torch.npu.Stream()
set_mtp_graph_params(
self.vllm_config.compilation_config.cudagraph_capture_sizes)
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
@torch.inference_mode()
def dummy_run(self,
num_tokens: int,
@@ -222,12 +243,55 @@ class MtpProposer(Proposer):
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
attn_metadata = None
if skip_attn:
attn_metadata = None
elif aclgraph_runtime_mode == CUDAGraphMode.FULL:
if len(self.runner.attn_groups) > 0:
num_computed_tokens_cpu = (
self.runner.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.runner.
query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.runner.seq_lens_cpu,
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=self.num_speculative_tokens + 1,
num_computed_tokens_cpu=num_computed_tokens_cpu,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs],
slot_mapping=self.runner.input_batch.block_table[0].
slot_mapping,
positions=self.runner.positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
cos=self.runner.cos,
sin=self.runner.sin,
)
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.SpecDecoding,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
else:
attn_metadata = None
else:
attn_metadata = None
input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens]
for _ in range(self.num_speculative_tokens):
for i in range(self.num_speculative_tokens):
if i > 0:
aclgraph_runtime_mode = CUDAGraphMode.NONE
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
@@ -239,10 +303,19 @@ class MtpProposer(Proposer):
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor):
batch_descriptor=batch_descriptor,
is_mtp_model=True):
self.model(input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing:
if self.vllm_config.model_config.use_mla:
update_mla_attn_params(
self.update_stream, forward_context,
positions.shape[0],
self.vllm_config.speculative_config)
if with_prefill:
break
@@ -324,6 +397,7 @@ class MtpProposer(Proposer):
common_attn_metadata.query_start_loc = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self._prepare_inputs(
@@ -358,6 +432,8 @@ class MtpProposer(Proposer):
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
return draft_token_ids
@@ -460,6 +536,13 @@ class MtpProposer(Proposer):
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)
common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_(
common_attn_metadata.slot_mapping[token_indices])
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
# so we do not need to fixed them. But if they are used in the future,
# we should fixed them.
spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device,
non_blocking=True),
@@ -472,7 +555,7 @@ class MtpProposer(Proposer):
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
slot_mapping=common_attn_metadata.slot_mapping,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
positions=common_attn_metadata.positions[token_indices],
attn_mask=self.runner.attn_mask,
@@ -502,6 +585,8 @@ class MtpProposer(Proposer):
long_seq_metadata=None,
num_prefill_reqs=0,
num_decode_reqs=0,
scheduler_output: SchedulerOutput = None,
num_scheduled_tokens: int = 0,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
@@ -585,14 +670,11 @@ class MtpProposer(Proposer):
assert self.runner is not None
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build(0, common_attn_metadata,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
# Acl graph mode, add padding to the batch size
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
@@ -609,19 +691,39 @@ class MtpProposer(Proposer):
moe_comm_type = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
if scheduler_output:
max_query_len = common_attn_metadata.max_query_len
uniform_decode = (max_query_len in list(
range(1, self.num_speculative_tokens +
2))) and (scheduler_output.total_num_scheduled_tokens
== self.runner.input_batch.num_reqs *
(self.num_speculative_tokens + 1))
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
else:
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
if aclgraph_runtime_mode not in [
CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE
]:
# Fallback to piecewise graph, when acl full graph is enabled
logger.debug(
"Currently the eagle proposer only supports cudagraph_mode "
f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
"to CUDAGraphMode.PIECEWISE")
aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
graph_pad_size = num_input_tokens
else:
# Currently, runner.graph_pad_size will always be -1.
graph_pad_size = self.runner.graph_pad_size
# If use fullgraph and disable_padded_drafter_batch=True, We need to
# update the graph_pad_size in common_attn_metadata, to tell the
# builder padding some elements.
common_attn_metadata.graph_pad_size = graph_pad_size
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build(0, common_attn_metadata,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
for step in range(self.num_speculative_tokens):
with set_ascend_forward_context(
@@ -635,7 +737,8 @@ class MtpProposer(Proposer):
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
num_actual_tokens=num_tokens,
is_mtp_model=True):
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
@@ -644,6 +747,13 @@ class MtpProposer(Proposer):
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens])
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
if self.vllm_config.model_config.use_mla:
update_mla_attn_params(
self.update_stream, forward_context,
num_input_tokens,
self.vllm_config.speculative_config)
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
@@ -699,12 +809,21 @@ class MtpProposer(Proposer):
input_ids = draft_token_ids_list[-1].int()
positions += 1
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
1:batch_size + 1].tolist()
attn_metadata_i.decode.cos = builder.cos_cache[
positions].unsqueeze(1).unsqueeze(2)
attn_metadata_i.decode.sin = builder.sin_cache[
positions].unsqueeze(1).unsqueeze(2)
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
if self.speculative_config.disable_padded_drafter_batch or \
aclgraph_runtime_mode != CUDAGraphMode.FULL:
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
1:batch_size + 1].tolist()
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata_i.decode.actual_seq_lengths_q = \
builder.pad_actual_seq_len_q_mtp_disable_pad(
graph_pad_size - batch_size,
batch_size,
attn_metadata_i.decode.actual_seq_lengths_q)
attn_metadata_i.decode.cos = builder.cos_cache[
positions].unsqueeze(1).unsqueeze(2)
attn_metadata_i.decode.sin = builder.sin_cache[
positions].unsqueeze(1).unsqueeze(2)
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
@@ -735,6 +854,10 @@ class MtpProposer(Proposer):
self.positions[:batch_size] = clamped_positions
self.hidden_states[:hidden_states.shape[0]] = hidden_states
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
if self.speculative_config.disable_padded_drafter_batch:
self.positions[batch_size:num_input_tokens] = 0
self.input_ids[batch_size:num_input_tokens] = 0
self.hidden_states[batch_size:num_input_tokens].fill_(0)
if attn_metadata_i.prefill is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
@@ -751,12 +874,19 @@ class MtpProposer(Proposer):
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
)
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
self.speculative_config.disable_padded_drafter_batch:
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [
0
] * (graph_pad_size - len(decode_seq_lens_list))
attn_metadata_i.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1
attn_metadata_i.decode.max_seq_lens = min(
attn_metadata_i.decode.max_seq_lens,
self.runner.model_config.max_model_len)
torch.npu.synchronize()
# mtp>1: [batch_size, k]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
@@ -915,6 +1045,9 @@ class MtpProposer(Proposer):
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
# so we do not need to fixed them. But if they are used in the future,
# we should fixed them.
spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,