init
This commit is contained in:
150
vllm/outputs.py
Normal file
150
vllm/outputs.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
SequenceGroup, SequenceStatus)
|
||||
|
||||
|
||||
class CompletionOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
Args:
|
||||
index: The index of the output in the request.
|
||||
text: The generated output text.
|
||||
token_ids: The token IDs of the generated output text.
|
||||
cumulative_logprob: The cumulative log probability of the generated
|
||||
output text.
|
||||
logprobs: The log probabilities of the top probability words at each
|
||||
position if the logprobs are requested.
|
||||
finish_reason: The reason why the sequence is finished.
|
||||
stop_reason: The stop string or token id that caused the completion
|
||||
to stop, None if the completion finished for some other reason
|
||||
including encountering the EOS token.
|
||||
lora_request: The LoRA request that was used to generate the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
text: str,
|
||||
token_ids: List[int],
|
||||
cumulative_logprob: float,
|
||||
logprobs: Optional[SampleLogprobs],
|
||||
finish_reason: Optional[str] = None,
|
||||
stop_reason: Union[int, str, None] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
self.index = index
|
||||
self.text = text
|
||||
self.token_ids = token_ids
|
||||
self.cumulative_logprob = cumulative_logprob
|
||||
self.logprobs = logprobs
|
||||
self.finish_reason = finish_reason
|
||||
self.stop_reason = stop_reason
|
||||
self.lora_request = lora_request
|
||||
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"CompletionOutput(index={self.index}, "
|
||||
f"text={self.text!r}, "
|
||||
f"token_ids={self.token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"finish_reason={self.finish_reason}, "
|
||||
f"stop_reason={self.stop_reason})")
|
||||
|
||||
|
||||
class RequestOutput:
|
||||
"""The output data of a request to the LLM.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string of the request.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
prompt_logprobs: The log probabilities to return per prompt token.
|
||||
outputs: The output sequences of the request.
|
||||
finished: Whether the whole request is finished.
|
||||
metrics: Metrics associated with the request.
|
||||
lora_request: The LoRA request that was used to generate the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
outputs: List[CompletionOutput],
|
||||
finished: bool,
|
||||
metrics: Optional[RequestMetrics] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.outputs = outputs
|
||||
self.finished = finished
|
||||
self.metrics = metrics
|
||||
self.lora_request = lora_request
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||
seqs = seq_group.get_seqs()
|
||||
if len(seqs) == 1:
|
||||
top_n_seqs = seqs
|
||||
else:
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
if seq_group.sampling_params.use_beam_search:
|
||||
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||
seq_group.sampling_params.length_penalty)
|
||||
else:
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
|
||||
# Create the outputs.
|
||||
# NOTE: We need omit logprobs here explicitly because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
include_logprobs = seq_group.sampling_params.logprobs is not None
|
||||
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
|
||||
outputs = [
|
||||
CompletionOutput(seqs.index(seq),
|
||||
seq.get_output_text_to_return(text_buffer_length),
|
||||
seq.get_output_token_ids(),
|
||||
seq.get_cumulative_logprob(),
|
||||
seq.output_logprobs if include_logprobs else None,
|
||||
SequenceStatus.get_finished_reason(seq.status),
|
||||
seq.stop_reason) for seq in top_n_seqs
|
||||
]
|
||||
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = seq_group.prompt
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
prompt_logprobs = seq_group.prompt_logprobs
|
||||
finished = seq_group.is_finished()
|
||||
finished_time = time.time() if finished else None
|
||||
seq_group.set_finished_time(finished_time)
|
||||
return cls(seq_group.request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_logprobs,
|
||||
outputs,
|
||||
finished,
|
||||
seq_group.metrics,
|
||||
lora_request=seq_group.lora_request)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished}, "
|
||||
f"metrics={self.metrics}, "
|
||||
f"lora_request={self.lora_request})")
|
||||
Reference in New Issue
Block a user