diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index d6b7ced..a5d3e24 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -138,13 +138,18 @@ jobs: speculative_tests_changed: - "tests/singlecard/spec_decode/**" - "tests/multicard/spec_decode_e2e/**" + - "vllm_ascend/worker/worker.py" + - "vllm_ascend/worker/model_runner.py" - "vllm_ascend/worker/multi_step_runner.py" - "vllm_ascend/worker/multi_step_worker.py" - - "vllm_ascend/patch/patch_rejection_sampler.py" - - "vllm_ascend/patch/patch_spec_decode_worker.py" - - "vllm_ascend/patch/patch_multi_step_worker.py" + - "vllm_ascend/worker/draft_model_runner.py" + - "vllm_ascend/patch/worker/patch_common/patch_metrics.py" + - "vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py" + - "vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py" - name: Run vllm-project/vllm-ascend Speculative Decode test + env: + VLLM_USE_V1: 0 if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true' run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then diff --git a/tests/singlecard/spec_decode/__init__.py b/tests/singlecard/spec_decode/__init__.py index e69de29..4caac44 100644 --- a/tests/singlecard/spec_decode/__init__.py +++ b/tests/singlecard/spec_decode/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vllm_ascend.patch import worker # noqa: F401 diff --git a/tests/singlecard/spec_decode/e2e/conftest.py b/tests/singlecard/spec_decode/e2e/conftest.py index c61ce1c..ce26b6c 100644 --- a/tests/singlecard/spec_decode/e2e/conftest.py +++ b/tests/singlecard/spec_decode/e2e/conftest.py @@ -17,7 +17,9 @@ # limitations under the License. # +import shutil from itertools import cycle +from pathlib import Path from typing import List, Optional, Sequence, Tuple, Union import pytest @@ -177,6 +179,12 @@ def _check_logprobs_when_output_disabled( assert spec_pos_logprob_token_id in baseline_pos_logprobs +def _clean_torchair_cache(): + cache_path = Path.cwd() / '.torchair_cache' + if cache_path.exists() and cache_path.is_dir(): + shutil.rmtree(cache_path) + + def run_equality_correctness_test( vllm_runner, common_llm_kwargs, @@ -219,10 +227,20 @@ def run_equality_correctness_test( logprobs=logprobs, prompt_logprobs=prompt_logprobs) + # TODO current torchair graph mode needs clean torchair cache. + # if do not clean, it will raise error + additional_config = common_llm_kwargs.get("additional_config") + enable_graph_mode = additional_config.get( + "enable_graph_mode") if additional_config else False + with vllm_runner(**org_args) as vllm_model: + if enable_graph_mode: + _clean_torchair_cache() org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) with vllm_runner(**sd_args) as vllm_model: + if enable_graph_mode: + _clean_torchair_cache() if ensure_all_accepted or expected_acceptance_rate is not None: # Force log interval to be 0 to catch all metrics. stat_logger = vllm_model.model.llm_engine.stat_loggers[ diff --git a/tests/singlecard/spec_decode/e2e/test_mtp_correctness.py b/tests/singlecard/spec_decode/e2e/test_mtp_correctness.py index 18841fb..5c28269 100644 --- a/tests/singlecard/spec_decode/e2e/test_mtp_correctness.py +++ b/tests/singlecard/spec_decode/e2e/test_mtp_correctness.py @@ -36,14 +36,20 @@ However, we still need to verify below scenario could be passed: With those tests, we can say at least, mtp would not break the correctess for the target model outputs. """ +import os import pytest from .conftest import run_equality_correctness_test -# main model -# NOTE vLLM use fp8 model, vllm-ascend use bf16 model -MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16" +# NOTE both main model and MTP are bfloat16 +FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16" + +# NOTE main model is w8a8, MTP is bfloat16 +QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part" + +# TODO when msmodelslim can quantify both main and MTP model +# This UT should use w8a8 fully weights. # max. number of speculative tokens: this corresponds to # num_nextn_predict_layers in the config.json of the speculator model. @@ -51,8 +57,11 @@ MAX_SPEC_TOKENS = 1 # precision PRECISION = "bfloat16" +os.environ["VLLM_USE_MODELSCOPE"] = "True" +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1", + reason="mtp is not supported on v1") @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -66,7 +75,7 @@ PRECISION = "bfloat16" "dtype": PRECISION, # Main model - "model_name": MAIN_MODEL, + "model_name": FLOAT_MODEL, # GPU memory utilization "gpu_memory_utilization": 0.85 @@ -97,6 +106,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.skipif(True, reason="quant model is not ready.") @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -110,7 +120,53 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, "dtype": PRECISION, # Main model - "model_name": MAIN_MODEL, + "model_name": QUANT_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int): + + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1", + reason="mtp is not supported on v1") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": FLOAT_MODEL, # GPU memory utilization "gpu_memory_utilization": 0.85 @@ -158,15 +214,13 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ["disable_logprobs"]) -@pytest.mark.skipif( - True, - reason= - "Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode" -) +@pytest.mark.skipif(True, reason="torchair ut can not clean mem.") @pytest.mark.parametrize( "common_llm_kwargs", [{ - "enforce_eager": False, + "additional_config": { + 'enable_graph_mode': True, + }, # Print spec metrics. "disable_log_stats": False, @@ -175,7 +229,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, "dtype": PRECISION, # Main model - "model_name": MAIN_MODEL, + "model_name": FLOAT_MODEL, "gpu_memory_utilization": 0.85 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @@ -192,20 +246,64 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - output_len: int, seed: int): - """Verify greedy equality with cuda graph enabled and different - batch sizes.""" +def test_mtp_e2e_greedy_correctness_torchair_graph( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality with torchair graph enabled and different + batch sizes using bfloat16 weights.""" run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.skipif(True, reason="quant model is not ready.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "additional_config": { + 'enable_graph_mode': True, + }, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": QUANT_MODEL, + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_quant_greedy_correctness_torchair_graph( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality with torchair graph enabled and different + batch sizes using quant weights.""" + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1", + reason="mtp is not supported on v1") @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -221,7 +319,7 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, "dtype": PRECISION, # Main model - "model_name": MAIN_MODEL, + "model_name": FLOAT_MODEL, # GPU memory utilization "gpu_memory_utilization": 0.9 @@ -256,6 +354,8 @@ def test_mtp_e2e_greedy_correctness_with_preemption( batch_size, output_len, seed) +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1", + reason="mtp is not supported on v1") @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -266,7 +366,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption( "dtype": PRECISION, # Main model - "model_name": MAIN_MODEL, + "model_name": FLOAT_MODEL, # GPU memory utilization "gpu_memory_utilization": 0.9 @@ -305,6 +405,8 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1", + reason="mtp is not supported on v1") @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -315,7 +417,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, "dtype": PRECISION, # Main model - "model_name": MAIN_MODEL, + "model_name": FLOAT_MODEL, # GPU memory utilization "gpu_memory_utilization": 0.9 diff --git a/tests/singlecard/spec_decode/test_dynamic_spec_decode.py b/tests/singlecard/spec_decode/test_dynamic_spec_decode.py index a8e0504..b5f9ed6 100644 --- a/tests/singlecard/spec_decode/test_dynamic_spec_decode.py +++ b/tests/singlecard/spec_decode/test_dynamic_spec_decode.py @@ -29,7 +29,6 @@ from vllm.spec_decode.top1_proposer import Top1Proposer from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler from tests.singlecard.spec_decode.utils import create_batch, mock_worker -from vllm_ascend.patch.worker import patch_common # noqa: F401 @pytest.mark.parametrize('queue_size', [4]) diff --git a/tests/singlecard/spec_decode/test_multi_step_worker.py b/tests/singlecard/spec_decode/test_multi_step_worker.py index 90d5e65..b7b4c72 100644 --- a/tests/singlecard/spec_decode/test_multi_step_worker.py +++ b/tests/singlecard/spec_decode/test_multi_step_worker.py @@ -33,7 +33,6 @@ from tests.singlecard.spec_decode.utils import ( assert_logprobs_dict_allclose, create_batch, create_seq_group_metadata_from_prompts, create_worker, patch_execute_model_with_seeds, zero_kv_cache) -from vllm_ascend.patch.worker import patch_common # noqa: F401 from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner from vllm_ascend.worker.worker import NPUWorker diff --git a/tests/singlecard/spec_decode/test_ngram_worker.py b/tests/singlecard/spec_decode/test_ngram_worker.py index 0226ac8..f8f7bf2 100644 --- a/tests/singlecard/spec_decode/test_ngram_worker.py +++ b/tests/singlecard/spec_decode/test_ngram_worker.py @@ -24,7 +24,6 @@ from vllm.spec_decode.top1_proposer import Top1Proposer from tests.singlecard.spec_decode.utils import ( create_seq_group_metadata_from_prompts, create_worker) -from vllm_ascend.patch.worker import patch_common # noqa: F401 def test_ngram_algo_correctness_for_single_no_match(): diff --git a/tests/singlecard/spec_decode/test_spec_decode_worker.py b/tests/singlecard/spec_decode/test_spec_decode_worker.py index d049d28..b44a1f3 100644 --- a/tests/singlecard/spec_decode/test_spec_decode_worker.py +++ b/tests/singlecard/spec_decode/test_spec_decode_worker.py @@ -39,8 +39,6 @@ from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler from tests.singlecard.spec_decode.utils import (create_batch, create_sampler_output_list, create_worker, mock_worker) -# patch SpecDecodeWorker, AsyncMetricsCollector -from vllm_ascend.patch.worker import patch_common # noqa: F401 from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner from vllm_ascend.worker.worker import NPUWorker diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index a19d666..979a609 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -17,11 +17,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import List, Optional import torch import torch.nn as nn from transformers import PretrainedConfig +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -34,6 +35,7 @@ from vllm.model_executor.models.deepseek_mtp import ( SharedHead) from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors from .deepseek_v2 import CustomDeepseekV2DecoderLayer @@ -69,6 +71,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, @@ -88,6 +92,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, residual=None) hidden_states = residual + hidden_states return hidden_states @@ -125,14 +131,20 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: torch.Tensor, + attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) + step_kv_cache = kv_caches[ + current_step_idx] if kv_caches is not None else None return self.layers_list[current_step_idx]( input_ids, positions, + step_kv_cache, + attn_metadata, previous_hidden_states, inputs_embeds, current_step_idx, @@ -170,3 +182,19 @@ class CustomDeepSeekMTP(DeepSeekMTP): prefix, "model")) self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, previous_hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index ef0813a..52bfe13 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -114,6 +114,19 @@ # Future Plan: # Revert it when the related pr is merged in vllm. # +# ** File: worker/patch_0_8_4/patch_spec_decode_worker.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker._configure_model_sampler_for_spec_decode` +# Why: +# vLLM `Remove Sampler from Model Code` so vllm-ascend needs a patch to run in v0.8.4. +# How: +# Use vLLM 0.8.4 method tp patch it. +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# - https://github.com/vllm-project/vllm/pull/17084 +# - https://github.com/vllm-project/vllm-ascend/pull/636 +# Future Plan: +# Follow v0.8.4 version strategy. +# # ** File: worker/patch_common/patch_metrics.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics` @@ -158,7 +171,19 @@ # Future Plan: # Revert it when the related pr is merged in vllm and vllm-ascend. # -# ** File: worker/patch_common/patch_multi_step_worker.py ** +# 2. `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_include_gpu_probs_tensor` and +# `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_should_modify_greedy_probs_inplace` +# Why: +# vLLM `Remove Sampler from Model Code` so vllm-ascend needs adapt to this change. +# How: +# Use vLLM 0.8.4 method to patch it. +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# - https://github.com/vllm-project/vllm/pull/15195 +# - https://github.com/vllm-project/vllm-ascend/pull/395 +# Future Plan: +# Remove it when we identify the reasons clearly. +# +# ** File: worker/patch_common/patch_spec_decode_worker.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker` # Why: diff --git a/vllm_ascend/patch/worker/patch_0_8_4/patch_spec_decode_worker.py b/vllm_ascend/patch/worker/patch_0_8_4/patch_spec_decode_worker.py new file mode 100644 index 0000000..710894e --- /dev/null +++ b/vllm_ascend/patch/worker/patch_0_8_4/patch_spec_decode_worker.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + + +def _configure_model_sampler_for_spec_decode(self): + (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + ) = True + (self.scorer_worker.model_runner.model.sampler. + should_modify_greedy_probs_inplace) = True + self.proposer_worker.set_include_gpu_probs_tensor() + self.proposer_worker.set_should_modify_greedy_probs_inplace() + + +SpecDecodeWorker._configure_model_sampler_for_spec_decode = _configure_model_sampler_for_spec_decode diff --git a/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py b/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py index 6adbf2d..2ae6cab 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py +++ b/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm_ascend.utils import vllm_version_is from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner @@ -61,15 +62,19 @@ def sampler_output( else: # Here we run multi-step directly, with every step prepared # on the CPU. - # TODO: Remove this branch once DraftModelRunner supports TP>1 + # TODO Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) + if expanded_request.previous_hidden_states is not None: + self.worker.model_runner.return_hidden_states = True for _ in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] + self._maybe_update_previous_hidden_states(model_output, + expanded_request) self._append_new_tokens(model_output, expanded_request.seq_group_metadata_list, @@ -84,4 +89,22 @@ def sampler_output( return filtered_model_outputs, True +def set_include_gpu_probs_tensor(self) -> None: + # Need include_gpu_probs_tensor for MultiSteoWorker + if hasattr(self.model_runner.model, "sampler"): + self.model_runner.model.sampler.include_gpu_probs_tensor = True + if not vllm_version_is("0.8.4"): + self.model_runner.sampler.include_gpu_probs_tensor = True + + +def set_should_modify_greedy_probs_inplace(self) -> None: + if hasattr(self.model_runner.model, "sampler"): + self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( + True) + if not vllm_version_is("0.8.4"): + self.model_runner.sampler.should_modify_greedy_probs_inplace = True + + MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output) +MultiStepWorker.set_include_gpu_probs_tensor = set_include_gpu_probs_tensor +MultiStepWorker.set_should_modify_greedy_probs_inplace = set_should_modify_greedy_probs_inplace diff --git a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py index 040e62e..8af68c1 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py +++ b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py @@ -93,7 +93,7 @@ def create_worker( proposer_worker = MultiStepWorker(**draft_worker_kwargs) if draft_model_config.hf_config.model_type == "deepseek_mtp": - num_spec_prefill_steps = num_speculative_tokens + num_spec_prefill_steps = draft_model_config.hf_config.n_predict proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) diff --git a/vllm_ascend/worker/draft_model_runner.py b/vllm_ascend/worker/draft_model_runner.py index 504d94e..7122569 100644 --- a/vllm_ascend/worker/draft_model_runner.py +++ b/vllm_ascend/worker/draft_model_runner.py @@ -293,8 +293,8 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): sampling_metadata=model_input.sampling_metadata, ) else: - assert self.sampler is not None - output = self.sampler( + assert self.model_runner.sampler is not None + output = self.model_runner.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm_ascend/worker/multi_step_runner.py b/vllm_ascend/worker/multi_step_runner.py index 2ac9561..7753604 100644 --- a/vllm_ascend/worker/multi_step_runner.py +++ b/vllm_ascend/worker/multi_step_runner.py @@ -23,6 +23,7 @@ from vllm.worker.multi_step_model_runner import (ModelOutput, PythonizationCache, StatefulModelInput) +from vllm_ascend.utils import vllm_version_is from vllm_ascend.worker.model_runner import ( ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase) @@ -318,8 +319,12 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]): device="cpu", pin_memory=True) - self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( - True) + if vllm_version_is("0.8.4"): + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) + else: + assert self._base_model_runner.sampler is not None + self._base_model_runner.sampler.include_gpu_probs_tensor = True if frozen_model_input.sampling_metadata: frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( True)