from typing import List, Set, Tuple import torch from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest) if current_platform.is_cuda_alike(): from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.worker.worker_base import DelegateWorkerBase from .var import * class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, seq_ids_with_bonus_token_in_last_step: Set[int], ): """Run the model forward pass sample_len times. Returns the list of sampler output, one per model forward pass, along with indicator of whether torch tensor in sampler output need to be transposed in latter sampler_output_to_torch logic. For multi step worker, this indicator shall be True. """ self._raise_if_unsupported(execute_model_req) # Expand the batch for sequences with a bonus token. # Perform a forward pass on the expanded batch and filter the # response to retain only the original sequences' responses. expanded_request, indices_of_seq_with_bonus_tokens =\ self._expand_execute_model_request( execute_model_req, seq_ids_with_bonus_token_in_last_step) # Run model sample_len times. model_outputs: List[SamplerOutput] = [] # Here we run multi-step directly, with every step prepared # on the CPU. # TODO: Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) if expanded_request.previous_hidden_states is not None: self.worker.model_runner.return_hidden_states = True for _ in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] self._maybe_update_previous_hidden_states( model_output, expanded_request) self._append_new_tokens( model_output, expanded_request.seq_group_metadata_list, indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) # 融合算子中进行outputs相关tensor选择 if USE_FUSED_MTP_SAMPLER: return model_outputs, indices_of_seq_with_bonus_tokens, True # move indices to device to avoid stream sync indices_of_seq_with_bonus_tokens = torch.tensor( indices_of_seq_with_bonus_tokens, device=self.device) filtered_model_outputs = self._filter_model_output( model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True