update
This commit is contained in:
109
vllm/beam_search.py
Normal file
109
vllm/beam_search.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.inputs import TokenInputs, token_inputs
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchSequence:
|
||||
"""A sequence for beam search.
|
||||
It keeps track of the tokens and the log probability of the sequence.
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
|
||||
orig_prompt: TokenInputs | MultiModalInputs
|
||||
|
||||
# The tokens include the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
lora_request: LoRARequest | None = None
|
||||
cum_logprob: float = 0.0
|
||||
text: str | None = None
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
def get_prompt(self):
|
||||
prompt = self.orig_prompt
|
||||
|
||||
prompt_text = prompt.get("prompt")
|
||||
cache_salt = prompt.get("cache_salt")
|
||||
|
||||
if prompt["type"] == "token":
|
||||
return token_inputs(
|
||||
self.tokens,
|
||||
prompt=prompt_text,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
return mm_inputs(
|
||||
prompt_token_ids=self.tokens,
|
||||
mm_kwargs=prompt["mm_kwargs"],
|
||||
mm_hashes=prompt["mm_hashes"],
|
||||
mm_placeholders=prompt["mm_placeholders"],
|
||||
prompt=prompt_text,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchOutput:
|
||||
"""The output of beam search.
|
||||
It contains the list of the best beam search sequences.
|
||||
The length of the list is equal to the beam width.
|
||||
"""
|
||||
|
||||
sequences: list[BeamSearchSequence]
|
||||
|
||||
|
||||
class BeamSearchInstance:
|
||||
def __init__(
|
||||
self,
|
||||
prompt: TokenInputs | MultiModalInputs,
|
||||
lora_request: LoRARequest | None = None,
|
||||
logprobs: list[dict[int, Logprob]] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.beams: list[BeamSearchSequence] = [
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=prompt["prompt_token_ids"],
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
lora_request=lora_request,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
self.completed: list[BeamSearchSequence] = []
|
||||
|
||||
|
||||
def get_beam_search_score(
|
||||
tokens: list[int],
|
||||
cumulative_logprob: float,
|
||||
eos_token_id: int,
|
||||
length_penalty: float = 1.0,
|
||||
) -> float:
|
||||
"""Calculate the beam search score with length penalty.
|
||||
|
||||
Adapted from
|
||||
|
||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||
"""
|
||||
seq_len = len(tokens)
|
||||
if tokens[-1] == eos_token_id:
|
||||
seq_len -= 1
|
||||
|
||||
return cumulative_logprob / (seq_len**length_penalty)
|
||||
|
||||
|
||||
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(
|
||||
x.tokens, x.cum_logprob, eos_token_id, length_penalty
|
||||
)
|
||||
|
||||
return sort_beams_key
|
||||
Reference in New Issue
Block a user