[MTP] follow custom deepseek modeling changes to support graph mode (#636)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? As custom deepseek modeling do some changes to support graph mode in https://github.com/vllm-project/vllm-ascend/pull/585, so i follow it to change custom deepseek_mtp modeling. And some modifications for k>1 were not carried over by the https://github.com/vllm-project/vllm-ascend/pull/429, now i add it. In order to better take care of the MTP feature in the vllm-ascend repository, I added cases related to graph mode(torchair), but i skip it since torchair can not correctly clean up memory in vllmrunner. Also i add some case for MTP quantization weights, but test weight is not ready, so i skip it and i will open it when test quant weights is ready. https://github.com/vllm-project/vllm-ascend/pull/648 did not completely fix the sample change(https://github.com/vllm-project/vllm-ascend/issues/660) issue, I added the relevant changes. ### Does this PR introduce _any_ user-facing change? now, u can use following method to use mtp in deepseek v3/r1 float or quant weights with eager mode. ```python llm = LLM( model="wemaster/deepseek_mtp_main_random_bf16", tensor_parallel_size=2, speculative_config={ "num_speculative_tokens": 1, }, enforce_eager=True, trust_remote_code=True, disable_log_stats=False, gpu_memory_utilization=0.8, max_model_len=64, ) ``` or use mtp in deepseek v3/r1 float or quant weights with graph mode(torchair) ```python llm = LLM( model="wemaster/deepseek_mtp_main_random_bf16", tensor_parallel_size=2, speculative_config={ "num_speculative_tokens": 1, }, trust_remote_code=True, additional_config={ 'enable_graph_mode': True, }, disable_log_stats=False, gpu_memory_utilization=0.8, max_model_len=64, ) ``` add notes: 1. now, we support k>1, so u can set num_speculative_tokens > 1 if there is sufficient redundant computing power; 2. MTP is not supported in V1, we will support it when vLLM does it in https://github.com/vllm-project/vllm/issues/13500. 3. if u run MTP failed by `segmentation fault`, u can follow v0.7.3 patch https://github.com/vllm-project/vllm-ascend/pull/236 file `vllm_ascend/patch/patch_metrics.py` method `__npu_async_metrics_collector_init__` ### How was this patch tested? local tested passed and test by CI Signed-off-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
11
.github/workflows/vllm_ascend_test.yaml
vendored
11
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -138,13 +138,18 @@ jobs:
|
||||
speculative_tests_changed:
|
||||
- "tests/singlecard/spec_decode/**"
|
||||
- "tests/multicard/spec_decode_e2e/**"
|
||||
- "vllm_ascend/worker/worker.py"
|
||||
- "vllm_ascend/worker/model_runner.py"
|
||||
- "vllm_ascend/worker/multi_step_runner.py"
|
||||
- "vllm_ascend/worker/multi_step_worker.py"
|
||||
- "vllm_ascend/patch/patch_rejection_sampler.py"
|
||||
- "vllm_ascend/patch/patch_spec_decode_worker.py"
|
||||
- "vllm_ascend/patch/patch_multi_step_worker.py"
|
||||
- "vllm_ascend/worker/draft_model_runner.py"
|
||||
- "vllm_ascend/patch/worker/patch_common/patch_metrics.py"
|
||||
- "vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py"
|
||||
- "vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py"
|
||||
|
||||
- name: Run vllm-project/vllm-ascend Speculative Decode test
|
||||
env:
|
||||
VLLM_USE_V1: 0
|
||||
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true'
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm_ascend.patch import worker # noqa: F401
|
||||
|
||||
@@ -17,7 +17,9 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import shutil
|
||||
from itertools import cycle
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import pytest
|
||||
@@ -177,6 +179,12 @@ def _check_logprobs_when_output_disabled(
|
||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||
|
||||
|
||||
def _clean_torchair_cache():
|
||||
cache_path = Path.cwd() / '.torchair_cache'
|
||||
if cache_path.exists() and cache_path.is_dir():
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
|
||||
def run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
@@ -219,10 +227,20 @@ def run_equality_correctness_test(
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
# TODO current torchair graph mode needs clean torchair cache.
|
||||
# if do not clean, it will raise error
|
||||
additional_config = common_llm_kwargs.get("additional_config")
|
||||
enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode") if additional_config else False
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
if enable_graph_mode:
|
||||
_clean_torchair_cache()
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
if enable_graph_mode:
|
||||
_clean_torchair_cache()
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
# Force log interval to be 0 to catch all metrics.
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers[
|
||||
|
||||
@@ -36,14 +36,20 @@ However, we still need to verify below scenario could be passed:
|
||||
With those tests, we can say at least, mtp would not break the
|
||||
correctess for the target model outputs.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
# NOTE vLLM use fp8 model, vllm-ascend use bf16 model
|
||||
MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
# NOTE both main model and MTP are bfloat16
|
||||
FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
|
||||
# NOTE main model is w8a8, MTP is bfloat16
|
||||
QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"
|
||||
|
||||
# TODO when msmodelslim can quantify both main and MTP model
|
||||
# This UT should use w8a8 fully weights.
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# num_nextn_predict_layers in the config.json of the speculator model.
|
||||
@@ -51,8 +57,11 @@ MAX_SPEC_TOKENS = 1
|
||||
|
||||
# precision
|
||||
PRECISION = "bfloat16"
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@@ -66,7 +75,7 @@ PRECISION = "bfloat16"
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
@@ -97,6 +106,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="quant model is not ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@@ -110,7 +120,53 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"model_name": QUANT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
@@ -158,15 +214,13 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
True,
|
||||
reason=
|
||||
"Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"
|
||||
)
|
||||
@pytest.mark.skipif(True, reason="torchair ut can not clean mem.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
"additional_config": {
|
||||
'enable_graph_mode': True,
|
||||
},
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
@@ -175,7 +229,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"model_name": FLOAT_MODEL,
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@@ -192,20 +246,64 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
def test_mtp_e2e_greedy_correctness_torchair_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with torchair graph enabled and different
|
||||
batch sizes using bfloat16 weights."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="quant model is not ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"additional_config": {
|
||||
'enable_graph_mode': True,
|
||||
},
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": QUANT_MODEL,
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with torchair graph enabled and different
|
||||
batch sizes using quant weights."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@@ -221,7 +319,7 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
@@ -256,6 +354,8 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@@ -266,7 +366,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
@@ -305,6 +405,8 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@@ -315,7 +417,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
|
||||
@@ -29,7 +29,6 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
|
||||
from tests.singlecard.spec_decode.utils import create_batch, mock_worker
|
||||
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.parametrize('queue_size', [4])
|
||||
|
||||
@@ -33,7 +33,6 @@ from tests.singlecard.spec_decode.utils import (
|
||||
assert_logprobs_dict_allclose, create_batch,
|
||||
create_seq_group_metadata_from_prompts, create_worker,
|
||||
patch_execute_model_with_seeds, zero_kv_cache)
|
||||
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from tests.singlecard.spec_decode.utils import (
|
||||
create_seq_group_metadata_from_prompts, create_worker)
|
||||
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_single_no_match():
|
||||
|
||||
@@ -39,8 +39,6 @@ from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
|
||||
from tests.singlecard.spec_decode.utils import (create_batch,
|
||||
create_sampler_output_list,
|
||||
create_worker, mock_worker)
|
||||
# patch SpecDecodeWorker, AsyncMetricsCollector
|
||||
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
|
||||
@@ -17,11 +17,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -34,6 +35,7 @@ from vllm.model_executor.models.deepseek_mtp import (
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
|
||||
|
||||
@@ -69,6 +71,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
@@ -88,6 +92,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
@@ -125,14 +131,20 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
@@ -170,3 +182,19 @@ class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
prefix, "model"))
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
|
||||
@@ -114,6 +114,19 @@
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm.
|
||||
#
|
||||
# ** File: worker/patch_0_8_4/patch_spec_decode_worker.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker._configure_model_sampler_for_spec_decode`
|
||||
# Why:
|
||||
# vLLM `Remove Sampler from Model Code` so vllm-ascend needs a patch to run in v0.8.4.
|
||||
# How:
|
||||
# Use vLLM 0.8.4 method tp patch it.
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# - https://github.com/vllm-project/vllm/pull/17084
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/636
|
||||
# Future Plan:
|
||||
# Follow v0.8.4 version strategy.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_metrics.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics`
|
||||
@@ -158,7 +171,19 @@
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
||||
# 2. `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_include_gpu_probs_tensor` and
|
||||
# `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_should_modify_greedy_probs_inplace`
|
||||
# Why:
|
||||
# vLLM `Remove Sampler from Model Code` so vllm-ascend needs adapt to this change.
|
||||
# How:
|
||||
# Use vLLM 0.8.4 method to patch it.
|
||||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||
# - https://github.com/vllm-project/vllm/pull/15195
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
||||
# Future Plan:
|
||||
# Remove it when we identify the reasons clearly.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_spec_decode_worker.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
|
||||
# Why:
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
|
||||
|
||||
def _configure_model_sampler_for_spec_decode(self):
|
||||
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||
) = True
|
||||
(self.scorer_worker.model_runner.model.sampler.
|
||||
should_modify_greedy_probs_inplace) = True
|
||||
self.proposer_worker.set_include_gpu_probs_tensor()
|
||||
self.proposer_worker.set_should_modify_greedy_probs_inplace()
|
||||
|
||||
|
||||
SpecDecodeWorker._configure_model_sampler_for_spec_decode = _configure_model_sampler_for_spec_decode
|
||||
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
|
||||
@@ -61,15 +62,19 @@ def sampler_output(
|
||||
else:
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# TODO Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
self.worker.model_runner.return_hidden_states = True
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
self._maybe_update_previous_hidden_states(model_output,
|
||||
expanded_request)
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
expanded_request.seq_group_metadata_list,
|
||||
@@ -84,4 +89,22 @@ def sampler_output(
|
||||
return filtered_model_outputs, True
|
||||
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
# Need include_gpu_probs_tensor for MultiSteoWorker
|
||||
if hasattr(self.model_runner.model, "sampler"):
|
||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
if not vllm_version_is("0.8.4"):
|
||||
self.model_runner.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
if hasattr(self.model_runner.model, "sampler"):
|
||||
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
|
||||
True)
|
||||
if not vllm_version_is("0.8.4"):
|
||||
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
|
||||
|
||||
|
||||
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)
|
||||
MultiStepWorker.set_include_gpu_probs_tensor = set_include_gpu_probs_tensor
|
||||
MultiStepWorker.set_should_modify_greedy_probs_inplace = set_should_modify_greedy_probs_inplace
|
||||
|
||||
@@ -93,7 +93,7 @@ def create_worker(
|
||||
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
if draft_model_config.hf_config.model_type == "deepseek_mtp":
|
||||
num_spec_prefill_steps = num_speculative_tokens
|
||||
num_spec_prefill_steps = draft_model_config.hf_config.n_predict
|
||||
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
proposer_worker, draft_tp, target_tp)
|
||||
|
||||
@@ -293,8 +293,8 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
else:
|
||||
assert self.sampler is not None
|
||||
output = self.sampler(
|
||||
assert self.model_runner.sampler is not None
|
||||
output = self.model_runner.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.worker.multi_step_model_runner import (ModelOutput,
|
||||
PythonizationCache,
|
||||
StatefulModelInput)
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
||||
|
||||
@@ -318,8 +319,12 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]):
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
|
||||
True)
|
||||
if vllm_version_is("0.8.4"):
|
||||
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
|
||||
True)
|
||||
else:
|
||||
assert self._base_model_runner.sampler is not None
|
||||
self._base_model_runner.sampler.include_gpu_probs_tensor = True
|
||||
if frozen_model_input.sampling_metadata:
|
||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
True)
|
||||
|
||||
Reference in New Issue
Block a user