# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from collections.abc import Iterable from dataclasses import dataclass from typing import Optional from vllm.logger import init_logger from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_ids_list_to_tokens) from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest from vllm.v1.outputs import LogprobsLists, LogprobsTensors logger = init_logger(__name__) NONES = itertools.repeat(None) @dataclass class LogprobsProcessor: # Tokenizer for this request, # None if detokenization is disabled. tokenizer: Optional[AnyTokenizer] # Logprobs for this request logprobs: Optional[SampleLogprobs] prompt_logprobs: Optional[PromptLogprobs] cumulative_logprob: Optional[float] num_logprobs: Optional[int] num_prompt_logprobs: Optional[int] @classmethod def from_new_request( cls, tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "LogprobsProcessor": num_logprobs = request.sampling_params.logprobs num_prompt_logprobs = request.sampling_params.prompt_logprobs return cls( tokenizer=tokenizer, cumulative_logprob=(None if num_logprobs is None else 0.), logprobs=(None if num_logprobs is None else []), # NOTE: logprob of first prompt token is None. prompt_logprobs=(None if num_prompt_logprobs is None else [None]), num_prompt_logprobs=num_prompt_logprobs, num_logprobs=num_logprobs, ) def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: """Update with sample logprobs from EngineCore. Outer lists are only of len > 1 if EngineCore made >1 tokens in prior step (e.g. in spec decoding). Args: logprobs_lists: the lists of logprob tokens, logprobs, and ranks. """ assert self.num_logprobs is not None assert self.logprobs is not None assert self.cumulative_logprob is not None token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): # Detokenize (non-incrementally). decoded_tokens = NONES if self.tokenizer is None else ( convert_ids_list_to_tokens(self.tokenizer, token_ids)) # Sampler puts the sampled logprob in first. sampled_token_logprob = logprobs[0] self.cumulative_logprob += sampled_token_logprob # Update with the Logprob dictionary for this pos. self.logprobs.append( self._make_logprob_dict( logprobs, token_ids, decoded_tokens, rank, self.num_logprobs, )) def _update_prompt_logprobs( self, prompt_logprobs_tensors: LogprobsTensors, ) -> None: """Update with prompt logprobs from EngineCore. Args: prompt_logprobs_tensors: tuple containing the prompt logprobs tensors. """ # Prompt logprobs are enabled. assert self.num_prompt_logprobs is not None assert self.prompt_logprobs is not None token_ids, logprobs, ranks = prompt_logprobs_tensors # Detokenize non-incrementally. # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] decoded_tokens = None if self.tokenizer is None else ( convert_ids_list_to_tokens(self.tokenizer, token_ids.flatten().tolist())) # Recover shapes. num_prompt_tokens, num_logprobs = logprobs.shape # Pythonize the torch tensors. prompt_token_ranks = ranks.tolist() prompt_logprobs = logprobs.tolist() token_ids = token_ids.tolist() # Make Logprob for each position. for pos in range(num_prompt_tokens): # Handle flattening. offset = pos * num_logprobs offset_end = offset + num_logprobs decoded_tokens_for_pos = NONES \ if decoded_tokens is None else decoded_tokens[offset:offset_end] # Update with the Logprob dictionary for this pos. self.prompt_logprobs.append( self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos], decoded_tokens_for_pos, prompt_token_ranks[pos], self.num_prompt_logprobs)) def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: """Pop and return all request prompt logprobs The logprobs processor aggregates prompt chunk logprobs over one or more prefill chunks. This method returns all prompt logprobs at once and then forgets them. Ensures correct RequestOutputKind.DELTA semantics wherein all prompt logprobs are returned at once at the end of prefill. Returns: None if prompt logprobs are disabled for this request. List of all prompt logprobs, otherwise. """ plp = self.prompt_logprobs if plp: self.prompt_logprobs = [] return plp @staticmethod def _make_logprob_dict( logprobs: list[float], logprob_token_ids: list[int], decoded_tokens: Iterable[Optional[str]], rank: int, num_logprobs: int, ) -> dict[int, Logprob]: """Make a Logprob dictionary for a position. Args: logprobs: list of log probabilities logprob_token_ids: list of top token ids decoded_tokens: list of decoded top tokens rank: rank of the sampled token num_logprobs: number of logprobs requested by the user (in addition to sampled logprob) Returns: dict[token id, Logprob] """ # We do not need a special case for the sampled token # being in the topk, since inserting duplicated data # into a dictionary twice is the same as doing it once. topk_ranks = range(1, num_logprobs + 1) ranks = itertools.chain((rank, ), topk_ranks) return { token_id: Logprob( logprob=logprob, rank=rank, decoded_token=token, ) for token_id, logprob, rank, token in zip( logprob_token_ids, logprobs, ranks, decoded_tokens) } def update_from_output(self, output: EngineCoreOutput) -> None: if output.new_logprobs is not None: self._update_sample_logprobs(output.new_logprobs) if output.new_prompt_logprobs_tensors is not None: self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)