Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -1,67 +1,90 @@
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import cast
import jinja2
from fastapi import Request
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.utils import merge_async_iterators, random_uuid
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
return prompt_is_tokens, prompts
class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
# set up logits processors
self.logits_processors = self.model_config.logits_processors
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.enable_force_include_usage = enable_force_include_usage
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info(
"Using default completion sampling params from %s: %s",
source,
self.default_sampling_params,
)
async def create_completion(
self,
request: CompletionRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
@@ -75,90 +98,214 @@ class OpenAIServingCompletion(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
return self.create_error_response("suffix is not currently supported")
model_name = self.served_model_names[0]
request_id = f"cmpl-{random_uuid()}"
if request.echo and request.prompt_embeds is not None:
return self.create_error_response("Echo is unsupported with prompt embeds.")
if request.prompt_logprobs is not None and request.prompt_embeds is not None:
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds."
)
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
created_time = int(time.time())
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
lora_request = self._maybe_get_adapters(request)
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
raise NotImplementedError
generators.append(
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids,
lora_request=lora_request))
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generator = self.engine_client.generate(
engine_request,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
result_generator = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
# beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# We do not stream the results when using beam search.
stream = request.stream and not request.use_beam_search
# Streaming response
if stream:
return self.completion_stream_generator(request,
raw_request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts))
return self.completion_stream_generator(
request,
engine_prompts,
result_generator,
request_id,
created_time,
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
final_res_batch: list[RequestOutput | None] = [None] * num_prompts
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = (
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name)
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@@ -179,80 +326,126 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
engine_prompts: list[TokensPrompt | EmbedsPrompt],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None
first_iteration = True
stream_options = request.stream_options
include_usage, include_continuous_usage = should_include_usage(
stream_options, self.enable_force_include_usage
)
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
if first_iteration:
num_cached_tokens = res.num_cached_tokens
first_iteration = False
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = (
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
i = output.index + prompt_idx * num_choices
# Useful when request.return_token_ids is True
# Returning prompt token IDs shares the same logic
# with the echo implementation.
prompt_token_ids_to_return: list[int] | None = None
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
top_logprobs = res.prompt_logprobs + (output.logprobs
or [])
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
else:
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids,
*output.token_ids,
]
out_logprobs = [
*(prompt_logprobs or []),
*(output.logprobs or []),
]
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
top_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
# has_echoed[i] is reused here to indicate whether
# we have already returned the prompt token IDs.
if not has_echoed[i] and request.return_token_ids:
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
if (
not delta_text
and not delta_token_ids
and not previous_num_tokens[i]
):
# Chunked prefill case, don't return empty chunks
continue
if request.logprobs is not None:
logprobs = self._create_logprobs(
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
tokenizer=tokenizer,
initial_text_offset=previous_text_lens[i],
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
if output.finish_reason is not None: # return final usage
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
final_usage = None
response_json = CompletionStreamResponse(
self._raise_if_error(finish_reason, request_id)
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
@@ -263,58 +456,129 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
prompt_token_ids=prompt_token_ids_to_return,
token_ids=(
as_list(output.token_ids)
if request.return_token_ids
else None
),
)
],
usage=final_usage,
).model_dump_json(exclude_unset=True)
)
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"
except ValueError as e:
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens
)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=final_usage_info,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: List[RequestOutput],
final_res_batch: list[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
kv_transfer_params = None
last_final_res = None
for final_res in final_res_batch:
assert final_res is not None
last_final_res = final_res
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in final_res.outputs:
self._raise_if_error(output.finish_reason, request_id)
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text
if request.echo:
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
else:
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
out_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
@@ -325,12 +589,19 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
prompt_token_ids=(
prompt_token_ids if request.return_token_ids else None
),
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
),
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
@@ -338,10 +609,121 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if (
self.enable_prompt_tokens_details
and last_final_res
and last_final_res.num_cached_tokens
):
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=last_final_res.num_cached_tokens
)
request_metadata.final_usage_info = usage
if final_res_batch:
kv_transfer_params = final_res_batch[0].kv_transfer_params
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
kv_transfer_params=kv_transfer_params,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int,
tokenizer: TokenizerLike | None,
initial_text_offset: int = 0,
return_as_token_id: bool | None = None,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: list[int] = []
out_token_logprobs: list[float | None] = []
out_tokens: list[str] = []
out_top_logprobs: list[dict[str, float] | None] = []
last_token_len = 0
should_return_as_token_id = (
return_as_token_id
if return_as_token_id is not None
else self.return_tokens_as_token_ids
)
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
if should_return_as_token_id:
token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_token,
token_id,
tokenizer,
return_as_token_id=should_return_as_token_id,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append(
{
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id,
): max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
}
)
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids),
)