Files
sglang/python/sglang/srt/entrypoints/openai/serving_completions.py

461 lines
18 KiB
Python

from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.code_completion_parser import (
generate_completion_prompt_from_request,
)
from sglang.utils import convert_json_schema_to_str
if TYPE_CHECKING:
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__name__)
class OpenAIServingCompletion(OpenAIServingBase):
"""Handler for /v1/completion requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "cmpl-"
def _validate_request(self, request: CompletionRequest) -> Optional[str]:
"""Validate that the input is valid."""
prompt = request.prompt
if not prompt or (isinstance(prompt, list) and all(not p for p in prompt)):
return "Prompt cannot be empty"
return None
def _convert_to_internal_request(
self,
request: CompletionRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, CompletionRequest]:
"""Convert OpenAI completion request to internal format"""
# NOTE: with openai API, the prompt's logprobs are always not computed
if request.echo and request.logprobs:
logger.warning(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use the native /generate API."
)
# Process prompt
prompt = request.prompt
if self.template_manager.completion_template_name is not None:
prompt = generate_completion_prompt_from_request(request)
# Set logprob start length based on echo and logprobs
if request.echo and request.logprobs:
logprob_start_len = 0
else:
logprob_start_len = -1
# Build sampling parameters
sampling_params = self._build_sampling_params(request)
# Determine prompt format
if isinstance(prompt, str) or (
isinstance(prompt, list) and isinstance(prompt[0], str)
):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params=sampling_params,
return_logprob=request.logprobs is not None,
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
logprob_start_len=logprob_start_len,
return_text_in_logprobs=True,
stream=request.stream,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
customer_labels=customer_labels,
)
return adapted_request, request
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
"""Build sampling parameters for the request"""
# Start with common parameters
sampling_params = {
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
}
# Handle response_format constraints
if request.response_format and request.response_format.type == "json_schema":
sampling_params["json_schema"] = convert_json_schema_to_str(
request.response_format.json_schema.schema_
)
elif request.response_format and request.response_format.type == "json_object":
sampling_params["json_schema"] = '{"type": "object"}'
elif (
request.response_format and request.response_format.type == "structural_tag"
):
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
return sampling_params
async def _handle_streaming_request(
self,
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> StreamingResponse:
"""Handle streaming completion request"""
return StreamingResponse(
self._generate_completion_stream(adapted_request, request, raw_request),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def _generate_completion_stream(
self,
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> AsyncGenerator[str, None]:
"""Generate streaming completion response"""
created = int(time.time())
# State tracking for streaming
stream_buffers = {}
n_prev_tokens = {}
# Usage tracking
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
text = content["text"]
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
stream_buffer = stream_buffers.get(index, "")
# Handle echo for first chunk
if not stream_buffer: # The first chunk
if request.echo:
echo_text = self._get_echo_text(request, index)
text = echo_text + text
# Handle logprobs
logprobs = None
if request.logprobs is not None:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
input_token_logprobs = content["meta_info"][
"input_token_logprobs"
]
input_top_logprobs = content["meta_info"]["input_top_logprobs"]
else:
input_token_logprobs = None
input_top_logprobs = None
n_prev_token = n_prev_tokens.get(index, 0)
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"]["output_top_logprobs"][
n_prev_token:
],
)
n_prev_tokens[index] = len(
content["meta_info"]["output_token_logprobs"]
)
# Generate delta
delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta,
logprobs=logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[
CompletionResponseStreamChoice(
index=index,
text="",
hidden_states=last_token_hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage(
prompt_tokens,
completion_tokens,
cached_tokens,
n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[],
model=request.model,
usage=usage,
)
final_usage_data = final_usage_chunk.model_dump_json(exclude_none=True)
yield f"data: {final_usage_data}\n\n"
except Exception as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
async def _handle_non_streaming_request(
self,
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming completion request"""
try:
generator = self.tokenizer_manager.generate_request(
adapted_request, raw_request
)
ret = await generator.__anext__()
except ValueError as e:
return self.create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = self._build_completion_response(
request,
ret,
int(time.time()),
)
return response
def _build_completion_response(
self,
request: CompletionRequest,
ret: List[Dict[str, Any]],
created: int,
) -> CompletionResponse:
"""Build completion response from generation results"""
choices = []
echo = False
# Prepare echo prompts if needed
echo_prompts = []
if request.echo:
echo_prompts = self._prepare_echo_prompts(request)
echo = True
for idx, ret_item in enumerate(ret):
text = ret_item["text"]
# Handle echo
if echo:
prompt_index = idx // request.n
text = echo_prompts[prompt_index] + text
# Handle logprobs
logprobs = None
if request.logprobs is not None:
if echo:
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=ret_item["meta_info"][
"output_token_logprobs"
],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)
finish_reason = ret_item["meta_info"]["finish_reason"]
choice_data = CompletionResponseChoice(
index=idx,
text=text,
logprobs=logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)
# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
)
return CompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
created=created,
choices=choices,
usage=usage,
metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
)
def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
"""Get echo text for streaming response"""
if isinstance(request.prompt, str):
# for the case of single str prompts
return request.prompt
elif isinstance(request.prompt, list):
if isinstance(request.prompt[0], str):
# for the case of multiple str prompts
return request.prompt[index // request.n]
elif isinstance(request.prompt[0], int):
# for the case of single token ids prompt
return self.tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
elif isinstance(request.prompt[0], list) and isinstance(
request.prompt[0][0], int
):
# for the case of multiple token ids prompts
return self.tokenizer_manager.tokenizer.decode(
request.prompt[index // request.n],
skip_special_tokens=True,
)
return ""
def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]:
"""Prepare echo prompts for non-streaming response"""
# TODO: handle the case prompt is token ids
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
# for the case of multiple str prompts
return request.prompt
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
# for the case of multiple token ids prompts
return [
self.tokenizer_manager.tokenizer.decode(
prompt, skip_special_tokens=True
)
for prompt in request.prompt
]
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
# for the case of single token ids prompt
return [
self.tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
]
else:
# for the case of single str prompt
return [request.prompt]