From 5b3646ab2142131579661ce12e4f0e4ba731ad06 Mon Sep 17 00:00:00 2001 From: 1092626063 <30970038+1092626063@users.noreply.github.com> Date: Fri, 5 Sep 2025 09:11:22 +0800 Subject: [PATCH] [FEATURE][MTP] Support MTP > 1 (#2708) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? [RFC:Support MTP > 1 for DeepSeek](https://github.com/vllm-project/vllm-ascend/issues/2745) - [x] dp1 tp16 - [x] dp4 tp4 - [x] dp2 tp 8 - [x] torchair graph - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/c9f7081f9c848d83ecbf42b57591451d6ff5a7a9 Signed-off-by: 1092626063 <1092626063@qq.com> --- .../spec_decode_v1/test_v1_mtp_correctness.py | 20 +- vllm_ascend/attention/mla_v1.py | 1 - vllm_ascend/spec_decode/mtp_proposer.py | 266 ++++++++++++------ vllm_ascend/torchair/torchair_mla.py | 1 - vllm_ascend/torchair/torchair_model_runner.py | 6 + 5 files changed, 206 insertions(+), 88 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 0c01a07..bbb6e01 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -20,9 +20,10 @@ def model_name(): return "wemaster/deepseek_mtp_main_random_bf16" -def test_mtp_correctness( +def mtp_correctness( sampling_config: SamplingParams, model_name: str, + num_speculative_tokens: int, ): example_prompts = [ "Hello, my name is", @@ -50,7 +51,7 @@ def test_mtp_correctness( enable_expert_parallel=True, speculative_config={ "method": "deepseek_mtp", - "num_speculative_tokens": 1, + "num_speculative_tokens": num_speculative_tokens, }, enforce_eager=True, max_model_len=2000, @@ -74,3 +75,18 @@ def test_mtp_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) + del spec_llm + + +def test_mtp1_correctness( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 1) + + +def test_mtp2_correctness( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 2) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a386f63..0031513 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -841,7 +841,6 @@ class AscendMLAImpl(MLAAttentionImpl): input_layout = "BNSD" if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index a5211d4..5e56493 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -27,6 +27,8 @@ from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata) from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable +PADDING_SLOT_ID = -1 + class MtpProposer(Proposer): @@ -40,6 +42,7 @@ class MtpProposer(Proposer): self.vllm_config = vllm_config self.device = device self.runner = runner + self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens # persistent buffers for graph self.input_ids = torch.zeros(self.runner.max_num_tokens, @@ -57,6 +60,12 @@ class MtpProposer(Proposer): self.torchair_compiled_models = {} # type: ignore self.torchair_graph_enabled = get_ascend_config( ).torchair_graph_config.enabled + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, + device=self.runner.device, + dtype=torch.int32) def load_model(self, model) -> None: loader = get_model_loader(self.vllm_config.load_config) @@ -125,43 +134,47 @@ class MtpProposer(Proposer): input_ids = self.input_ids[:num_tokens] positions = self.positions[:num_tokens] previous_hidden_states = self.hidden_states[:num_tokens] - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - with_prefill=with_prefill, - num_tokens_across_dp=num_tokens_across_dp, - reserved_mc2_mask=self.runner.reserved_mc2_mask, - in_profile_run=self.runner.in_profile_run, - num_actual_tokens=0): - if is_running_torchair: - assert attn_metadata is not None - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static(previous_hidden_states) - torch._dynamo.mark_static(attn_metadata.decode.block_table) - torch._dynamo.mark_static(attn_metadata.decode.input_positions) - if hasattr(attn_metadata.decode, "sin"): - torch._dynamo.mark_static(attn_metadata.decode.sin) - torch._dynamo.mark_static(attn_metadata.decode.cos) - torch._dynamo.mark_static(get_forward_context().mc2_mask) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - torch._dynamo.mark_static(attn_metadata.decode.attn_mask) - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_tokens) - torchair_compiled_model( - input_ids=input_ids, - positions=positions, - previous_hidden_states=previous_hidden_states, - inputs_embeds=None, - intermediate_tensors=None, - attn_metadata=attn_metadata, - kv_caches=self.runner.kv_caches[-1:], - spec_step_idx=0) - else: - self.model(input_ids=input_ids, - positions=positions, - previous_hidden_states=previous_hidden_states) + for _ in range(self.num_speculative_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=0): + if is_running_torchair: + assert attn_metadata is not None + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(previous_hidden_states) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static( + attn_metadata.decode.input_positions) + if hasattr(attn_metadata.decode, "sin"): + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static(get_forward_context().mc2_mask) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_tokens) + torchair_compiled_model( + input_ids=input_ids, + positions=positions, + previous_hidden_states=previous_hidden_states, + inputs_embeds=None, + intermediate_tensors=None, + attn_metadata=attn_metadata, + kv_caches=self.runner.kv_caches[-1:], + spec_step_idx=0) + else: + self.model(input_ids=input_ids, + positions=positions, + previous_hidden_states=previous_hidden_states) + if with_prefill: + break def generate_token_ids(self, valid_sampled_token_ids: list[list[int]], @@ -385,57 +398,142 @@ class MtpProposer(Proposer): num_tokens_across_dp = self.runner.num_tokens_across_dp with_prefill = self.runner.with_prefill - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - with_prefill=with_prefill, - num_tokens_across_dp=num_tokens_across_dp, - reserved_mc2_mask=self.runner.reserved_mc2_mask, - in_profile_run=self.runner.in_profile_run, - num_actual_tokens=num_tokens): - with ProfileExecuteDuration().capture_async('mtp_forward'): - model_kwargs = {} - model_kwargs["attn_metadata"] = attn_metadata - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] - if is_running_torchair: - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_input_tokens) - hidden_states = torchair_compiled_model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - previous_hidden_states=self. - hidden_states[:num_input_tokens], - inputs_embeds=None, - intermediate_tensors=None, - spec_step_idx=0, - **model_kwargs) + for step in range(self.num_speculative_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=num_tokens): + with ProfileExecuteDuration().capture_async('mtp_forward'): + model_kwargs = {} + model_kwargs["attn_metadata"] = attn_metadata + if self.torchair_graph_enabled: + model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] + if is_running_torchair: + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_input_tokens) + hidden_states = torchair_compiled_model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self. + hidden_states[:num_input_tokens], + inputs_embeds=None, + intermediate_tensors=None, + spec_step_idx=0, + **model_kwargs) + else: + hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self. + hidden_states[:num_input_tokens], + kv_caches=self.runner.kv_caches[-1:]) + + num_indices = last_token_indices.shape[0] + if lmhead_tp_enable(): + if not self.runner.with_prefill: + max_num_reqs_across_dp = num_input_tokens else: - hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - previous_hidden_states=self. - hidden_states[:num_input_tokens], - kv_caches=self.runner.kv_caches[-1:]) + max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs + last_token_indices = nn.functional.pad( + last_token_indices, + (0, max_num_reqs_across_dp - num_indices)) - num_indices = last_token_indices.shape[0] - if lmhead_tp_enable(): - if not self.runner.with_prefill: - max_num_reqs_across_dp = num_input_tokens + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if lmhead_tp_enable() and num_indices < logits.shape[0]: + logits = logits[:num_indices] + draft_token_ids = logits.argmax(dim=-1) + + if self.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + if step == 0: + draft_token_ids_list = [draft_token_ids] else: - max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs - last_token_indices = nn.functional.pad( - last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + draft_token_ids_list.append(draft_token_ids) - sample_hidden_states = hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - if lmhead_tp_enable() and num_indices < logits.shape[0]: - logits = logits[:num_indices] - draft_token_ids = logits.argmax(dim=-1) + # prepare next mtp inputs + # mtp>1: prefill skip or decode skip last loop + if with_prefill and self.torchair_graph_enabled: + for _ in range(self.num_speculative_tokens - 1): + draft_token_ids_list.append(draft_token_ids) + if step == self.num_speculative_tokens - 1 or with_prefill: + break - # [batch_size, 1] - return draft_token_ids.view(-1, 1) + if step == 0: + positions = target_positions[last_token_indices] + hidden_states = hidden_states[last_token_indices] + slot_mapping = attn_metadata.slot_mapping[last_token_indices] + attn_metadata.slot_mapping.fill_(-1) + attn_metadata.query_start_loc = self.arange[:batch_size + 1] + last_token_indices = self.arange[:batch_size] + if attn_metadata.num_decode_tokens != 0: + attn_metadata.num_decode_tokens = batch_size + if is_running_torchair: + attn_metadata.num_actual_tokens = batch_size + attn_metadata.query_lens = [1] * batch_size + + input_ids = draft_token_ids_list[-1].int() + positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.runner.model_config.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + # Increment the sequence lengths. + attn_metadata.seq_lens[:batch_size] += 1 + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + exceeds_max_model_len_cpu = exceeds_max_model_len.to( + attn_metadata.seq_lens.device, non_blocking=True) + attn_metadata.seq_lens[:batch_size].masked_fill_( + exceeds_max_model_len_cpu, 1) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping += 1 + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + attn_metadata.slot_mapping[:batch_size] = slot_mapping + + if attn_metadata.prefill is not None: + attn_metadata.prefill.seq_lens = attn_metadata.seq_lens + attn_metadata.prefill.context_lens = attn_metadata.seq_lens + attn_metadata.prefill.input_positions = self.positions[: + num_input_tokens] + attn_metadata.prefill.max_seq_lens += 1 + attn_metadata.prefill.max_seq_lens = min( + attn_metadata.prefill.max_seq_lens, + self.runner.model_config.max_model_len) + if attn_metadata.decode is not None: + attn_metadata.decode.seq_lens = attn_metadata.seq_lens + attn_metadata.decode.input_positions = self.positions[: + num_input_tokens] + attn_metadata.decode.max_seq_lens += 1 + attn_metadata.decode.max_seq_lens = min( + attn_metadata.decode.max_seq_lens, + self.runner.model_config.max_model_len) + + # mtp>1: [batch_size, k] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids def _get_torchair_lazy_compiled_model(self, batch_size: int): if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ @@ -511,4 +609,4 @@ class MtpProposer(Proposer): global_indices_flat = global_indices[mask] values_flat = values[mask] - out_ptr[global_indices_flat] = values_flat \ No newline at end of file + out_ptr[global_indices_flat] = values_flat diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 30ef293..95ca3bd 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -1020,7 +1020,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl): input_layout = "BNSD" if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 2b34f9b..71315b1 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # isort: skip_file +import math import types from typing import Optional @@ -427,6 +428,11 @@ class NPUTorchairModelRunner(NPUModelRunner): for graph_batch_size in self.torchair_graph_batch_sizes: cur_graph_batch_size = (graph_batch_size + tp_size - 1) // tp_size * tp_size + # MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size, + # Both adapter multi-dp and FIA operator + if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1: + cur_graph_batch_size = (tp_size * graph_batch_size) \ + // math.gcd(tp_size, graph_batch_size) if cur_graph_batch_size not in new_graph_batch_sizes and \ cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: new_graph_batch_sizes.append(cur_graph_batch_size)