[Misc][V0 Deprecation] Remove multi-step worker (#1809)
### What this PR does / why we need it?
Remove multi-step worker
This PR is a part of
https://github.com/vllm-project/vllm-ascend/issues/1620.
- vLLM version: v0.9.2
- vLLM main:
235bfd5dfe
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -73,23 +73,6 @@
|
|||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Keep this patch in vllm-ascend.
|
# Keep this patch in vllm-ascend.
|
||||||
#
|
#
|
||||||
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
|
|
||||||
# Why:
|
|
||||||
# There are cuda hard code (current_platform.is_cuda_alike()) in
|
|
||||||
# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it.
|
|
||||||
# How:
|
|
||||||
# Make speculative decoding extensible to different backends.
|
|
||||||
# - support attention metadata register to the set supported spec decode
|
|
||||||
# - offer a api in platform to determine whether spec decode is supported,
|
|
||||||
# and deprecate is_cuda_alike in it.
|
|
||||||
# Related PR (if no, explain why):
|
|
||||||
# - https://github.com/vllm-project/vllm/pull/15195
|
|
||||||
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
|
||||||
# Future Plan:
|
|
||||||
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
|
||||||
#
|
|
||||||
# ** File: worker/patch_common/patch_spec_decode_worker.py **
|
# ** File: worker/patch_common/patch_spec_decode_worker.py **
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
|
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
|
||||||
|
|||||||
@@ -20,6 +20,5 @@
|
|||||||
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
||||||
|
|||||||
@@ -1,91 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
from typing import List, Set, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
||||||
from vllm.sequence import ExecuteModelRequest
|
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
|
||||||
|
|
||||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
|
||||||
|
|
||||||
|
|
||||||
def sampler_output(
|
|
||||||
self,
|
|
||||||
execute_model_req: ExecuteModelRequest,
|
|
||||||
sample_len: int,
|
|
||||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
|
||||||
) -> Tuple[List[SamplerOutput], bool]:
|
|
||||||
"""Run the model forward pass sample_len times. Returns the list of
|
|
||||||
sampler output, one per model forward pass, along with indicator of
|
|
||||||
whether torch tensor in sampler output need to be transposed in latter
|
|
||||||
sampler_output_to_torch logic.
|
|
||||||
|
|
||||||
For multi step worker, this indicator shall be True.
|
|
||||||
"""
|
|
||||||
self._raise_if_unsupported(execute_model_req)
|
|
||||||
# Expand the batch for sequences with a bonus token.
|
|
||||||
# Perform a forward pass on the expanded batch and filter the
|
|
||||||
# response to retain only the original sequences' responses.
|
|
||||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
|
||||||
self._expand_execute_model_request(
|
|
||||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
|
||||||
|
|
||||||
# Run model sample_len times.
|
|
||||||
model_outputs: List[SamplerOutput] = []
|
|
||||||
|
|
||||||
# TODO: supports_gpu_multi_step is False in ASCEND
|
|
||||||
if isinstance(self.model_runner, TP1DraftModelRunner) and \
|
|
||||||
self.model_runner.supports_gpu_multi_step(expanded_request):
|
|
||||||
# Here we run the draft_model_runner with multi-step prepare
|
|
||||||
# on the GPU directly
|
|
||||||
expanded_request.num_steps = sample_len
|
|
||||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
|
||||||
indices_of_seq_with_bonus_tokens)
|
|
||||||
model_outputs = self.execute_model(execute_model_req=expanded_request)
|
|
||||||
else:
|
|
||||||
# Here we run multi-step directly, with every step prepared
|
|
||||||
# on the CPU.
|
|
||||||
# TODO Remove this branch once DraftModelRunner supports TP>1
|
|
||||||
# and other restrictions that are part of DraftModelRunner's
|
|
||||||
# supports_gpu_multi_step(..)
|
|
||||||
if expanded_request.previous_hidden_states is not None:
|
|
||||||
self.worker.model_runner.return_hidden_states = True
|
|
||||||
for _ in range(sample_len):
|
|
||||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
|
||||||
execute_model_req=expanded_request)
|
|
||||||
assert (len(model_output) == 1
|
|
||||||
), "composing multistep workers not supported"
|
|
||||||
model_output = model_output[0]
|
|
||||||
self._maybe_update_previous_hidden_states(model_output,
|
|
||||||
expanded_request)
|
|
||||||
|
|
||||||
self._append_new_tokens(model_output,
|
|
||||||
expanded_request.seq_group_metadata_list,
|
|
||||||
indices_of_seq_with_bonus_tokens)
|
|
||||||
model_outputs.append(model_output)
|
|
||||||
|
|
||||||
# move indices to device to avoid stream sync
|
|
||||||
indices_of_seq_with_bonus_tokens = torch.tensor(
|
|
||||||
indices_of_seq_with_bonus_tokens, device=self.device)
|
|
||||||
filtered_model_outputs = self._filter_model_output(
|
|
||||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
|
||||||
return filtered_model_outputs, True
|
|
||||||
|
|
||||||
|
|
||||||
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)
|
|
||||||
@@ -1,194 +0,0 @@
|
|||||||
import dataclasses
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
||||||
from vllm.sequence import ExecuteModelRequest
|
|
||||||
from vllm.worker.model_runner_base import BroadcastableModelInput
|
|
||||||
from vllm.worker.multi_step_model_runner import StatefulModelInput
|
|
||||||
|
|
||||||
from vllm_ascend.worker.multi_step_runner import MultiStepModelNPURunner
|
|
||||||
from vllm_ascend.worker.worker import NPUWorker, WorkerInput
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MultiStepState:
|
|
||||||
worker_input: WorkerInput
|
|
||||||
model_input: StatefulModelInput
|
|
||||||
|
|
||||||
|
|
||||||
class MultiStepWorker(NPUWorker):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
base_model_runner = self.model_runner
|
|
||||||
# for multi-step model, wrap the model runner with MultiStepModelRunner
|
|
||||||
self.model_runner = MultiStepModelNPURunner(
|
|
||||||
base_model_runner,
|
|
||||||
vllm_config=base_model_runner.vllm_config,
|
|
||||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
|
||||||
is_driver_worker=base_model_runner.is_driver_worker,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
|
|
||||||
self.multi_step_states: List[
|
|
||||||
Optional[MultiStepState]] = [None] * pipeline_parallel_size
|
|
||||||
self.temp_output = None
|
|
||||||
|
|
||||||
def _get_driver_input_and_broadcast(
|
|
||||||
self, execute_model_req: ExecuteModelRequest
|
|
||||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Get the driver input and broadcast it to other workers.
|
|
||||||
"""
|
|
||||||
assert self.is_driver_worker
|
|
||||||
virtual_engine = execute_model_req.virtual_engine
|
|
||||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
|
||||||
if is_first_multi_step:
|
|
||||||
# on first step we prepare the worker input and model input normally
|
|
||||||
worker_input: WorkerInput = self.prepare_worker_input(
|
|
||||||
execute_model_req=execute_model_req)
|
|
||||||
model_input: StatefulModelInput = (
|
|
||||||
self.model_runner.prepare_model_input(
|
|
||||||
execute_model_req.seq_group_metadata_list,
|
|
||||||
execute_model_req.virtual_engine,
|
|
||||||
execute_model_req.finished_requests_ids))
|
|
||||||
|
|
||||||
if execute_model_req.async_callback:
|
|
||||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
|
||||||
model_input.frozen_model_input,
|
|
||||||
async_callback=execute_model_req.async_callback)
|
|
||||||
else:
|
|
||||||
# on subsequent steps we reuse the worker input and model input
|
|
||||||
multi_step_state = self.multi_step_states[virtual_engine]
|
|
||||||
worker_input = multi_step_state.worker_input
|
|
||||||
model_input = multi_step_state.model_input
|
|
||||||
frozen_model_input = model_input.frozen_model_input
|
|
||||||
assert frozen_model_input is not None
|
|
||||||
assert frozen_model_input.attn_metadata is not None
|
|
||||||
# clear the cached metadata so that it can be recomputed on
|
|
||||||
# the workers.
|
|
||||||
frozen_model_input.attn_metadata._cached_prefill_metadata = None
|
|
||||||
frozen_model_input.attn_metadata._cached_decode_metadata = None
|
|
||||||
|
|
||||||
model_input.is_first_multi_step = is_first_multi_step
|
|
||||||
model_input.is_last_step = execute_model_req.is_last_step
|
|
||||||
|
|
||||||
if not is_first_multi_step:
|
|
||||||
# we broadcast the last sampled token ids to all TP workers so they
|
|
||||||
# can update their model input metadata in-place.
|
|
||||||
self._prepare_last_sampled_token_ids_for_tp_workers(
|
|
||||||
execute_model_req=execute_model_req, model_input=model_input)
|
|
||||||
|
|
||||||
if self.do_metadata_broadcast:
|
|
||||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
|
||||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
|
||||||
broadcast_tensor_dict(broadcast_data, src=0)
|
|
||||||
|
|
||||||
# Retuning empty dict here to keep this compatible with
|
|
||||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
|
||||||
return model_input, worker_input, {}
|
|
||||||
|
|
||||||
def _prepare_last_sampled_token_ids_for_tp_workers(
|
|
||||||
self,
|
|
||||||
execute_model_req: ExecuteModelRequest,
|
|
||||||
model_input: StatefulModelInput,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Prepare the last sampled token ids for TP workers. If it's the last
|
|
||||||
PP rank, then the last sampled token ids are already in the model_input.
|
|
||||||
If it is NOT the last PP rank, then we need to get the last sampled
|
|
||||||
token that is cached in the execute_model_req.
|
|
||||||
"""
|
|
||||||
if get_pp_group().is_last_rank:
|
|
||||||
assert model_input.cached_outputs[
|
|
||||||
-1].sampler_output.sampled_token_ids is None
|
|
||||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
|
||||||
model_input.last_sampled_token_ids = model_input.cached_outputs[
|
|
||||||
-1].sampled_token_ids
|
|
||||||
# free sampled token ids from the previous step if it has been
|
|
||||||
# pythonized. Cannot free the last sampled token ids because
|
|
||||||
# we need it for GPU advance_step.
|
|
||||||
for output in model_input.cached_outputs[:-1]:
|
|
||||||
if output.pythonized:
|
|
||||||
output.sampled_token_ids = None
|
|
||||||
else:
|
|
||||||
# otherwise we need to get the cached sampled token ids from the
|
|
||||||
# execute_model_req
|
|
||||||
assert execute_model_req.last_sampled_token_ids is not None
|
|
||||||
model_input.last_sampled_token_ids = (
|
|
||||||
execute_model_req.last_sampled_token_ids.npu())
|
|
||||||
model_input.add_sampler_output(
|
|
||||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
|
||||||
model_input.last_sampled_token_ids)
|
|
||||||
|
|
||||||
# free sampled token ids from the previous step.
|
|
||||||
# TODO(will) we could reuse the sampled token ids tensor from
|
|
||||||
# the previous step instead.
|
|
||||||
for output in model_input.cached_outputs[:-1]:
|
|
||||||
output.sampled_token_ids = None
|
|
||||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
|
||||||
|
|
||||||
def prepare_input(
|
|
||||||
self,
|
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
|
||||||
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
|
|
||||||
torch.Tensor]]]:
|
|
||||||
"""
|
|
||||||
Depending on the current state of the request and multi step worker,
|
|
||||||
this method may skip the normal _prepare_model_input and
|
|
||||||
_prepare_worker_input methods and instead used cached values.
|
|
||||||
"""
|
|
||||||
if self.is_driver_worker:
|
|
||||||
if execute_model_req is None:
|
|
||||||
if self.do_metadata_broadcast:
|
|
||||||
# This signals that there's no more requests to process for
|
|
||||||
# now. All workers are running infinite loop with
|
|
||||||
# broadcast_tensor_dict, and it stops the loop when the
|
|
||||||
# driver broadcasts an empty input. Send an empty input to
|
|
||||||
# notify all other workers to stop their execution loop.
|
|
||||||
broadcast_tensor_dict({}, src=0)
|
|
||||||
return None
|
|
||||||
|
|
||||||
virtual_engine = execute_model_req.virtual_engine
|
|
||||||
(model_input, worker_input,
|
|
||||||
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
|
|
||||||
assert isinstance(model_input, StatefulModelInput)
|
|
||||||
if execute_model_req.is_first_multi_step:
|
|
||||||
# cache the worker input and model input for the next steps
|
|
||||||
self.multi_step_states[virtual_engine] = MultiStepState(
|
|
||||||
worker_input=worker_input, model_input=model_input)
|
|
||||||
# if TP workers
|
|
||||||
else:
|
|
||||||
broadcast_data = self._get_worker_input_from_broadcast()
|
|
||||||
# if the driver has sent an empty input, we should stop the worker
|
|
||||||
# loop
|
|
||||||
if broadcast_data is None:
|
|
||||||
return None
|
|
||||||
model_input, worker_input, kwargs = broadcast_data
|
|
||||||
assert isinstance(model_input, StatefulModelInput)
|
|
||||||
virtual_engine = worker_input.virtual_engine
|
|
||||||
if model_input.is_first_multi_step:
|
|
||||||
pass
|
|
||||||
# TODO(will) Can cache the worker input and model input for the
|
|
||||||
# next steps. See below for details
|
|
||||||
else:
|
|
||||||
# TODO(will) possible to also cache and reuse the cached worker
|
|
||||||
# input and model input. The idea is essentially the delta
|
|
||||||
# optimization for model_inputs. Where the TP workers can cache
|
|
||||||
# the model input states and we only broadcast the delta need
|
|
||||||
# for the next step (sampled_token_ids from the previous step)
|
|
||||||
|
|
||||||
assert isinstance(model_input, StatefulModelInput)
|
|
||||||
# we need to update the last sampled token ids in the model
|
|
||||||
# input for the workers so that they can run inplace
|
|
||||||
# advance_step
|
|
||||||
model_input.add_sampler_output(
|
|
||||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
|
||||||
model_input.last_sampled_token_ids)
|
|
||||||
|
|
||||||
assert model_input is not None
|
|
||||||
assert worker_input is not None
|
|
||||||
return model_input, worker_input, kwargs
|
|
||||||
Reference in New Issue
Block a user