[Misc][V0 Deprecation] Remove multi-step worker (#1809)
### What this PR does / why we need it?
Remove multi-step worker
This PR is a part of
https://github.com/vllm-project/vllm-ascend/issues/1620.
- vLLM version: v0.9.2
- vLLM main:
235bfd5dfe
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -73,23 +73,6 @@
|
||||
# Future Plan:
|
||||
# Keep this patch in vllm-ascend.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
|
||||
# Why:
|
||||
# There are cuda hard code (current_platform.is_cuda_alike()) in
|
||||
# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it.
|
||||
# How:
|
||||
# Make speculative decoding extensible to different backends.
|
||||
# - support attention metadata register to the set supported spec decode
|
||||
# - offer a api in platform to determine whether spec decode is supported,
|
||||
# and deprecate is_cuda_alike in it.
|
||||
# Related PR (if no, explain why):
|
||||
# - https://github.com/vllm-project/vllm/pull/15195
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
||||
# Future Plan:
|
||||
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_spec_decode_worker.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
|
||||
|
||||
@@ -20,6 +20,5 @@
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
|
||||
# TODO: supports_gpu_multi_step is False in ASCEND
|
||||
if isinstance(self.model_runner, TP1DraftModelRunner) and \
|
||||
self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs = self.execute_model(execute_model_req=expanded_request)
|
||||
else:
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
self.worker.model_runner.return_hidden_states = True
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
self._maybe_update_previous_hidden_states(model_output,
|
||||
expanded_request)
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
# move indices to device to avoid stream sync
|
||||
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||||
indices_of_seq_with_bonus_tokens, device=self.device)
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
|
||||
|
||||
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)
|
||||
@@ -1,194 +0,0 @@
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.model_runner_base import BroadcastableModelInput
|
||||
from vllm.worker.multi_step_model_runner import StatefulModelInput
|
||||
|
||||
from vllm_ascend.worker.multi_step_runner import MultiStepModelNPURunner
|
||||
from vllm_ascend.worker.worker import NPUWorker, WorkerInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStepState:
|
||||
worker_input: WorkerInput
|
||||
model_input: StatefulModelInput
|
||||
|
||||
|
||||
class MultiStepWorker(NPUWorker):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
base_model_runner = self.model_runner
|
||||
# for multi-step model, wrap the model runner with MultiStepModelRunner
|
||||
self.model_runner = MultiStepModelNPURunner(
|
||||
base_model_runner,
|
||||
vllm_config=base_model_runner.vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=base_model_runner.is_driver_worker,
|
||||
)
|
||||
|
||||
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
self.multi_step_states: List[
|
||||
Optional[MultiStepState]] = [None] * pipeline_parallel_size
|
||||
self.temp_output = None
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Get the driver input and broadcast it to other workers.
|
||||
"""
|
||||
assert self.is_driver_worker
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||
if is_first_multi_step:
|
||||
# on first step we prepare the worker input and model input normally
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: StatefulModelInput = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
else:
|
||||
# on subsequent steps we reuse the worker input and model input
|
||||
multi_step_state = self.multi_step_states[virtual_engine]
|
||||
worker_input = multi_step_state.worker_input
|
||||
model_input = multi_step_state.model_input
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
# clear the cached metadata so that it can be recomputed on
|
||||
# the workers.
|
||||
frozen_model_input.attn_metadata._cached_prefill_metadata = None
|
||||
frozen_model_input.attn_metadata._cached_decode_metadata = None
|
||||
|
||||
model_input.is_first_multi_step = is_first_multi_step
|
||||
model_input.is_last_step = execute_model_req.is_last_step
|
||||
|
||||
if not is_first_multi_step:
|
||||
# we broadcast the last sampled token ids to all TP workers so they
|
||||
# can update their model input metadata in-place.
|
||||
self._prepare_last_sampled_token_ids_for_tp_workers(
|
||||
execute_model_req=execute_model_req, model_input=model_input)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
# Retuning empty dict here to keep this compatible with
|
||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||
return model_input, worker_input, {}
|
||||
|
||||
def _prepare_last_sampled_token_ids_for_tp_workers(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
model_input: StatefulModelInput,
|
||||
) -> None:
|
||||
"""
|
||||
Prepare the last sampled token ids for TP workers. If it's the last
|
||||
PP rank, then the last sampled token ids are already in the model_input.
|
||||
If it is NOT the last PP rank, then we need to get the last sampled
|
||||
token that is cached in the execute_model_req.
|
||||
"""
|
||||
if get_pp_group().is_last_rank:
|
||||
assert model_input.cached_outputs[
|
||||
-1].sampler_output.sampled_token_ids is None
|
||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = model_input.cached_outputs[
|
||||
-1].sampled_token_ids
|
||||
# free sampled token ids from the previous step if it has been
|
||||
# pythonized. Cannot free the last sampled token ids because
|
||||
# we need it for GPU advance_step.
|
||||
for output in model_input.cached_outputs[:-1]:
|
||||
if output.pythonized:
|
||||
output.sampled_token_ids = None
|
||||
else:
|
||||
# otherwise we need to get the cached sampled token ids from the
|
||||
# execute_model_req
|
||||
assert execute_model_req.last_sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = (
|
||||
execute_model_req.last_sampled_token_ids.npu())
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
# free sampled token ids from the previous step.
|
||||
# TODO(will) we could reuse the sampled token ids tensor from
|
||||
# the previous step instead.
|
||||
for output in model_input.cached_outputs[:-1]:
|
||||
output.sampled_token_ids = None
|
||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Depending on the current state of the request and multi step worker,
|
||||
this method may skip the normal _prepare_model_input and
|
||||
_prepare_worker_input methods and instead used cached values.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
(model_input, worker_input,
|
||||
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
if execute_model_req.is_first_multi_step:
|
||||
# cache the worker input and model input for the next steps
|
||||
self.multi_step_states[virtual_engine] = MultiStepState(
|
||||
worker_input=worker_input, model_input=model_input)
|
||||
# if TP workers
|
||||
else:
|
||||
broadcast_data = self._get_worker_input_from_broadcast()
|
||||
# if the driver has sent an empty input, we should stop the worker
|
||||
# loop
|
||||
if broadcast_data is None:
|
||||
return None
|
||||
model_input, worker_input, kwargs = broadcast_data
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
if model_input.is_first_multi_step:
|
||||
pass
|
||||
# TODO(will) Can cache the worker input and model input for the
|
||||
# next steps. See below for details
|
||||
else:
|
||||
# TODO(will) possible to also cache and reuse the cached worker
|
||||
# input and model input. The idea is essentially the delta
|
||||
# optimization for model_inputs. Where the TP workers can cache
|
||||
# the model input states and we only broadcast the delta need
|
||||
# for the next step (sampled_token_ids from the previous step)
|
||||
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
# we need to update the last sampled token ids in the model
|
||||
# input for the workers so that they can run inplace
|
||||
# advance_step
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
assert model_input is not None
|
||||
assert worker_input is not None
|
||||
return model_input, worker_input, kwargs
|
||||
Reference in New Issue
Block a user