[MTP] follow custom deepseek modeling changes to support graph mode (#636)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?

As custom deepseek modeling do some changes to support graph mode in
https://github.com/vllm-project/vllm-ascend/pull/585, so i follow it to
change custom deepseek_mtp modeling.

And some modifications for k>1 were not carried over by the
https://github.com/vllm-project/vllm-ascend/pull/429, now i add it.

In order to better take care of the MTP feature in the vllm-ascend
repository, I added cases related to graph mode(torchair), but i skip it
since torchair can not correctly clean up memory in vllmrunner.

Also i add some case for MTP quantization weights, but test weight is
not ready, so i skip it and i will open it when test quant weights is
ready.

https://github.com/vllm-project/vllm-ascend/pull/648 did not completely
fix the sample
change(https://github.com/vllm-project/vllm-ascend/issues/660) issue, I
added the relevant changes.

### Does this PR introduce _any_ user-facing change?
now, u can use following method to use mtp in deepseek v3/r1 float or
quant weights with eager mode.
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    enforce_eager=True,
    trust_remote_code=True,
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

or use mtp in deepseek v3/r1 float or quant weights with graph
mode(torchair)
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    trust_remote_code=True,
    additional_config={
        'enable_graph_mode': True,
    },
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

add notes:
1. now, we support k>1, so u can set num_speculative_tokens > 1 if there
is sufficient redundant computing power;
2. MTP is not supported in V1, we will support it when vLLM does it in
https://github.com/vllm-project/vllm/issues/13500.
3. if u run MTP failed by `segmentation fault`, u can follow v0.7.3
patch https://github.com/vllm-project/vllm-ascend/pull/236 file
`vllm_ascend/patch/patch_metrics.py` method
`__npu_async_metrics_collector_init__`

### How was this patch tested?
local tested passed and test by CI

Signed-off-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
wemaster
2025-04-28 21:18:53 +08:00
committed by GitHub
parent be9e3e8545
commit 54c0e63df7
15 changed files with 288 additions and 39 deletions

View File

@@ -138,13 +138,18 @@ jobs:
speculative_tests_changed:
- "tests/singlecard/spec_decode/**"
- "tests/multicard/spec_decode_e2e/**"
- "vllm_ascend/worker/worker.py"
- "vllm_ascend/worker/model_runner.py"
- "vllm_ascend/worker/multi_step_runner.py"
- "vllm_ascend/worker/multi_step_worker.py"
- "vllm_ascend/patch/patch_rejection_sampler.py"
- "vllm_ascend/patch/patch_spec_decode_worker.py"
- "vllm_ascend/patch/patch_multi_step_worker.py"
- "vllm_ascend/worker/draft_model_runner.py"
- "vllm_ascend/patch/worker/patch_common/patch_metrics.py"
- "vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py"
- "vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py"
- name: Run vllm-project/vllm-ascend Speculative Decode test
env:
VLLM_USE_V1: 0
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true'
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then

View File

@@ -0,0 +1,18 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm_ascend.patch import worker # noqa: F401

View File

@@ -17,7 +17,9 @@
# limitations under the License.
#
import shutil
from itertools import cycle
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
import pytest
@@ -177,6 +179,12 @@ def _check_logprobs_when_output_disabled(
assert spec_pos_logprob_token_id in baseline_pos_logprobs
def _clean_torchair_cache():
cache_path = Path.cwd() / '.torchair_cache'
if cache_path.exists() and cache_path.is_dir():
shutil.rmtree(cache_path)
def run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
@@ -219,10 +227,20 @@ def run_equality_correctness_test(
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)
# TODO current torchair graph mode needs clean torchair cache.
# if do not clean, it will raise error
additional_config = common_llm_kwargs.get("additional_config")
enable_graph_mode = additional_config.get(
"enable_graph_mode") if additional_config else False
with vllm_runner(**org_args) as vllm_model:
if enable_graph_mode:
_clean_torchair_cache()
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
if enable_graph_mode:
_clean_torchair_cache()
if ensure_all_accepted or expected_acceptance_rate is not None:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers[

View File

@@ -36,14 +36,20 @@ However, we still need to verify below scenario could be passed:
With those tests, we can say at least, mtp would not break the
correctess for the target model outputs.
"""
import os
import pytest
from .conftest import run_equality_correctness_test
# main model
# NOTE vLLM use fp8 model, vllm-ascend use bf16 model
MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
# NOTE both main model and MTP are bfloat16
FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
# NOTE main model is w8a8, MTP is bfloat16
QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"
# TODO when msmodelslim can quantify both main and MTP model
# This UT should use w8a8 fully weights.
# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
@@ -51,8 +57,11 @@ MAX_SPEC_TOKENS = 1
# precision
PRECISION = "bfloat16"
os.environ["VLLM_USE_MODELSCOPE"] = "True"
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -66,7 +75,7 @@ PRECISION = "bfloat16"
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
@@ -97,6 +106,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(True, reason="quant model is not ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -110,7 +120,53 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": QUANT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
@@ -158,15 +214,13 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
["disable_logprobs"])
@pytest.mark.skipif(
True,
reason=
"Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"
)
@pytest.mark.skipif(True, reason="torchair ut can not clean mem.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
"additional_config": {
'enable_graph_mode': True,
},
# Print spec metrics.
"disable_log_stats": False,
@@ -175,7 +229,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@@ -192,20 +246,64 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
output_len: int, seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
def test_mtp_e2e_greedy_correctness_torchair_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with torchair graph enabled and different
batch sizes using bfloat16 weights."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(True, reason="quant model is not ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"additional_config": {
'enable_graph_mode': True,
},
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": QUANT_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with torchair graph enabled and different
batch sizes using quant weights."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -221,7 +319,7 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
@@ -256,6 +354,8 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -266,7 +366,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
@@ -305,6 +405,8 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -315,7 +417,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9

View File

@@ -29,7 +29,6 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import create_batch, mock_worker
from vllm_ascend.patch.worker import patch_common # noqa: F401
@pytest.mark.parametrize('queue_size', [4])

View File

@@ -33,7 +33,6 @@ from tests.singlecard.spec_decode.utils import (
assert_logprobs_dict_allclose, create_batch,
create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache)
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker

View File

@@ -24,7 +24,6 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.utils import (
create_seq_group_metadata_from_prompts, create_worker)
from vllm_ascend.patch.worker import patch_common # noqa: F401
def test_ngram_algo_correctness_for_single_no_match():

View File

@@ -39,8 +39,6 @@ from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import (create_batch,
create_sampler_output_list,
create_worker, mock_worker)
# patch SpecDecodeWorker, AsyncMetricsCollector
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker

View File

@@ -17,11 +17,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -34,6 +35,7 @@ from vllm.model_executor.models.deepseek_mtp import (
SharedHead)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
@@ -69,6 +71,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
@@ -88,6 +92,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
@@ -125,14 +131,20 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
step_kv_cache = kv_caches[
current_step_idx] if kv_caches is not None else None
return self.layers_list[current_step_idx](
input_ids,
positions,
step_kv_cache,
attn_metadata,
previous_hidden_states,
inputs_embeds,
current_step_idx,
@@ -170,3 +182,19 @@ class CustomDeepSeekMTP(DeepSeekMTP):
prefix, "model"))
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states

View File

@@ -114,6 +114,19 @@
# Future Plan:
# Revert it when the related pr is merged in vllm.
#
# ** File: worker/patch_0_8_4/patch_spec_decode_worker.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker._configure_model_sampler_for_spec_decode`
# Why:
# vLLM `Remove Sampler from Model Code` so vllm-ascend needs a patch to run in v0.8.4.
# How
# Use vLLM 0.8.4 method tp patch it.
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm/pull/17084
# - https://github.com/vllm-project/vllm-ascend/pull/636
# Future Plan:
# Follow v0.8.4 version strategy.
#
# ** File: worker/patch_common/patch_metrics.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics`
@@ -158,7 +171,19 @@
# Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend.
#
# ** File: worker/patch_common/patch_multi_step_worker.py **
# 2. `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_include_gpu_probs_tensor` and
# `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_should_modify_greedy_probs_inplace`
# Why:
# vLLM `Remove Sampler from Model Code` so vllm-ascend needs adapt to this change.
# How
# Use vLLM 0.8.4 method to patch it.
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
# Remove it when we identify the reasons clearly.
#
# ** File: worker/patch_common/patch_spec_decode_worker.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
# Why:

View File

@@ -0,0 +1,30 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
def _configure_model_sampler_for_spec_decode(self):
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()
SpecDecodeWorker._configure_model_sampler_for_spec_decode = _configure_model_sampler_for_spec_decode

View File

@@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
@@ -61,15 +62,19 @@ def sampler_output(
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# TODO Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if expanded_request.previous_hidden_states is not None:
self.worker.model_runner.return_hidden_states = True
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._maybe_update_previous_hidden_states(model_output,
expanded_request)
self._append_new_tokens(model_output,
expanded_request.seq_group_metadata_list,
@@ -84,4 +89,22 @@ def sampler_output(
return filtered_model_outputs, True
def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiSteoWorker
if hasattr(self.model_runner.model, "sampler"):
self.model_runner.model.sampler.include_gpu_probs_tensor = True
if not vllm_version_is("0.8.4"):
self.model_runner.sampler.include_gpu_probs_tensor = True
def set_should_modify_greedy_probs_inplace(self) -> None:
if hasattr(self.model_runner.model, "sampler"):
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
if not vllm_version_is("0.8.4"):
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)
MultiStepWorker.set_include_gpu_probs_tensor = set_include_gpu_probs_tensor
MultiStepWorker.set_should_modify_greedy_probs_inplace = set_should_modify_greedy_probs_inplace

View File

@@ -93,7 +93,7 @@ def create_worker(
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
num_spec_prefill_steps = draft_model_config.hf_config.n_predict
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)

View File

@@ -293,8 +293,8 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
sampling_metadata=model_input.sampling_metadata,
)
else:
assert self.sampler is not None
output = self.sampler(
assert self.model_runner.sampler is not None
output = self.model_runner.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)

View File

@@ -23,6 +23,7 @@ from vllm.worker.multi_step_model_runner import (ModelOutput,
PythonizationCache,
StatefulModelInput)
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.model_runner import (
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
@@ -318,8 +319,12 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]):
device="cpu",
pin_memory=True)
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
True)
if vllm_version_is("0.8.4"):
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
True)
else:
assert self._base_model_runner.sampler is not None
self._base_model_runner.sampler.include_gpu_probs_tensor = True
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
True)