[Misc][V0 Deprecation] Remove multi-step worker (#1809)

### What this PR does / why we need it?
Remove multi-step worker

This PR is a part of
https://github.com/vllm-project/vllm-ascend/issues/1620.

- vLLM version: v0.9.2
- vLLM main:
235bfd5dfe

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-07-15 19:48:47 +08:00
committed by GitHub
parent bf2549856f
commit a929699e98
4 changed files with 0 additions and 303 deletions

View File

@@ -73,23 +73,6 @@
# Future Plan:
# Keep this patch in vllm-ascend.
#
# ** File: worker/patch_common/patch_multi_step_worker.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
# Why:
# There are cuda hard code (current_platform.is_cuda_alike()) in
# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it.
# How
# Make speculative decoding extensible to different backends.
# - support attention metadata register to the set supported spec decode
# - offer a api in platform to determine whether spec decode is supported,
# and deprecate is_cuda_alike in it.
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend.
#
# ** File: worker/patch_common/patch_spec_decode_worker.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`

View File

@@ -20,6 +20,5 @@
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa

View File

@@ -1,91 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
# TODO: supports_gpu_multi_step is False in ASCEND
if isinstance(self.model_runner, TP1DraftModelRunner) and \
self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if expanded_request.previous_hidden_states is not None:
self.worker.model_runner.return_hidden_states = True
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._maybe_update_previous_hidden_states(model_output,
expanded_request)
self._append_new_tokens(model_output,
expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)

View File

@@ -1,194 +0,0 @@
import dataclasses
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import StatefulModelInput
from vllm_ascend.worker.multi_step_runner import MultiStepModelNPURunner
from vllm_ascend.worker.worker import NPUWorker, WorkerInput
@dataclass
class MultiStepState:
worker_input: WorkerInput
model_input: StatefulModelInput
class MultiStepWorker(NPUWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
base_model_runner = self.model_runner
# for multi-step model, wrap the model runner with MultiStepModelRunner
self.model_runner = MultiStepModelNPURunner(
base_model_runner,
vllm_config=base_model_runner.vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=base_model_runner.is_driver_worker,
)
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
self.multi_step_states: List[
Optional[MultiStepState]] = [None] * pipeline_parallel_size
self.temp_output = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
"""
Get the driver input and broadcast it to other workers.
"""
assert self.is_driver_worker
virtual_engine = execute_model_req.virtual_engine
is_first_multi_step = execute_model_req.is_first_multi_step
if is_first_multi_step:
# on first step we prepare the worker input and model input normally
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: StatefulModelInput = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=execute_model_req.async_callback)
else:
# on subsequent steps we reuse the worker input and model input
multi_step_state = self.multi_step_states[virtual_engine]
worker_input = multi_step_state.worker_input
model_input = multi_step_state.model_input
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
# clear the cached metadata so that it can be recomputed on
# the workers.
frozen_model_input.attn_metadata._cached_prefill_metadata = None
frozen_model_input.attn_metadata._cached_decode_metadata = None
model_input.is_first_multi_step = is_first_multi_step
model_input.is_last_step = execute_model_req.is_last_step
if not is_first_multi_step:
# we broadcast the last sampled token ids to all TP workers so they
# can update their model input metadata in-place.
self._prepare_last_sampled_token_ids_for_tp_workers(
execute_model_req=execute_model_req, model_input=model_input)
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def _prepare_last_sampled_token_ids_for_tp_workers(
self,
execute_model_req: ExecuteModelRequest,
model_input: StatefulModelInput,
) -> None:
"""
Prepare the last sampled token ids for TP workers. If it's the last
PP rank, then the last sampled token ids are already in the model_input.
If it is NOT the last PP rank, then we need to get the last sampled
token that is cached in the execute_model_req.
"""
if get_pp_group().is_last_rank:
assert model_input.cached_outputs[
-1].sampler_output.sampled_token_ids is None
assert model_input.cached_outputs[-1].sampled_token_ids is not None
model_input.last_sampled_token_ids = model_input.cached_outputs[
-1].sampled_token_ids
# free sampled token ids from the previous step if it has been
# pythonized. Cannot free the last sampled token ids because
# we need it for GPU advance_step.
for output in model_input.cached_outputs[:-1]:
if output.pythonized:
output.sampled_token_ids = None
else:
# otherwise we need to get the cached sampled token ids from the
# execute_model_req
assert execute_model_req.last_sampled_token_ids is not None
model_input.last_sampled_token_ids = (
execute_model_req.last_sampled_token_ids.npu())
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
# free sampled token ids from the previous step.
# TODO(will) we could reuse the sampled token ids tensor from
# the previous step instead.
for output in model_input.cached_outputs[:-1]:
output.sampled_token_ids = None
assert model_input.cached_outputs[-1].sampled_token_ids is not None
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
torch.Tensor]]]:
"""
Depending on the current state of the request and multi step worker,
this method may skip the normal _prepare_model_input and
_prepare_worker_input methods and instead used cached values.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
virtual_engine = execute_model_req.virtual_engine
(model_input, worker_input,
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
assert isinstance(model_input, StatefulModelInput)
if execute_model_req.is_first_multi_step:
# cache the worker input and model input for the next steps
self.multi_step_states[virtual_engine] = MultiStepState(
worker_input=worker_input, model_input=model_input)
# if TP workers
else:
broadcast_data = self._get_worker_input_from_broadcast()
# if the driver has sent an empty input, we should stop the worker
# loop
if broadcast_data is None:
return None
model_input, worker_input, kwargs = broadcast_data
assert isinstance(model_input, StatefulModelInput)
virtual_engine = worker_input.virtual_engine
if model_input.is_first_multi_step:
pass
# TODO(will) Can cache the worker input and model input for the
# next steps. See below for details
else:
# TODO(will) possible to also cache and reuse the cached worker
# input and model input. The idea is essentially the delta
# optimization for model_inputs. Where the TP workers can cache
# the model input states and we only broadcast the delta need
# for the next step (sampled_token_ids from the previous step)
assert isinstance(model_input, StatefulModelInput)
# we need to update the last sampled token ids in the model
# input for the workers so that they can run inplace
# advance_step
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
assert model_input is not None
assert worker_input is not None
return model_input, worker_input, kwargs