From 3629bc4431d3edb4224761f9036b3bddb16158d6 Mon Sep 17 00:00:00 2001 From: ZhaoJiangJiang <41458538+ZhaoJiangJiang@users.noreply.github.com> Date: Fri, 22 Aug 2025 17:09:08 +0800 Subject: [PATCH] feat: add mtp ut and fix some bugs (#2453) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Fix mtp mode ut ### Does this PR introduce _any_ user-facing change? Nothing ### How was this patch tested? This can be tested in the same way as a unit test. - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/53415653ff24be03e7c90f5b42ef9cb3f72aad71 Signed-off-by: 赵江江 Co-authored-by: 赵江江 --- .../spec_decode_v1/test_v1_mtp_correctness.py | 91 ++++++++----------- tests/ut/quantization/test_quant_config.py | 1 + tests/ut/torchair/test_torchair_mla.py | 64 +++++++++++++ vllm_ascend/attention/mla_v1.py | 8 +- vllm_ascend/models/deepseek_mtp.py | 2 +- vllm_ascend/ops/fused_moe.py | 6 +- vllm_ascend/quantization/quant_config.py | 2 +- vllm_ascend/torchair/torchair_mla.py | 13 +-- vllm_ascend/worker/mtp_proposer_v1.py | 7 +- vllm_ascend/worker/worker_v1.py | 10 +- 10 files changed, 129 insertions(+), 75 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 10322f4..71d274c 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -1,43 +1,13 @@ from __future__ import annotations -import random -from typing import Any +import os import pytest -from vllm import LLM, SamplingParams +from vllm import SamplingParams +from tests.e2e.conftest import VllmRunner -@pytest.fixture -def test_prompts(): - prompt_types = ["repeat", "sentence"] - num_prompts = 10 - prompts = [] - - random.seed(0) - random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) - - # Generate a mixed batch of prompts, some of which can be easily - # predicted by n-gram matching and some which likely cannot. - for kind in random_prompt_type_choices: - word_choices = ["test", "temp", "hello", "where"] - word = random.choice(word_choices) - if kind == "repeat": - prompt = f""" - please repeat the word '{word}' 10 times. - give no other output than the word at least ten times in a row, - in lowercase with spaces between each word and without quotes. - """ - elif kind == "sentence": - prompt = f""" - please give a ten-word sentence that - uses the word {word} at least once. - give no other output than that simple sentence without quotes. - """ - else: - raise ValueError(f"Unknown prompt type: {kind}") - prompts.append([{"role": "user", "content": prompt}]) - - return prompts +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @pytest.fixture @@ -50,39 +20,56 @@ def model_name(): return "wemaster/deepseek_mtp_main_random_bf16" -@pytest.mark.skipif( - True, reason="TODO: Enable me after test_mtp_correctness is fixed") def test_mtp_correctness( - test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_name: str, ): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] ''' Compare the outputs of a original LLM and a speculative LLM should be the same when using mtp speculative decoding. ''' - ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) - del ref_llm + with VllmRunner(model_name, + tensor_parallel_size=1, + gpu_memory_utilization=0.7, + max_model_len=256, + enforce_eager=True) as ref_llm: + ref_outputs = ref_llm.generate(example_prompts, sampling_config) + + with VllmRunner( + model_name, + tensor_parallel_size=1, + max_num_seqs=256, + gpu_memory_utilization=0.7, + distributed_executor_backend="mp", + enable_expert_parallel=True, + speculative_config={ + "method": "deepseek_mtp", + "num_speculative_tokens": 1, + }, + enforce_eager=True, + max_model_len=2000, + additional_config={"ascend_scheduler_config": { + "enabled": False + }}) as spec_llm: + spec_outputs = spec_llm.generate(example_prompts, sampling_config) - spec_llm = LLM(model=model_name, - trust_remote_code=True, - speculative_config={ - "method": "deepseek_mtp", - "num_speculative_tokens": 1, - }, - max_model_len=256, - enforce_eager=True) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 misses = 0 for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: + ref_token_ids = ref_output[0][0] + spec_token_ids = spec_output[0][0] + if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: matches += 1 else: misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") + print(f"ref_output: {ref_output[1][0]}") + print(f"spec_output: {spec_output[1][0]}") # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py index 5a15cf9..55c72c6 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_quant_config.py @@ -113,6 +113,7 @@ class TestAscendQuantConfig(TestBase): def test_get_quant_method_for_fused_moe(self): fused_moe_layer = MagicMock(spec=FusedMoE) fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig) + fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig) # Test skipped layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \ diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 8a6c14d..6ee983a 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -1,11 +1,13 @@ from unittest.mock import MagicMock, patch import torch +from torch import nn from vllm.distributed.parallel_state import GroupCoordinator from vllm.model_executor.layers.linear import LinearBase from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.torchair.torchair_mla import ( AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata, AscendMLATorchairImpl, AscendMLATorchairMetadata, @@ -398,6 +400,68 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): assert torch.equal(sin_golden, metadata.decode.sin) assert torch.equal(cos_golden, metadata.decode.cos) + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") + def test_build_decode(self, mock_ascend_config): + ascend_config = MagicMock() + mock_ascend_config.return_value = ascend_config + ascend_config.torchair_graph_config.enabled = False + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_vllm_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_device = 'cpu' + model = MagicMock(spec=nn.Module) + model.model = MagicMock(spec=nn.Module) + + builder = AscendMLATorchairMetadataBuilder( + mock_vllm_config, + mock_device, + metadata_cls=AscendMLATorchairMetadata) + builder.rope_dim = 64 + + builder.sin_cache = torch.tensor([10, 10]) + builder.cos_cache = torch.tensor([10, 10]) + + with patch.object(builder, + "_get_graph_runner_block_tables", + side_effect=lambda x, y: y): + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 1, 2, 3]), + query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), + seq_lens_cpu=torch.tensor([1, 1, 1]), + num_reqs=3, + num_actual_tokens=3, + max_query_len=1, + decode_token_per_req=torch.tensor([1, 1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping_cpu=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1, 2]), + positions=torch.tensor([1, 1]), + attn_mask=torch.ones((15, 15)), + spec_attn_mask=None, + attn_state=AscendAttentionState.ChunkedPrefill) + + metadata = builder.build(common_attn_metadata, model) + + self.assertIsInstance(metadata, AscendMLATorchairMetadata) + self.assertEqual(metadata.num_input_tokens, 0) + self.assertEqual(metadata.num_actual_tokens, 3) + self.assertEqual(metadata.num_decodes, 3) + self.assertEqual(metadata.num_decode_tokens, 3) + self.assertEqual(metadata.num_prefills, 0) + self.assertEqual(metadata.attn_state, + AscendAttentionState.ChunkedPrefill) + self.assertIsNone(metadata.prefill) + self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata) + self.assertEqual(metadata.block_tables.shape[0], 3) + self.assertEqual(metadata.block_tables.shape[1], 10) + self.assertEqual(metadata.seq_lens.shape[0], 3) + self.assertEqual(metadata.slot_mapping.shape[0], 3) + self.assertEqual(metadata.query_start_loc.shape[0], 4) + class TestAscendMLATorchairImpl(TestBase): diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index fcad4c8..605d8c1 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -374,18 +374,12 @@ class AscendMLAMetadataBuilder: decode_metadata = None if num_decodes > 0: - actual_seq_lengths_q = query_start_loc[1:].tolist() + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() max_seq_lens = seq_lens[:num_decodes].max().item() seq_lens = seq_lens[:num_decode_tokens] input_positions = input_positions[:num_decode_tokens] block_table = block_table[:num_decode_tokens, ...] seq_lens_list = seq_lens.tolist() - # TODO(xyx): whether this block is necessary without torchair - # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) - batch_size = slot_mapping.size(0) - if actual_seq_lengths_q[-1] != batch_size \ - and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - actual_seq_lengths_q[-1] = batch_size cos = self.cos_cache[input_positions].unsqueeze( # type: ignore 1).unsqueeze(2) diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 03abca4..8bcc4fb 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -215,4 +215,4 @@ class CustomDeepSeekMTP(DeepSeekMTP): hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, previous_hidden_states, inputs_embeds, spec_step_idx) - return hidden_states + return hidden_states \ No newline at end of file diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c772124..0d6dc9c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1178,7 +1178,7 @@ class AscendFusedMoE(FusedMoE): if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - self.moe = FusedMoEConfig.make( + moe = FusedMoEConfig.make( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, @@ -1188,8 +1188,10 @@ class AscendFusedMoE(FusedMoE): in_dtype=params_dtype, quant_config=quant_config) + self.moe_config = moe + if quant_config is None: - self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe) + self.quant_method = AscendUnquantizedFusedMoEMethod(moe) else: self.quant_method = quant_config.get_quant_method(self, prefix) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index ee6793b..65f682d 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -102,7 +102,7 @@ class AscendQuantConfig(QuantizationConfig): elif isinstance(layer, FusedMoE): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): - return AscendUnquantizedFusedMoEMethod(layer.moe) + return AscendUnquantizedFusedMoEMethod(layer.moe_config) return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping) elif isinstance(layer, VocabParallelEmbedding): diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 10718b7..036db47 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -492,17 +492,17 @@ class AscendMLATorchairMetadataBuilder: graph_pad_size = common_attn_metadata.graph_pad_size use_torchair_graph = graph_pad_size != -1 if num_decodes > 0: - actual_seq_lengths_q = query_start_loc[1:].tolist() + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() max_seq_lens = seq_lens[:num_decodes].max().item() seq_lens = seq_lens[:num_decode_tokens] input_positions = input_positions[:num_decode_tokens] block_table = block_table[:num_decode_tokens, ...] + num_token_pad_size = 0 if use_torchair_graph and common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ]: num_reqs_pad_size = 0 - num_token_pad_size = 0 if graph_pad_size != 0: pad_value = 0 num_token_pad_size = graph_pad_size - num_decode_tokens @@ -535,13 +535,14 @@ class AscendMLATorchairMetadataBuilder: device=input_positions.device) input_positions = torch.cat( [input_positions, position_padding]) - actual_seq_lengths_q = query_start_loc[1:].tolist( - ) + common_attn_metadata.actual_seq_lengths_q[ - num_reqs:num_reqs + num_reqs_pad_size] + actual_seq_lengths_q = ( + actual_seq_lengths_q + common_attn_metadata. + actual_seq_lengths_q[num_reqs:num_reqs + + num_reqs_pad_size]) else: seq_lens_list = seq_lens.tolist() # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) - batch_size = slot_mapping.size(0) + batch_size = num_decode_tokens + num_token_pad_size if actual_seq_lengths_q[-1] != batch_size \ and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: actual_seq_lengths_q[-1] = batch_size diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 61320fa..1ec1436 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -190,11 +190,6 @@ class MtpProposer: self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states - if attn_metadata.prefill is not None: - attn_metadata.prefill.query_lens = query_lens.cpu() - attn_metadata.prefill.input_positions = target_positions - attn_metadata.prefill.seq_lens = seq_lens - if not self.torchair_graph_enabled: # torch mode need to update num_tokens_across_dp # TODO: adapt enable_dbo later @@ -213,6 +208,7 @@ class MtpProposer: num_tokens=num_input_tokens, with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, in_profile_run=self.runner.in_profile_run, num_actual_tokens=num_tokens): with ProfileExecuteDuration().capture_async('mtp_forward'): @@ -315,6 +311,7 @@ class MtpProposer: num_tokens=num_tokens, with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, in_profile_run=self.runner.in_profile_run, num_actual_tokens=0): if is_running_torchair: diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 8bd11d3..6c72f84 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -47,9 +47,14 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (init_ascend_soc_version, register_ascend_customop, sleep_mode_enabled, - try_register_lib) + try_register_lib, vllm_version_is) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +if not vllm_version_is("0.10.1.1"): + from vllm.v1.outputs import DraftTokenIds +else: + DraftTokenIds = None + class NPUWorker(WorkerBase): @@ -343,3 +348,6 @@ class NPUWorker(WorkerBase): def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": return self.model_runner.get_supported_tasks() + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + return self.model_runner.take_draft_token_ids() \ No newline at end of file