[FEATURE][MTP] Support MTP > 1 (#2708)

### What this PR does / why we need it?
[RFC:Support MTP > 1 for
DeepSeek](https://github.com/vllm-project/vllm-ascend/issues/2745)

- [x] dp1 tp16
- [x] dp4 tp4
- [x] dp2 tp 8
- [x] torchair graph

- vLLM version: v0.10.1.1
- vLLM main:
c9f7081f9c

Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
1092626063
2025-09-05 09:11:22 +08:00
committed by GitHub
parent 83eb40a51c
commit 5b3646ab21
5 changed files with 206 additions and 88 deletions

View File

@@ -20,9 +20,10 @@ def model_name():
return "wemaster/deepseek_mtp_main_random_bf16" return "wemaster/deepseek_mtp_main_random_bf16"
def test_mtp_correctness( def mtp_correctness(
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
num_speculative_tokens: int,
): ):
example_prompts = [ example_prompts = [
"Hello, my name is", "Hello, my name is",
@@ -50,7 +51,7 @@ def test_mtp_correctness(
enable_expert_parallel=True, enable_expert_parallel=True,
speculative_config={ speculative_config={
"method": "deepseek_mtp", "method": "deepseek_mtp",
"num_speculative_tokens": 1, "num_speculative_tokens": num_speculative_tokens,
}, },
enforce_eager=True, enforce_eager=True,
max_model_len=2000, max_model_len=2000,
@@ -74,3 +75,18 @@ def test_mtp_correctness(
# Heuristic: expect at least 66% of the prompts to match exactly # Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy. # Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs)) assert matches > int(0.66 * len(ref_outputs))
del spec_llm
def test_mtp1_correctness(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1)
def test_mtp2_correctness(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2)

View File

@@ -841,7 +841,6 @@ class AscendMLAImpl(MLAAttentionImpl):
input_layout = "BNSD" input_layout = "BNSD"
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
input_layout = "TND" input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim] # [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1) q_nope = q_nope.view(num_tokens, self.num_heads, -1)

View File

@@ -27,6 +27,8 @@ from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
TorchairCommonAttentionMetadata) TorchairCommonAttentionMetadata)
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
PADDING_SLOT_ID = -1
class MtpProposer(Proposer): class MtpProposer(Proposer):
@@ -40,6 +42,7 @@ class MtpProposer(Proposer):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.device = device self.device = device
self.runner = runner self.runner = runner
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
# persistent buffers for graph # persistent buffers for graph
self.input_ids = torch.zeros(self.runner.max_num_tokens, self.input_ids = torch.zeros(self.runner.max_num_tokens,
@@ -57,6 +60,12 @@ class MtpProposer(Proposer):
self.torchair_compiled_models = {} # type: ignore self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config( self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled ).torchair_graph_config.enabled
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
device=self.runner.device,
dtype=torch.int32)
def load_model(self, model) -> None: def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config) loader = get_model_loader(self.vllm_config.load_config)
@@ -125,43 +134,47 @@ class MtpProposer(Proposer):
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens] previous_hidden_states = self.hidden_states[:num_tokens]
with set_ascend_forward_context( for _ in range(self.num_speculative_tokens):
attn_metadata, with set_ascend_forward_context(
self.vllm_config, attn_metadata,
num_tokens=num_tokens, self.vllm_config,
with_prefill=with_prefill, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill,
reserved_mc2_mask=self.runner.reserved_mc2_mask, num_tokens_across_dp=num_tokens_across_dp,
in_profile_run=self.runner.in_profile_run, reserved_mc2_mask=self.runner.reserved_mc2_mask,
num_actual_tokens=0): in_profile_run=self.runner.in_profile_run,
if is_running_torchair: num_actual_tokens=0):
assert attn_metadata is not None if is_running_torchair:
torch._dynamo.mark_static(input_ids) assert attn_metadata is not None
torch._dynamo.mark_static(positions) torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(previous_hidden_states) torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table) torch._dynamo.mark_static(previous_hidden_states)
torch._dynamo.mark_static(attn_metadata.decode.input_positions) torch._dynamo.mark_static(attn_metadata.decode.block_table)
if hasattr(attn_metadata.decode, "sin"): torch._dynamo.mark_static(
torch._dynamo.mark_static(attn_metadata.decode.sin) attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.decode.cos) if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(get_forward_context().mc2_mask) torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.slot_mapping) torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.decode.attn_mask) torch._dynamo.mark_static(get_forward_context().mc2_mask)
torchair_compiled_model = self._get_torchair_lazy_compiled_model( torch._dynamo.mark_static(attn_metadata.slot_mapping)
num_tokens) torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
torchair_compiled_model( torchair_compiled_model = self._get_torchair_lazy_compiled_model(
input_ids=input_ids, num_tokens)
positions=positions, torchair_compiled_model(
previous_hidden_states=previous_hidden_states, input_ids=input_ids,
inputs_embeds=None, positions=positions,
intermediate_tensors=None, previous_hidden_states=previous_hidden_states,
attn_metadata=attn_metadata, inputs_embeds=None,
kv_caches=self.runner.kv_caches[-1:], intermediate_tensors=None,
spec_step_idx=0) attn_metadata=attn_metadata,
else: kv_caches=self.runner.kv_caches[-1:],
self.model(input_ids=input_ids, spec_step_idx=0)
positions=positions, else:
previous_hidden_states=previous_hidden_states) self.model(input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states)
if with_prefill:
break
def generate_token_ids(self, def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]], valid_sampled_token_ids: list[list[int]],
@@ -385,57 +398,142 @@ class MtpProposer(Proposer):
num_tokens_across_dp = self.runner.num_tokens_across_dp num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill with_prefill = self.runner.with_prefill
with set_ascend_forward_context( for step in range(self.num_speculative_tokens):
attn_metadata, with set_ascend_forward_context(
self.vllm_config, attn_metadata,
num_tokens=num_input_tokens, self.vllm_config,
with_prefill=with_prefill, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill,
reserved_mc2_mask=self.runner.reserved_mc2_mask, num_tokens_across_dp=num_tokens_across_dp,
in_profile_run=self.runner.in_profile_run, reserved_mc2_mask=self.runner.reserved_mc2_mask,
num_actual_tokens=num_tokens): in_profile_run=self.runner.in_profile_run,
with ProfileExecuteDuration().capture_async('mtp_forward'): num_actual_tokens=num_tokens):
model_kwargs = {} with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs["attn_metadata"] = attn_metadata model_kwargs = {}
if self.torchair_graph_enabled: model_kwargs["attn_metadata"] = attn_metadata
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] if self.torchair_graph_enabled:
if is_running_torchair: model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
torchair_compiled_model = self._get_torchair_lazy_compiled_model( if is_running_torchair:
num_input_tokens) torchair_compiled_model = self._get_torchair_lazy_compiled_model(
hidden_states = torchair_compiled_model( num_input_tokens)
input_ids=self.input_ids[:num_input_tokens], hidden_states = torchair_compiled_model(
positions=self.positions[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
previous_hidden_states=self. positions=self.positions[:num_input_tokens],
hidden_states[:num_input_tokens], previous_hidden_states=self.
inputs_embeds=None, hidden_states[:num_input_tokens],
intermediate_tensors=None, inputs_embeds=None,
spec_step_idx=0, intermediate_tensors=None,
**model_kwargs) spec_step_idx=0,
**model_kwargs)
else:
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
if not self.runner.with_prefill:
max_num_reqs_across_dp = num_input_tokens
else: else:
hidden_states = self.model( max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
input_ids=self.input_ids[:num_input_tokens], last_token_indices = nn.functional.pad(
positions=self.positions[:num_input_tokens], last_token_indices,
previous_hidden_states=self. (0, max_num_reqs_across_dp - num_indices))
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
num_indices = last_token_indices.shape[0] sample_hidden_states = hidden_states[last_token_indices]
if lmhead_tp_enable(): logits = self.model.compute_logits(sample_hidden_states, None)
if not self.runner.with_prefill: if lmhead_tp_enable() and num_indices < logits.shape[0]:
max_num_reqs_across_dp = num_input_tokens logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
if step == 0:
draft_token_ids_list = [draft_token_ids]
else: else:
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs draft_token_ids_list.append(draft_token_ids)
last_token_indices = nn.functional.pad(
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
sample_hidden_states = hidden_states[last_token_indices] # prepare next mtp inputs
logits = self.model.compute_logits(sample_hidden_states, None) # mtp>1: prefill skip or decode skip last loop
if lmhead_tp_enable() and num_indices < logits.shape[0]: if with_prefill and self.torchair_graph_enabled:
logits = logits[:num_indices] for _ in range(self.num_speculative_tokens - 1):
draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids)
if step == self.num_speculative_tokens - 1 or with_prefill:
break
# [batch_size, 1] if step == 0:
return draft_token_ids.view(-1, 1) positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
slot_mapping = attn_metadata.slot_mapping[last_token_indices]
attn_metadata.slot_mapping.fill_(-1)
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
last_token_indices = self.arange[:batch_size]
if attn_metadata.num_decode_tokens != 0:
attn_metadata.num_decode_tokens = batch_size
if is_running_torchair:
attn_metadata.num_actual_tokens = batch_size
attn_metadata.query_lens = [1] * batch_size
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata.seq_lens.device, non_blocking=True)
attn_metadata.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping += 1
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
attn_metadata.slot_mapping[:batch_size] = slot_mapping
if attn_metadata.prefill is not None:
attn_metadata.prefill.seq_lens = attn_metadata.seq_lens
attn_metadata.prefill.context_lens = attn_metadata.seq_lens
attn_metadata.prefill.input_positions = self.positions[:
num_input_tokens]
attn_metadata.prefill.max_seq_lens += 1
attn_metadata.prefill.max_seq_lens = min(
attn_metadata.prefill.max_seq_lens,
self.runner.model_config.max_model_len)
if attn_metadata.decode is not None:
attn_metadata.decode.seq_lens = attn_metadata.seq_lens
attn_metadata.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata.decode.max_seq_lens += 1
attn_metadata.decode.max_seq_lens = min(
attn_metadata.decode.max_seq_lens,
self.runner.model_config.max_model_len)
# mtp>1: [batch_size, k]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def _get_torchair_lazy_compiled_model(self, batch_size: int): def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[

View File

@@ -1020,7 +1020,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
input_layout = "BNSD" input_layout = "BNSD"
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
input_layout = "TND" input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim] # [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1) q_nope = q_nope.view(num_tokens, self.num_heads, -1)

View File

@@ -17,6 +17,7 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
# isort: skip_file # isort: skip_file
import math
import types import types
from typing import Optional from typing import Optional
@@ -427,6 +428,11 @@ class NPUTorchairModelRunner(NPUModelRunner):
for graph_batch_size in self.torchair_graph_batch_sizes: for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size - cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size 1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * graph_batch_size) \
// math.gcd(tp_size, graph_batch_size)
if cur_graph_batch_size not in new_graph_batch_sizes and \ if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size) new_graph_batch_sizes.append(cur_graph_batch_size)