[MTP] follow custom deepseek modeling changes to support graph mode (#636)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?

As custom deepseek modeling do some changes to support graph mode in
https://github.com/vllm-project/vllm-ascend/pull/585, so i follow it to
change custom deepseek_mtp modeling.

And some modifications for k>1 were not carried over by the
https://github.com/vllm-project/vllm-ascend/pull/429, now i add it.

In order to better take care of the MTP feature in the vllm-ascend
repository, I added cases related to graph mode(torchair), but i skip it
since torchair can not correctly clean up memory in vllmrunner.

Also i add some case for MTP quantization weights, but test weight is
not ready, so i skip it and i will open it when test quant weights is
ready.

https://github.com/vllm-project/vllm-ascend/pull/648 did not completely
fix the sample
change(https://github.com/vllm-project/vllm-ascend/issues/660) issue, I
added the relevant changes.

### Does this PR introduce _any_ user-facing change?
now, u can use following method to use mtp in deepseek v3/r1 float or
quant weights with eager mode.
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    enforce_eager=True,
    trust_remote_code=True,
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

or use mtp in deepseek v3/r1 float or quant weights with graph
mode(torchair)
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    trust_remote_code=True,
    additional_config={
        'enable_graph_mode': True,
    },
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

add notes:
1. now, we support k>1, so u can set num_speculative_tokens > 1 if there
is sufficient redundant computing power;
2. MTP is not supported in V1, we will support it when vLLM does it in
https://github.com/vllm-project/vllm/issues/13500.
3. if u run MTP failed by `segmentation fault`, u can follow v0.7.3
patch https://github.com/vllm-project/vllm-ascend/pull/236 file
`vllm_ascend/patch/patch_metrics.py` method
`__npu_async_metrics_collector_init__`

### How was this patch tested?
local tested passed and test by CI

Signed-off-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
wemaster
2025-04-28 21:18:53 +08:00
committed by GitHub
parent be9e3e8545
commit 54c0e63df7
15 changed files with 288 additions and 39 deletions

View File

@@ -0,0 +1,18 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm_ascend.patch import worker # noqa: F401

View File

@@ -17,7 +17,9 @@
# limitations under the License.
#
import shutil
from itertools import cycle
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
import pytest
@@ -177,6 +179,12 @@ def _check_logprobs_when_output_disabled(
assert spec_pos_logprob_token_id in baseline_pos_logprobs
def _clean_torchair_cache():
cache_path = Path.cwd() / '.torchair_cache'
if cache_path.exists() and cache_path.is_dir():
shutil.rmtree(cache_path)
def run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
@@ -219,10 +227,20 @@ def run_equality_correctness_test(
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)
# TODO current torchair graph mode needs clean torchair cache.
# if do not clean, it will raise error
additional_config = common_llm_kwargs.get("additional_config")
enable_graph_mode = additional_config.get(
"enable_graph_mode") if additional_config else False
with vllm_runner(**org_args) as vllm_model:
if enable_graph_mode:
_clean_torchair_cache()
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
if enable_graph_mode:
_clean_torchair_cache()
if ensure_all_accepted or expected_acceptance_rate is not None:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers[

View File

@@ -36,14 +36,20 @@ However, we still need to verify below scenario could be passed:
With those tests, we can say at least, mtp would not break the
correctess for the target model outputs.
"""
import os
import pytest
from .conftest import run_equality_correctness_test
# main model
# NOTE vLLM use fp8 model, vllm-ascend use bf16 model
MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
# NOTE both main model and MTP are bfloat16
FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
# NOTE main model is w8a8, MTP is bfloat16
QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"
# TODO when msmodelslim can quantify both main and MTP model
# This UT should use w8a8 fully weights.
# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
@@ -51,8 +57,11 @@ MAX_SPEC_TOKENS = 1
# precision
PRECISION = "bfloat16"
os.environ["VLLM_USE_MODELSCOPE"] = "True"
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -66,7 +75,7 @@ PRECISION = "bfloat16"
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
@@ -97,6 +106,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(True, reason="quant model is not ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -110,7 +120,53 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": QUANT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
@@ -158,15 +214,13 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
["disable_logprobs"])
@pytest.mark.skipif(
True,
reason=
"Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"
)
@pytest.mark.skipif(True, reason="torchair ut can not clean mem.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
"additional_config": {
'enable_graph_mode': True,
},
# Print spec metrics.
"disable_log_stats": False,
@@ -175,7 +229,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@@ -192,20 +246,64 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
output_len: int, seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
def test_mtp_e2e_greedy_correctness_torchair_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with torchair graph enabled and different
batch sizes using bfloat16 weights."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(True, reason="quant model is not ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"additional_config": {
'enable_graph_mode': True,
},
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": QUANT_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with torchair graph enabled and different
batch sizes using quant weights."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -221,7 +319,7 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
@@ -256,6 +354,8 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -266,7 +366,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
@@ -305,6 +405,8 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -315,7 +417,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9

View File

@@ -29,7 +29,6 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import create_batch, mock_worker
from vllm_ascend.patch.worker import patch_common # noqa: F401
@pytest.mark.parametrize('queue_size', [4])

View File

@@ -33,7 +33,6 @@ from tests.singlecard.spec_decode.utils import (
assert_logprobs_dict_allclose, create_batch,
create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache)
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker

View File

@@ -24,7 +24,6 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.utils import (
create_seq_group_metadata_from_prompts, create_worker)
from vllm_ascend.patch.worker import patch_common # noqa: F401
def test_ngram_algo_correctness_for_single_no_match():

View File

@@ -39,8 +39,6 @@ from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import (create_batch,
create_sampler_output_list,
create_worker, mock_worker)
# patch SpecDecodeWorker, AsyncMetricsCollector
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker