[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
@@ -32,7 +32,11 @@ from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
set_mtp_graph_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable,
|
||||
@@ -52,9 +56,14 @@ logger = init_logger(__name__)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
_deepseek_mtp_path = "vllm.model_executor.models.deepseek_mtp"
|
||||
_deepseek_mtp_model = "DeepSeekMTP"
|
||||
if vllm_version_is("0.11.0"):
|
||||
_deepseek_mtp_path = "vllm_ascend.patch.worker.patch_deepseek_mtp"
|
||||
_deepseek_mtp_model = "AscendDeepSeekMTP"
|
||||
|
||||
_MTP_MODELS = {
|
||||
"DeepseekV3ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"DeepseekV3ForCausalLM": (_deepseek_mtp_path, _deepseek_mtp_model),
|
||||
"Qwen3NextForCausalLM":
|
||||
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
|
||||
}
|
||||
@@ -75,6 +84,9 @@ def _load_model(architecture):
|
||||
|
||||
class MtpProposer(Proposer):
|
||||
|
||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||
model: Union[nn.Module, ACLGraphWrapper]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -203,6 +215,15 @@ class MtpProposer(Proposer):
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
):
|
||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||
set_mtp_graph_params(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes)
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
@@ -222,12 +243,55 @@ class MtpProposer(Proposer):
|
||||
moe_comm_type = self.runner._select_moe_comm_method(
|
||||
num_tokens, with_prefill)
|
||||
|
||||
attn_metadata = None
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
elif aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if len(self.runner.attn_groups) > 0:
|
||||
num_computed_tokens_cpu = (
|
||||
self.runner.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.runner.
|
||||
query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.runner.seq_lens_cpu,
|
||||
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=self.num_speculative_tokens + 1,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor()[:num_reqs],
|
||||
slot_mapping=self.runner.input_batch.block_table[0].
|
||||
slot_mapping,
|
||||
positions=self.runner.positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
cos=self.runner.cos,
|
||||
sin=self.runner.sin,
|
||||
)
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.SpecDecoding,
|
||||
self.runner.get_model())
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
else:
|
||||
attn_metadata = None
|
||||
else:
|
||||
attn_metadata = None
|
||||
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
for _ in range(self.num_speculative_tokens):
|
||||
for i in range(self.num_speculative_tokens):
|
||||
if i > 0:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -239,10 +303,19 @@ class MtpProposer(Proposer):
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor):
|
||||
batch_descriptor=batch_descriptor,
|
||||
is_mtp_model=True):
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
positions.shape[0],
|
||||
self.vllm_config.speculative_config)
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
@@ -324,6 +397,7 @@ class MtpProposer(Proposer):
|
||||
common_attn_metadata.query_start_loc = \
|
||||
query_start_loc_pcp_full[:num_reqs + 1]
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices =\
|
||||
self._prepare_inputs(
|
||||
@@ -358,6 +432,8 @@ class MtpProposer(Proposer):
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
@@ -460,6 +536,13 @@ class MtpProposer(Proposer):
|
||||
token_indices = torch.from_numpy(token_indices_np).to(
|
||||
device, non_blocking=True)
|
||||
|
||||
common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_(
|
||||
common_attn_metadata.slot_mapping[token_indices])
|
||||
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=new_query_start_loc_cpu.to(device,
|
||||
non_blocking=True),
|
||||
@@ -472,7 +555,7 @@ class MtpProposer(Proposer):
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
positions=common_attn_metadata.positions[token_indices],
|
||||
attn_mask=self.runner.attn_mask,
|
||||
@@ -502,6 +585,8 @@ class MtpProposer(Proposer):
|
||||
long_seq_metadata=None,
|
||||
num_prefill_reqs=0,
|
||||
num_decode_reqs=0,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
@@ -585,14 +670,11 @@ class MtpProposer(Proposer):
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
|
||||
if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
# Acl graph mode, add padding to the batch size
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
@@ -609,19 +691,39 @@ class MtpProposer(Proposer):
|
||||
|
||||
moe_comm_type = self.runner._select_moe_comm_method(
|
||||
num_input_tokens, with_prefill)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=False)
|
||||
|
||||
if scheduler_output:
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
uniform_decode = (max_query_len in list(
|
||||
range(1, self.num_speculative_tokens +
|
||||
2))) and (scheduler_output.total_num_scheduled_tokens
|
||||
== self.runner.input_batch.num_reqs *
|
||||
(self.num_speculative_tokens + 1))
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode)
|
||||
else:
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=False)
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
|
||||
if aclgraph_runtime_mode not in [
|
||||
CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE
|
||||
]:
|
||||
# Fallback to piecewise graph, when acl full graph is enabled
|
||||
logger.debug(
|
||||
"Currently the eagle proposer only supports cudagraph_mode "
|
||||
f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
|
||||
"to CUDAGraphMode.PIECEWISE")
|
||||
aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
graph_pad_size = num_input_tokens
|
||||
else:
|
||||
# Currently, runner.graph_pad_size will always be -1.
|
||||
graph_pad_size = self.runner.graph_pad_size
|
||||
|
||||
# If use fullgraph and disable_padded_drafter_batch=True, We need to
|
||||
# update the graph_pad_size in common_attn_metadata, to tell the
|
||||
# builder padding some elements.
|
||||
common_attn_metadata.graph_pad_size = graph_pad_size
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
|
||||
for step in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
@@ -635,7 +737,8 @@ class MtpProposer(Proposer):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
num_actual_tokens=num_tokens,
|
||||
is_mtp_model=True):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
@@ -644,6 +747,13 @@ class MtpProposer(Proposer):
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens])
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
@@ -699,12 +809,21 @@ class MtpProposer(Proposer):
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||
1:batch_size + 1].tolist()
|
||||
attn_metadata_i.decode.cos = builder.cos_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
attn_metadata_i.decode.sin = builder.sin_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
|
||||
if self.speculative_config.disable_padded_drafter_batch or \
|
||||
aclgraph_runtime_mode != CUDAGraphMode.FULL:
|
||||
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||
1:batch_size + 1].tolist()
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
attn_metadata_i.decode.actual_seq_lengths_q = \
|
||||
builder.pad_actual_seq_len_q_mtp_disable_pad(
|
||||
graph_pad_size - batch_size,
|
||||
batch_size,
|
||||
attn_metadata_i.decode.actual_seq_lengths_q)
|
||||
attn_metadata_i.decode.cos = builder.cos_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
attn_metadata_i.decode.sin = builder.sin_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
@@ -735,6 +854,10 @@ class MtpProposer(Proposer):
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
self.positions[batch_size:num_input_tokens] = 0
|
||||
self.input_ids[batch_size:num_input_tokens] = 0
|
||||
self.hidden_states[batch_size:num_input_tokens].fill_(0)
|
||||
|
||||
if attn_metadata_i.prefill is not None:
|
||||
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
||||
@@ -751,12 +874,19 @@ class MtpProposer(Proposer):
|
||||
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
||||
)
|
||||
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
self.speculative_config.disable_padded_drafter_batch:
|
||||
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [
|
||||
0
|
||||
] * (graph_pad_size - len(decode_seq_lens_list))
|
||||
attn_metadata_i.decode.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.decode.max_seq_lens += 1
|
||||
attn_metadata_i.decode.max_seq_lens = min(
|
||||
attn_metadata_i.decode.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
torch.npu.synchronize()
|
||||
|
||||
# mtp>1: [batch_size, k]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
@@ -915,6 +1045,9 @@ class MtpProposer(Proposer):
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
|
||||
Reference in New Issue
Block a user