[FEATURE][MTP] Support MTP > 1 (#2708)

### What this PR does / why we need it?
[RFC:Support MTP > 1 for
DeepSeek](https://github.com/vllm-project/vllm-ascend/issues/2745)

- [x] dp1 tp16
- [x] dp4 tp4
- [x] dp2 tp 8
- [x] torchair graph

- vLLM version: v0.10.1.1
- vLLM main:
c9f7081f9c

Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
1092626063
2025-09-05 09:11:22 +08:00
committed by GitHub
parent 83eb40a51c
commit 5b3646ab21
5 changed files with 206 additions and 88 deletions

View File

@@ -20,9 +20,10 @@ def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
def test_mtp_correctness(
def mtp_correctness(
sampling_config: SamplingParams,
model_name: str,
num_speculative_tokens: int,
):
example_prompts = [
"Hello, my name is",
@@ -50,7 +51,7 @@ def test_mtp_correctness(
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
"num_speculative_tokens": num_speculative_tokens,
},
enforce_eager=True,
max_model_len=2000,
@@ -74,3 +75,18 @@ def test_mtp_correctness(
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
def test_mtp1_correctness(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1)
def test_mtp2_correctness(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2)

View File

@@ -841,7 +841,6 @@ class AscendMLAImpl(MLAAttentionImpl):
input_layout = "BNSD"
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1)

View File

@@ -27,6 +27,8 @@ from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
TorchairCommonAttentionMetadata)
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
PADDING_SLOT_ID = -1
class MtpProposer(Proposer):
@@ -40,6 +42,7 @@ class MtpProposer(Proposer):
self.vllm_config = vllm_config
self.device = device
self.runner = runner
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
# persistent buffers for graph
self.input_ids = torch.zeros(self.runner.max_num_tokens,
@@ -57,6 +60,12 @@ class MtpProposer(Proposer):
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
device=self.runner.device,
dtype=torch.int32)
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
@@ -125,43 +134,47 @@ class MtpProposer(Proposer):
input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens]
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if is_running_torchair:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(previous_hidden_states)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
torchair_compiled_model(
input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states,
inputs_embeds=None,
intermediate_tensors=None,
attn_metadata=attn_metadata,
kv_caches=self.runner.kv_caches[-1:],
spec_step_idx=0)
else:
self.model(input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states)
for _ in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if is_running_torchair:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(previous_hidden_states)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
torchair_compiled_model(
input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states,
inputs_embeds=None,
intermediate_tensors=None,
attn_metadata=attn_metadata,
kv_caches=self.runner.kv_caches[-1:],
spec_step_idx=0)
else:
self.model(input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states)
if with_prefill:
break
def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
@@ -385,57 +398,142 @@ class MtpProposer(Proposer):
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if is_running_torchair:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_input_tokens)
hidden_states = torchair_compiled_model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
inputs_embeds=None,
intermediate_tensors=None,
spec_step_idx=0,
**model_kwargs)
for step in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if is_running_torchair:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_input_tokens)
hidden_states = torchair_compiled_model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
inputs_embeds=None,
intermediate_tensors=None,
spec_step_idx=0,
**model_kwargs)
else:
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
if not self.runner.with_prefill:
max_num_reqs_across_dp = num_input_tokens
else:
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
last_token_indices = nn.functional.pad(
last_token_indices,
(0, max_num_reqs_across_dp - num_indices))
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
if not self.runner.with_prefill:
max_num_reqs_across_dp = num_input_tokens
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
if step == 0:
draft_token_ids_list = [draft_token_ids]
else:
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
last_token_indices = nn.functional.pad(
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
draft_token_ids_list.append(draft_token_ids)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
# prepare next mtp inputs
# mtp>1: prefill skip or decode skip last loop
if with_prefill and self.torchair_graph_enabled:
for _ in range(self.num_speculative_tokens - 1):
draft_token_ids_list.append(draft_token_ids)
if step == self.num_speculative_tokens - 1 or with_prefill:
break
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
if step == 0:
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
slot_mapping = attn_metadata.slot_mapping[last_token_indices]
attn_metadata.slot_mapping.fill_(-1)
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
last_token_indices = self.arange[:batch_size]
if attn_metadata.num_decode_tokens != 0:
attn_metadata.num_decode_tokens = batch_size
if is_running_torchair:
attn_metadata.num_actual_tokens = batch_size
attn_metadata.query_lens = [1] * batch_size
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata.seq_lens.device, non_blocking=True)
attn_metadata.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping += 1
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
attn_metadata.slot_mapping[:batch_size] = slot_mapping
if attn_metadata.prefill is not None:
attn_metadata.prefill.seq_lens = attn_metadata.seq_lens
attn_metadata.prefill.context_lens = attn_metadata.seq_lens
attn_metadata.prefill.input_positions = self.positions[:
num_input_tokens]
attn_metadata.prefill.max_seq_lens += 1
attn_metadata.prefill.max_seq_lens = min(
attn_metadata.prefill.max_seq_lens,
self.runner.model_config.max_model_len)
if attn_metadata.decode is not None:
attn_metadata.decode.seq_lens = attn_metadata.seq_lens
attn_metadata.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata.decode.max_seq_lens += 1
attn_metadata.decode.max_seq_lens = min(
attn_metadata.decode.max_seq_lens,
self.runner.model_config.max_model_len)
# mtp>1: [batch_size, k]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
@@ -511,4 +609,4 @@ class MtpProposer(Proposer):
global_indices_flat = global_indices[mask]
values_flat = values[mask]
out_ptr[global_indices_flat] = values_flat
out_ptr[global_indices_flat] = values_flat

View File

@@ -1020,7 +1020,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
input_layout = "BNSD"
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1)

View File

@@ -17,6 +17,7 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
# isort: skip_file
import math
import types
from typing import Optional
@@ -427,6 +428,11 @@ class NPUTorchairModelRunner(NPUModelRunner):
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * graph_batch_size) \
// math.gcd(tp_size, graph_batch_size)
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)