75 lines
3.0 KiB
Python
75 lines
3.0 KiB
Python
|
|
|
||
|
|
|
||
|
|
from typing import List, Set, Tuple
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||
|
|
from vllm.platforms import current_platform
|
||
|
|
from vllm.sequence import (ExecuteModelRequest)
|
||
|
|
|
||
|
|
if current_platform.is_cuda_alike():
|
||
|
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||
|
|
|
||
|
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||
|
|
from vllm.worker.worker_base import DelegateWorkerBase
|
||
|
|
|
||
|
|
from .var import *
|
||
|
|
|
||
|
|
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||
|
|
|
||
|
|
def sampler_output(
|
||
|
|
self,
|
||
|
|
execute_model_req: ExecuteModelRequest,
|
||
|
|
sample_len: int,
|
||
|
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||
|
|
):
|
||
|
|
"""Run the model forward pass sample_len times. Returns the list of
|
||
|
|
sampler output, one per model forward pass, along with indicator of
|
||
|
|
whether torch tensor in sampler output need to be transposed in latter
|
||
|
|
sampler_output_to_torch logic.
|
||
|
|
|
||
|
|
For multi step worker, this indicator shall be True.
|
||
|
|
"""
|
||
|
|
self._raise_if_unsupported(execute_model_req)
|
||
|
|
# Expand the batch for sequences with a bonus token.
|
||
|
|
# Perform a forward pass on the expanded batch and filter the
|
||
|
|
# response to retain only the original sequences' responses.
|
||
|
|
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||
|
|
self._expand_execute_model_request(
|
||
|
|
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||
|
|
|
||
|
|
# Run model sample_len times.
|
||
|
|
model_outputs: List[SamplerOutput] = []
|
||
|
|
|
||
|
|
# Here we run multi-step directly, with every step prepared
|
||
|
|
# on the CPU.
|
||
|
|
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||
|
|
# and other restrictions that are part of DraftModelRunner's
|
||
|
|
# supports_gpu_multi_step(..)
|
||
|
|
if expanded_request.previous_hidden_states is not None:
|
||
|
|
self.worker.model_runner.return_hidden_states = True
|
||
|
|
for _ in range(sample_len):
|
||
|
|
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||
|
|
execute_model_req=expanded_request)
|
||
|
|
assert (len(model_output) == 1
|
||
|
|
), "composing multistep workers not supported"
|
||
|
|
model_output = model_output[0]
|
||
|
|
self._maybe_update_previous_hidden_states(
|
||
|
|
model_output, expanded_request)
|
||
|
|
|
||
|
|
self._append_new_tokens(
|
||
|
|
model_output, expanded_request.seq_group_metadata_list,
|
||
|
|
indices_of_seq_with_bonus_tokens)
|
||
|
|
model_outputs.append(model_output)
|
||
|
|
|
||
|
|
# 融合算子中进行outputs相关tensor选择
|
||
|
|
if USE_FUSED_MTP_SAMPLER:
|
||
|
|
return model_outputs, indices_of_seq_with_bonus_tokens, True
|
||
|
|
|
||
|
|
# move indices to device to avoid stream sync
|
||
|
|
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||
|
|
indices_of_seq_with_bonus_tokens, device=self.device)
|
||
|
|
filtered_model_outputs = self._filter_model_output(
|
||
|
|
model_outputs, indices_of_seq_with_bonus_tokens)
|
||
|
|
return filtered_model_outputs, True
|