Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/outputs.py

365 lines
14 KiB
Python
Raw Normal View History

2026-02-04 17:22:39 +08:00
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus)
@dataclass
class CompletionOutput:
"""The output data of one completion output of a request.
Args:
index: The index of the output in the request.
text: The generated output text.
token_ids: The token IDs of the generated output text.
cumulative_logprob: The cumulative log probability of the generated
output text.
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
stop_reason: The stop string or token id that caused the completion
to stop, None if the completion finished for some other reason
including encountering the EOS token.
lora_request: The LoRA request that was used to generate the output.
"""
index: int
text: str
token_ids: GenericSequence[int]
cumulative_logprob: Optional[float]
logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
lora_request: Optional[LoRARequest] = None
def finished(self) -> bool:
return self.finish_reason is not None
def __repr__(self) -> str:
return (f"CompletionOutput(index={self.index}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")
@dataclass
class EmbeddingOutput:
"""The output data of one completion output of a request.
Args:
embedding: The embedding vector, which is a list of floats. The
length of vector depends on the model as listed in the embedding guide.
"""
embedding: List[float]
def __repr__(self) -> str:
return (f"EmbeddingOutput("
f"embedding={len(self.embedding)})")
class RequestOutput:
"""The output data of a completion request to the LLM.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
For encoder/decoder models, this is the
decoder input prompt.
prompt_token_ids: The token IDs of the prompt.
For encoder/decoder models, this is the
decoder input prompt token ids.
prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request.
None if decoder-only.
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
"""
def __init__(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
finished: bool,
metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None,
num_cached_tokens: Optional[int] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_logprobs = prompt_logprobs
self.outputs = outputs
self.finished = finished
self.metrics = metrics
self.lora_request = lora_request
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
@classmethod
def new(
cls,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
text: str,
token_ids: List[int],
finished: bool = False,
) -> "RequestOutput":
"""Initialize a new RequestOutput object."""
# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text=text,
token_ids=token_ids,
cumulative_logprob=None,
logprobs=None, # TODO
)
return RequestOutput(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, # TODO
outputs=[completion_output],
finished=finished,
)
@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
) -> Optional["RequestOutput"]:
finished = seq_group.is_finished()
if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id]
if finished:
group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None:
return None
return cls.from_seq_group(assembled_seq_group, use_cache,
seq_id_to_seq_group)
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
# Init cache (if needed)
if use_cache and seq_group.cached_request_output is None:
seq_group.cached_request_output = RequestOutput( # type: ignore
request_id="",
prompt=None,
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[],
finished=False)
top_n_seqs = seq_group.get_seqs()
# Create the outputs.
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = sampling_params.logprobs is not None
text_buffer_length = sampling_params.output_text_buffer_length
delta = sampling_params.output_kind == RequestOutputKind.DELTA
outputs = []
include_prompt = True
# num_cached_tokens should be the same for all the sequences
num_cached_tokens = None
for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)
output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)
num_cached_tokens = seq.data.get_num_cached_tokens()
output_logprobs = seq.output_logprobs if include_logprobs else None
if delta:
# Slice logprobs delta if applicable
if output_logprobs:
output_logprobs = output_logprobs[-num_output_tokens:]
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > num_output_tokens:
include_prompt = False
if use_cache:
# Get cached output object
cached_outputs = seq_group.cached_request_output.outputs # type: ignore
if i >= len(cached_outputs):
cached_outputs.append(
CompletionOutput(index=i,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None))
output = cached_outputs[i]
# Init cached output object
assert output.index == i
output.text = output_text
if isinstance(output_token_ids, int):
output.token_ids.clear()
output.token_ids.append(output_token_ids)
else:
output.token_ids = output_token_ids
output.cumulative_logprob = seq.get_cumulative_logprob() \
if include_logprobs else None
output.logprobs = output_logprobs
output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
output.stop_reason = seq.stop_reason
else:
output = CompletionOutput(
top_n_seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason)
outputs.append(output)
# Every sequence in the sequence group should have the same prompt.
if include_prompt:
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished, seq_group.metrics,
seq_group.lora_request, encoder_prompt,
encoder_prompt_token_ids, num_cached_tokens)
if use_cache:
request_output = seq_group.cached_request_output
request_output.__init__(*init_args) # type: ignore
else:
request_output = cls(*init_args)
return request_output
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"encoder_prompt={self.encoder_prompt!r}, "
f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request}, "
f"num_cached_tokens={self.num_cached_tokens})")
class EmbeddingRequestOutput:
"""
The output data of an embedding request to the LLM.
Args:
request_id (str): A unique identifier for the embedding request.
outputs (EmbeddingOutput): The embedding results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the embedding is completed.
"""
def __init__(self, request_id: str, outputs: "EmbeddingOutput",
prompt_token_ids: List[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
self.finished = finished
self.outputs = outputs
@classmethod
def from_seq_group(cls,
seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
if seq_group.embeddings is None:
raise ValueError(
"Embeddings are missing in seq_group for EmbeddingRequest.")
output = EmbeddingOutput(seq_group.embeddings)
prompt_token_ids = seq_group.prompt_token_ids
finished = seq_group.is_finished()
return cls(seq_group.request_id, output, prompt_token_ids, finished)
def __repr__(self):
"""
Returns a string representation of an EmbeddingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results.
Returns:
str: A string representation of the EmbeddingRequestOutput instance.
"""
return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished})")
class RequestOutputFactory:
@staticmethod
def create(seq_group: SequenceGroup,
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group)