[Refactor] OAI Server components (#7167)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
0
python/sglang/srt/entrypoints/openai/__init__.py
Normal file
0
python/sglang/srt/entrypoints/openai/__init__.py
Normal file
539
python/sglang/srt/entrypoints/openai/protocol.py
Normal file
539
python/sglang/srt/entrypoints/openai/protocol.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Pydantic models for OpenAI API protocol"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
"""Model cards."""
|
||||
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "sglang"
|
||||
root: Optional[str] = None
|
||||
max_model_len: Optional[int] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
"""Model list consists of model cards."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
object: str = "error"
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str] = None
|
||||
code: int
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TopLogprob(BaseModel):
|
||||
token: str
|
||||
bytes: List[int]
|
||||
logprob: float
|
||||
|
||||
|
||||
class ChatCompletionTokenLogprob(BaseModel):
|
||||
token: str
|
||||
bytes: List[int]
|
||||
logprob: float
|
||||
top_logprobs: List[TopLogprob]
|
||||
|
||||
|
||||
class ChoiceLogprobs(BaseModel):
|
||||
# build for v1/chat/completions response
|
||||
content: List[ChatCompletionTokenLogprob]
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
# only used to return cached tokens when --enable-cache-report is set
|
||||
prompt_tokens_details: Optional[Dict[str, int]] = None
|
||||
|
||||
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: Optional[bool] = False
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
# use alias to workaround pydantic conflict
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
strict: Optional[bool] = False
|
||||
|
||||
|
||||
class FileRequest(BaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/files/create
|
||||
file: bytes # The File object (not file name) to be uploaded
|
||||
purpose: str = (
|
||||
"batch" # The intended purpose of the uploaded file, default is "batch"
|
||||
)
|
||||
|
||||
|
||||
class FileResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "file"
|
||||
bytes: int
|
||||
created_at: int
|
||||
filename: str
|
||||
purpose: str
|
||||
|
||||
|
||||
class FileDeleteResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "file"
|
||||
deleted: bool
|
||||
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
input_file_id: (
|
||||
str # The ID of an uploaded file that contains requests for the new batch
|
||||
)
|
||||
endpoint: str # The endpoint to be used for all requests in the batch
|
||||
completion_window: str # The time frame within which the batch should be processed
|
||||
metadata: Optional[dict] = None # Optional custom metadata for the batch
|
||||
|
||||
|
||||
class BatchResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "batch"
|
||||
endpoint: str
|
||||
errors: Optional[dict] = None
|
||||
input_file_id: str
|
||||
completion_window: str
|
||||
status: str = "validating"
|
||||
output_file_id: Optional[str] = None
|
||||
error_file_id: Optional[str] = None
|
||||
created_at: int
|
||||
in_progress_at: Optional[int] = None
|
||||
expires_at: Optional[int] = None
|
||||
finalizing_at: Optional[int] = None
|
||||
completed_at: Optional[int] = None
|
||||
failed_at: Optional[int] = None
|
||||
expired_at: Optional[int] = None
|
||||
cancelling_at: Optional[int] = None
|
||||
cancelled_at: Optional[int] = None
|
||||
request_counts: Optional[dict] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: bool = False
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: int = 16
|
||||
n: int = 1
|
||||
presence_penalty: float = 0.0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[str] = None
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
user: Optional[str] = None
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
min_tokens: int = 0
|
||||
json_schema: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[str] = None
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
@field_validator("max_tokens")
|
||||
@classmethod
|
||||
def validate_max_tokens_positive(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError("max_tokens must be positive")
|
||||
return v
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class ChatCompletionMessageContentTextPart(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ChatCompletionMessageContentImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentAudioURL(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ChatCompletionMessageContentImagePart(BaseModel):
|
||||
type: Literal["image_url"]
|
||||
image_url: ChatCompletionMessageContentImageURL
|
||||
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentAudioPart(BaseModel):
|
||||
type: Literal["audio_url"]
|
||||
audio_url: ChatCompletionMessageContentAudioURL
|
||||
|
||||
|
||||
ChatCompletionMessageContentPart = Union[
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentAudioPart,
|
||||
]
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call response."""
|
||||
|
||||
id: Optional[str] = None
|
||||
index: Optional[int] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionResponse
|
||||
|
||||
|
||||
class ChatCompletionMessageGenericParam(BaseModel):
|
||||
role: Literal["system", "assistant", "tool"]
|
||||
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
|
||||
tool_call_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionMessageUserParam(BaseModel):
|
||||
role: Literal["user"]
|
||||
content: Union[str, List[ChatCompletionMessageContentPart]]
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[
|
||||
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
|
||||
]
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StructuresResponseFormat(BaseModel):
|
||||
begin: str
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
strict: bool = False
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Function wrapper."""
|
||||
|
||||
type: str = Field(default="function", examples=["function"])
|
||||
function: Function
|
||||
|
||||
|
||||
class ToolChoiceFuncName(BaseModel):
|
||||
"""The name of tool choice function."""
|
||||
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
"""The tool choice definition."""
|
||||
|
||||
function: ToolChoiceFuncName
|
||||
type: Literal["function"] = Field(default="function", examples=["function"])
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
|
||||
description="The maximum number of tokens that can be generated in the chat completion. ",
|
||||
)
|
||||
max_completion_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The maximum number of completion tokens for a chat completion request, "
|
||||
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
|
||||
)
|
||||
n: int = 1
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: float = 0.7
|
||||
top_p: float = 1.0
|
||||
user: Optional[str] = None
|
||||
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
||||
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
||||
default="auto", examples=["none"]
|
||||
) # noqa
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_tool_choice_default(cls, values):
|
||||
if isinstance(values, dict):
|
||||
if values.get("tool_choice") is None:
|
||||
if values.get("tools") is None:
|
||||
values["tool_choice"] = "none"
|
||||
else:
|
||||
values["tool_choice"] = "auto"
|
||||
return values
|
||||
|
||||
@field_validator("messages")
|
||||
@classmethod
|
||||
def validate_messages_not_empty(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
return v
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
min_tokens: int = 0
|
||||
regex: Optional[str] = None
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
continue_final_message: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
separate_reasoning: bool = True
|
||||
stream_reasoning: bool = True
|
||||
chat_template_kwargs: Optional[Dict] = None
|
||||
|
||||
# The request id.
|
||||
rid: Optional[str] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[str] = None
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||
]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: Optional[
|
||||
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
|
||||
] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class MultimodalEmbeddingInput(BaseModel):
|
||||
text: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
|
||||
|
||||
EmbeddingInput = Union[
|
||||
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
||||
]
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
input: EmbeddingInput
|
||||
model: str
|
||||
encoding_format: str = "float"
|
||||
dimensions: int = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# The request id.
|
||||
rid: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingObject(BaseModel):
|
||||
embedding: List[float]
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
data: List[EmbeddingObject]
|
||||
model: str
|
||||
object: str = "list"
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class ScoringRequest(BaseModel):
|
||||
query: Optional[Union[str, List[int]]] = (
|
||||
None # Query text or pre-tokenized token IDs
|
||||
)
|
||||
items: Optional[Union[str, List[str], List[List[int]]]] = (
|
||||
None # Item text(s) or pre-tokenized token IDs
|
||||
)
|
||||
label_token_ids: Optional[List[int]] = (
|
||||
None # Token IDs to compute probabilities for
|
||||
)
|
||||
apply_softmax: bool = False
|
||||
item_first: bool = False
|
||||
model: str
|
||||
|
||||
|
||||
class ScoringResponse(BaseModel):
|
||||
scores: List[
|
||||
List[float]
|
||||
] # List of lists of probabilities, each in the order of label_token_ids
|
||||
model: str
|
||||
usage: Optional[UsageInfo] = None
|
||||
object: str = "scoring"
|
||||
|
||||
|
||||
OpenAIServingRequest = Union[
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest
|
||||
]
|
||||
178
python/sglang/srt/entrypoints/openai/serving_base.py
Normal file
178
python/sglang/srt/entrypoints/openai/serving_base.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
OpenAIServingRequest,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Base class for specific endpoint handlers
|
||||
class OpenAIServingBase(ABC):
|
||||
"""Abstract base class for OpenAI endpoint handlers"""
|
||||
|
||||
def __init__(self, tokenizer_manager: TokenizerManager):
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
|
||||
async def handle_request(
|
||||
self, request: OpenAIServingRequest, raw_request: Request
|
||||
) -> Union[Any, StreamingResponse, ErrorResponse]:
|
||||
"""Handle the specific request type with common pattern"""
|
||||
try:
|
||||
# Validate request
|
||||
error_msg = self._validate_request(request)
|
||||
if error_msg:
|
||||
return self.create_error_response(error_msg)
|
||||
|
||||
# Convert to internal format
|
||||
adapted_request, processed_request = self._convert_to_internal_request(
|
||||
[request], [self._generate_request_id_base(request)]
|
||||
)
|
||||
|
||||
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
||||
if hasattr(request, "stream") and request.stream:
|
||||
return await self._handle_streaming_request(
|
||||
adapted_request, processed_request, raw_request
|
||||
)
|
||||
else:
|
||||
return await self._handle_non_streaming_request(
|
||||
adapted_request, processed_request, raw_request
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in request: {e}")
|
||||
return self.create_error_response(
|
||||
message=f"Internal server error: {str(e)}",
|
||||
err_type="InternalServerError",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _request_id_prefix(self) -> str:
|
||||
"""Generate request ID based on request type"""
|
||||
pass
|
||||
|
||||
def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
|
||||
"""Generate request ID based on request type"""
|
||||
if rid := getattr(request, "rid", None):
|
||||
return rid
|
||||
|
||||
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
|
||||
|
||||
@abstractmethod
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
all_requests: List[OpenAIServingRequest],
|
||||
request_ids: List[str],
|
||||
) -> tuple[
|
||||
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
|
||||
]:
|
||||
"""Convert OpenAI request to internal format"""
|
||||
pass
|
||||
|
||||
async def _handle_streaming_request(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
request: OpenAIServingRequest,
|
||||
raw_request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handle streaming request
|
||||
|
||||
Override this method in child classes that support streaming requests.
|
||||
"""
|
||||
return self.create_error_response(
|
||||
message=f"{self.__class__.__name__} does not support streaming requests",
|
||||
err_type="NotImplementedError",
|
||||
status_code=501,
|
||||
)
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
request: OpenAIServingRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[Any, ErrorResponse]:
|
||||
"""Handle non-streaming request
|
||||
|
||||
Override this method in child classes that support non-streaming requests.
|
||||
"""
|
||||
return self.create_error_response(
|
||||
message=f"{self.__class__.__name__} does not support non-streaming requests",
|
||||
err_type="NotImplementedError",
|
||||
status_code=501,
|
||||
)
|
||||
|
||||
def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
|
||||
"""Validate request"""
|
||||
pass
|
||||
|
||||
def _calculate_streaming_usage_base(
|
||||
self,
|
||||
prompt_tokens: Dict[int, int],
|
||||
completion_tokens: Dict[int, int],
|
||||
cached_tokens: Dict[int, int],
|
||||
n_choices: int,
|
||||
) -> UsageInfo:
|
||||
"""Calculate usage information for streaming responses (common logic)"""
|
||||
total_prompt_tokens = sum(
|
||||
tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
|
||||
)
|
||||
total_completion_tokens = sum(tokens for tokens in completion_tokens.values())
|
||||
|
||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||
prompt_tokens_details = None
|
||||
if cache_report:
|
||||
cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
|
||||
if cached_tokens_sum > 0:
|
||||
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
|
||||
|
||||
return UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: int = 400,
|
||||
param: Optional[str] = None,
|
||||
) -> ORJSONResponse:
|
||||
"""Create an error response"""
|
||||
error = ErrorResponse(
|
||||
object="error",
|
||||
message=message,
|
||||
type=err_type,
|
||||
param=param,
|
||||
code=status_code,
|
||||
)
|
||||
return ORJSONResponse(content=error.model_dump(), status_code=status_code)
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: int = 400,
|
||||
) -> str:
|
||||
"""Create a streaming error response"""
|
||||
error = ErrorResponse(
|
||||
object="error",
|
||||
message=message,
|
||||
type=err_type,
|
||||
param=None,
|
||||
code=status_code,
|
||||
)
|
||||
return json.dumps({"error": error.model_dump()})
|
||||
938
python/sglang/srt/entrypoints/openai/serving_chat.py
Normal file
938
python/sglang/srt/entrypoints/openai/serving_chat.py
Normal file
@@ -0,0 +1,938 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionTokenLogprob,
|
||||
ChatMessage,
|
||||
ChoiceLogprobs,
|
||||
DeltaMessage,
|
||||
ErrorResponse,
|
||||
FunctionResponse,
|
||||
LogProbs,
|
||||
ToolCall,
|
||||
TopLogprob,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
aggregate_token_usage,
|
||||
detect_template_content_format,
|
||||
process_content_for_template_format,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.utils import convert_json_schema_to_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServingBase):
|
||||
"""Handler for chat completion requests"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Instance-specific cache for template content format detection
|
||||
self._cached_chat_template = None
|
||||
self._cached_template_format = None
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "chatcmpl-"
|
||||
|
||||
def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:
|
||||
"""Validate chat messages format and content"""
|
||||
if not (messages := request.messages):
|
||||
return "Messages cannot be empty"
|
||||
|
||||
# Check for alternating user/assistant pattern (optional validation)
|
||||
roles = [msg.role for msg in messages]
|
||||
|
||||
# First message should typically be from user or system
|
||||
if roles[0] not in ["user", "system"]:
|
||||
return "First message should be from 'user' or 'system'"
|
||||
|
||||
# Check for consecutive assistant messages (which might indicate an error)
|
||||
for i in range(1, len(roles)):
|
||||
if roles[i] == "assistant" and roles[i - 1] == "assistant":
|
||||
# This is actually allowed in some cases, so just warn
|
||||
pass
|
||||
|
||||
# Validate message content
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role == "user":
|
||||
if not msg.content:
|
||||
return f"User message at index {i} has no content"
|
||||
elif msg.role == "assistant":
|
||||
# Assistant messages can have no content if they have tool_calls
|
||||
if not msg.content and not getattr(msg, "tool_calls", None):
|
||||
return (
|
||||
f"Assistant message at index {i} has no content or tool calls"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
all_requests: List[ChatCompletionRequest],
|
||||
request_ids: List[str],
|
||||
) -> tuple[
|
||||
GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]]
|
||||
]:
|
||||
"""Convert OpenAI chat completion request to internal format"""
|
||||
input_ids = []
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
image_data_list = []
|
||||
audio_data_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
modalities_list = []
|
||||
lora_paths = []
|
||||
|
||||
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
||||
|
||||
for request in all_requests:
|
||||
# Process messages and apply chat template
|
||||
(
|
||||
prompt,
|
||||
prompt_ids,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
stop,
|
||||
tool_call_constraint,
|
||||
) = self._process_messages(request, is_multimodal)
|
||||
|
||||
input_ids.append(prompt_ids)
|
||||
prompts.append(prompt)
|
||||
return_logprobs.append(request.logprobs)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(request.top_logprobs or 0)
|
||||
lora_paths.append(request.lora_path)
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(
|
||||
request, stop, tool_call_constraint
|
||||
)
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
image_data_list.append(image_data)
|
||||
audio_data_list.append(audio_data)
|
||||
modalities_list.append(modalities)
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if len(all_requests) == 1:
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": prompts[0]}
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids[0]}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids[0]}
|
||||
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data_list = image_data_list[0]
|
||||
audio_data_list = audio_data_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
modalities_list = modalities_list[0]
|
||||
lora_paths = lora_paths[0]
|
||||
request_ids = request_ids[0]
|
||||
else:
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": prompts}
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=image_data_list,
|
||||
audio_data=audio_data_list,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
stream=all_requests[0].stream,
|
||||
return_text_in_logprobs=True,
|
||||
rid=request_ids,
|
||||
modalities=modalities_list,
|
||||
lora_path=lora_paths,
|
||||
bootstrap_host=all_requests[0].bootstrap_host,
|
||||
bootstrap_port=all_requests[0].bootstrap_port,
|
||||
bootstrap_room=all_requests[0].bootstrap_room,
|
||||
)
|
||||
|
||||
return adapted_request, (
|
||||
all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
)
|
||||
|
||||
def _process_messages(
|
||||
self, request: ChatCompletionRequest, is_multimodal: bool
|
||||
) -> tuple[
|
||||
str,
|
||||
Union[str, List[int]],
|
||||
Optional[Any],
|
||||
Optional[Any],
|
||||
List[str],
|
||||
List[str],
|
||||
Optional[Any],
|
||||
]:
|
||||
"""Process chat messages and apply chat template"""
|
||||
tool_call_constraint = None
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
for item in request.tools
|
||||
if item.function.name == request.tool_choice.function.name
|
||||
]
|
||||
else:
|
||||
tools = [item.function.model_dump() for item in request.tools]
|
||||
|
||||
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||
tool_call_constraint = parser.get_structure_constraint(
|
||||
request.tool_choice
|
||||
)
|
||||
|
||||
# Use chat template
|
||||
if (
|
||||
hasattr(self.tokenizer_manager, "chat_template_name")
|
||||
and self.tokenizer_manager.chat_template_name is None
|
||||
):
|
||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||
self._apply_jinja_template(request, tools, is_multimodal)
|
||||
)
|
||||
else:
|
||||
prompt, image_data, audio_data, modalities, stop = (
|
||||
self._apply_conversation_template(request)
|
||||
)
|
||||
if not is_multimodal:
|
||||
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
||||
else:
|
||||
# Use raw prompt
|
||||
prompt_ids = request.messages
|
||||
stop = request.stop or []
|
||||
image_data = None
|
||||
audio_data = None
|
||||
modalities = []
|
||||
prompt = request.messages
|
||||
|
||||
return (
|
||||
prompt,
|
||||
prompt_ids,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
stop,
|
||||
tool_call_constraint,
|
||||
)
|
||||
|
||||
def _apply_jinja_template(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
tools: Optional[List[Dict]],
|
||||
is_multimodal: bool,
|
||||
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
|
||||
"""Apply Jinja chat template"""
|
||||
openai_compatible_messages = []
|
||||
image_data = []
|
||||
audio_data = []
|
||||
modalities = []
|
||||
|
||||
# Detect template content format
|
||||
current_template = self.tokenizer_manager.tokenizer.chat_template
|
||||
if current_template != self._cached_chat_template:
|
||||
self._cached_chat_template = current_template
|
||||
self._cached_template_format = detect_template_content_format(
|
||||
current_template
|
||||
)
|
||||
logger.info(
|
||||
f"Detected chat template content format: {self._cached_template_format}"
|
||||
)
|
||||
|
||||
template_content_format = self._cached_template_format
|
||||
|
||||
for message in request.messages:
|
||||
if message.content is None:
|
||||
message.content = ""
|
||||
msg_dict = message.model_dump()
|
||||
|
||||
# Process content based on detected template format
|
||||
processed_msg = process_content_for_template_format(
|
||||
msg_dict,
|
||||
template_content_format,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
)
|
||||
openai_compatible_messages.append(processed_msg)
|
||||
|
||||
# Handle assistant prefix for continue_final_message
|
||||
assistant_prefix = None
|
||||
if (
|
||||
openai_compatible_messages
|
||||
and openai_compatible_messages[-1]["role"] == "assistant"
|
||||
):
|
||||
if request.continue_final_message:
|
||||
assistant_prefix = openai_compatible_messages[-1]["content"]
|
||||
openai_compatible_messages = openai_compatible_messages[:-1]
|
||||
|
||||
try:
|
||||
prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
**(
|
||||
request.chat_template_kwargs if request.chat_template_kwargs else {}
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
# This except branch will be triggered when the chosen model
|
||||
# has a different tools input format that is not compatible
|
||||
# with openAI's apply_chat_template tool_call format, like Mistral.
|
||||
tools = (
|
||||
[t if "function" in t else {"function": t} for t in tools]
|
||||
if tools
|
||||
else None
|
||||
)
|
||||
prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
**(
|
||||
request.chat_template_kwargs if request.chat_template_kwargs else {}
|
||||
),
|
||||
)
|
||||
|
||||
if assistant_prefix:
|
||||
encoded = self.tokenizer_manager.tokenizer.encode(assistant_prefix)
|
||||
if encoded and encoded[0] == self.tokenizer_manager.tokenizer.bos_token_id:
|
||||
encoded = encoded[1:]
|
||||
prompt_ids += encoded
|
||||
|
||||
if is_multimodal:
|
||||
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
|
||||
|
||||
stop = request.stop or []
|
||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||
|
||||
def _apply_conversation_template(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]:
|
||||
"""Apply conversation template"""
|
||||
conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name)
|
||||
|
||||
# If we should continue the final assistant message, adjust the conversation.
|
||||
if (
|
||||
request.continue_final_message
|
||||
and request.messages
|
||||
and request.messages[-1].role == "assistant"
|
||||
):
|
||||
# Remove the auto-added blank assistant turn, if present.
|
||||
if conv.messages and conv.messages[-1][1] is None:
|
||||
conv.messages.pop()
|
||||
# Rebuild the prompt from the conversation.
|
||||
prompt = conv.get_prompt()
|
||||
# Strip trailing stop tokens or separators that indicate end-of-assistant.
|
||||
if isinstance(conv.stop_str, list):
|
||||
for stop_token in conv.stop_str:
|
||||
if prompt.endswith(stop_token):
|
||||
prompt = prompt[: -len(stop_token)]
|
||||
elif isinstance(conv.stop_str, str) and prompt.endswith(conv.stop_str):
|
||||
prompt = prompt[: -len(conv.stop_str)]
|
||||
if conv.sep and prompt.endswith(conv.sep):
|
||||
prompt = prompt[: -len(conv.sep)]
|
||||
if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
|
||||
prompt = prompt[: -len(conv.sep2)]
|
||||
else:
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
image_data = conv.image_data
|
||||
audio_data = conv.audio_data
|
||||
modalities = conv.modalities
|
||||
stop = conv.stop_str or [] if not request.ignore_eos else []
|
||||
|
||||
if request.stop:
|
||||
if isinstance(request.stop, str):
|
||||
stop.append(request.stop)
|
||||
else:
|
||||
stop.extend(request.stop)
|
||||
|
||||
return prompt, image_data, audio_data, modalities, stop
|
||||
|
||||
def _build_sampling_params(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
stop: List[str],
|
||||
tool_call_constraint: Optional[Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Build sampling parameters for the request"""
|
||||
|
||||
sampling_params = {
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"ebnf": request.ebnf,
|
||||
"n": request.n,
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"skip_special_tokens": request.skip_special_tokens,
|
||||
"logit_bias": request.logit_bias,
|
||||
}
|
||||
|
||||
if request.response_format and request.response_format.type == "json_schema":
|
||||
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||
request.response_format.json_schema.schema_
|
||||
)
|
||||
elif request.response_format and request.response_format.type == "json_object":
|
||||
sampling_params["json_schema"] = '{"type": "object"}'
|
||||
elif (
|
||||
request.response_format and request.response_format.type == "structural_tag"
|
||||
):
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
request.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
# Check if there are already existing output constraints
|
||||
has_existing_constraints = (
|
||||
sampling_params.get("regex")
|
||||
or sampling_params.get("ebnf")
|
||||
or sampling_params.get("structural_tag")
|
||||
or sampling_params.get("json_schema")
|
||||
)
|
||||
|
||||
if tool_call_constraint and has_existing_constraints:
|
||||
logger.warning("Constrained decoding is not compatible with tool calls.")
|
||||
elif tool_call_constraint:
|
||||
constraint_type, constraint_value = tool_call_constraint
|
||||
if constraint_type == "structural_tag":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value.model_dump(by_alias=True)
|
||||
)
|
||||
else:
|
||||
sampling_params[constraint_type] = constraint_value
|
||||
return sampling_params
|
||||
|
||||
async def _handle_streaming_request(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handle streaming chat completion request"""
|
||||
|
||||
async def generate_stream_resp():
|
||||
parser_dict = {}
|
||||
reasoning_parser_dict = {}
|
||||
tool_call_first = True
|
||||
is_firsts = {}
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
n_prev_token = n_prev_tokens.get(index, 0)
|
||||
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
|
||||
# Handle logprobs
|
||||
choice_logprobs = None
|
||||
if request.logprobs:
|
||||
choice_logprobs = self._process_streaming_logprobs(
|
||||
content, n_prev_token
|
||||
)
|
||||
n_prev_token = len(
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
finish_reason_type = (
|
||||
finish_reason["type"] if finish_reason else None
|
||||
)
|
||||
|
||||
# First chunk with role
|
||||
if is_first:
|
||||
is_first = False
|
||||
delta = DeltaMessage(role="assistant")
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=delta,
|
||||
finish_reason=finish_reason_type,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Process content delta
|
||||
delta = content["text"][len(stream_buffer) :]
|
||||
new_stream_buffer = stream_buffer + delta
|
||||
|
||||
# Handle reasoning content
|
||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||
"enable_thinking", True
|
||||
)
|
||||
if (
|
||||
self.tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
and enable_thinking
|
||||
):
|
||||
reasoning_text, delta = self._process_reasoning_stream(
|
||||
index, delta, reasoning_parser_dict, content, request
|
||||
)
|
||||
if reasoning_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(reasoning_content=reasoning_text),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
if not delta:
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
continue
|
||||
|
||||
# Handle tool calls
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
async for chunk in self._process_tool_call_stream(
|
||||
index,
|
||||
delta,
|
||||
parser_dict,
|
||||
content,
|
||||
request,
|
||||
finish_reason_type,
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
# Regular content
|
||||
if delta or not (
|
||||
request.stream_options
|
||||
and request.stream_options.include_usage
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta if delta else None),
|
||||
finish_reason=(
|
||||
None
|
||||
if request.stream_options
|
||||
and request.stream_options.include_usage
|
||||
else finish_reason_type
|
||||
),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
# Final chunk with usage
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = self._calculate_streaming_usage_base(
|
||||
prompt_tokens, completion_tokens, cached_tokens, request.n
|
||||
)
|
||||
else:
|
||||
usage = None
|
||||
|
||||
final_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
usage=usage,
|
||||
)
|
||||
yield f"data: {final_chunk.model_dump_json()}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {error}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream_resp(),
|
||||
media_type="text/event-stream",
|
||||
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
||||
)
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[ChatCompletionResponse, ErrorResponse]:
|
||||
"""Handle non-streaming chat completion request"""
|
||||
try:
|
||||
ret = await self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
).__anext__()
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = self._build_chat_response(
|
||||
request,
|
||||
ret,
|
||||
int(time.time()),
|
||||
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
|
||||
reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _build_chat_response(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
ret: List[Dict[str, Any]],
|
||||
created: int,
|
||||
cache_report: bool = False,
|
||||
tool_call_parser: Optional[str] = None,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Build chat completion response from generation results"""
|
||||
choices = []
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
# Process logprobs
|
||||
choice_logprobs = None
|
||||
if request.logprobs:
|
||||
choice_logprobs = self._process_response_logprobs(ret_item)
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
text = ret_item["text"]
|
||||
|
||||
# Handle reasoning content
|
||||
reasoning_text = None
|
||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||
"enable_thinking", True
|
||||
)
|
||||
if reasoning_parser and request.separate_reasoning and enable_thinking:
|
||||
try:
|
||||
parser = ReasoningParser(
|
||||
model_type=reasoning_parser, stream_reasoning=False
|
||||
)
|
||||
reasoning_text, text = parser.parse_non_stream(text)
|
||||
except Exception as e:
|
||||
logger.error(f"Reasoning parsing error: {e}")
|
||||
return self.create_error_response(
|
||||
"Failed to parse reasoning content",
|
||||
err_type="InternalServerError",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
# Handle tool calls
|
||||
tool_calls = None
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||
text, request.tools, tool_call_parser, finish_reason
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=idx,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=text if text else None,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=reasoning_text if reasoning_text else None,
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
finish_reason=finish_reason["type"] if finish_reason else None,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
# Calculate usage
|
||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
created=created,
|
||||
model=request.model,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _process_logprobs_tokens(
|
||||
self, logprobs: LogProbs, use_token_index: bool = False
|
||||
) -> List[ChatCompletionTokenLogprob]:
|
||||
"""Common helper to process logprobs tokens for both streaming and non-streaming
|
||||
|
||||
Args:
|
||||
logprobs: LogProbs data from model
|
||||
use_token_index: True for non-streaming (use token_idx), False for streaming (use index 0)
|
||||
"""
|
||||
token_logprobs = []
|
||||
|
||||
for token_idx, (token, logprob) in enumerate(
|
||||
zip(logprobs.tokens, logprobs.token_logprobs)
|
||||
):
|
||||
token_bytes = list(token.encode("utf-8"))
|
||||
top_logprobs = []
|
||||
if logprobs.top_logprobs:
|
||||
# - Non-streaming (use_token_index=True): uses token_idx for full data
|
||||
# - Streaming (use_token_index=False): uses index 0 for pre-sliced data
|
||||
top_logprobs_idx = token_idx if use_token_index else 0
|
||||
for top_token, top_logprob in logprobs.top_logprobs[
|
||||
top_logprobs_idx
|
||||
].items():
|
||||
top_token_bytes = list(top_token.encode("utf-8"))
|
||||
top_logprobs.append(
|
||||
TopLogprob(
|
||||
token=top_token,
|
||||
bytes=top_token_bytes,
|
||||
logprob=top_logprob,
|
||||
)
|
||||
)
|
||||
token_logprobs.append(
|
||||
ChatCompletionTokenLogprob(
|
||||
token=token,
|
||||
bytes=token_bytes,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
return token_logprobs
|
||||
|
||||
def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs:
|
||||
"""Process logprobs for non-streaming response"""
|
||||
logprobs = to_openai_style_logprobs(
|
||||
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
||||
output_top_logprobs=ret_item["meta_info"].get("output_top_logprobs", None),
|
||||
)
|
||||
|
||||
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
|
||||
return ChoiceLogprobs(content=token_logprobs)
|
||||
|
||||
def _process_tool_calls(
|
||||
self,
|
||||
text: str,
|
||||
tools: List[Any],
|
||||
tool_call_parser: Optional[str],
|
||||
finish_reason: Dict[str, Any],
|
||||
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
||||
"""Process tool calls in the response"""
|
||||
parser = FunctionCallParser(tools, tool_call_parser)
|
||||
if parser.has_tool_call(text):
|
||||
if finish_reason["type"] == "stop":
|
||||
finish_reason["type"] = "tool_calls"
|
||||
finish_reason["matched"] = None
|
||||
try:
|
||||
text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
||||
function=FunctionResponse(
|
||||
name=call_info.name, arguments=call_info.parameters
|
||||
),
|
||||
)
|
||||
for call_info in call_info_list
|
||||
]
|
||||
return tool_calls, text, finish_reason
|
||||
except Exception as e:
|
||||
logger.error(f"Tool call parsing error: {e}")
|
||||
# Return error but don't fail the whole request
|
||||
return None, text, finish_reason
|
||||
|
||||
return None, text, finish_reason
|
||||
|
||||
def _process_streaming_logprobs(
|
||||
self, content: Dict[str, Any], n_prev_token: int
|
||||
) -> ChoiceLogprobs:
|
||||
"""Process logprobs for streaming response"""
|
||||
logprobs = to_openai_style_logprobs(
|
||||
output_token_logprobs=content["meta_info"]["output_token_logprobs"][
|
||||
n_prev_token:
|
||||
],
|
||||
output_top_logprobs=content["meta_info"].get("output_top_logprobs", [])[
|
||||
n_prev_token:
|
||||
],
|
||||
)
|
||||
|
||||
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=False)
|
||||
return ChoiceLogprobs(content=token_logprobs)
|
||||
|
||||
def _process_reasoning_stream(
|
||||
self,
|
||||
index: int,
|
||||
delta: str,
|
||||
reasoning_parser_dict: Dict[int, ReasoningParser],
|
||||
content: Dict[str, Any],
|
||||
request: ChatCompletionRequest,
|
||||
) -> tuple[Optional[str], str]:
|
||||
"""Process reasoning content in streaming response"""
|
||||
if index not in reasoning_parser_dict:
|
||||
reasoning_parser_dict[index] = ReasoningParser(
|
||||
self.tokenizer_manager.server_args.reasoning_parser,
|
||||
request.stream_reasoning,
|
||||
)
|
||||
reasoning_parser = reasoning_parser_dict[index]
|
||||
return reasoning_parser.parse_stream_chunk(delta)
|
||||
|
||||
async def _process_tool_call_stream(
|
||||
self,
|
||||
index: int,
|
||||
delta: str,
|
||||
parser_dict: Dict[int, FunctionCallParser],
|
||||
content: Dict[str, Any],
|
||||
request: ChatCompletionRequest,
|
||||
finish_reason_type: Optional[str],
|
||||
):
|
||||
"""Process tool calls in streaming response"""
|
||||
if index not in parser_dict:
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
parser = parser_dict[index]
|
||||
|
||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||
|
||||
# Yield normal text
|
||||
if normal_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=normal_text),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Yield tool calls
|
||||
for call_item in calls:
|
||||
if finish_reason_type == "stop":
|
||||
# Handle remaining arguments
|
||||
latest_delta_len = 0
|
||||
if isinstance(call_item.parameters, str):
|
||||
latest_delta_len = len(call_item.parameters)
|
||||
|
||||
expected_call = json.dumps(
|
||||
parser.detector.prev_tool_call_arr[index].get("arguments", {}),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
actual_call = parser.detector.streamed_args_for_tool[index]
|
||||
if latest_delta_len > 0:
|
||||
actual_call = actual_call[:-latest_delta_len]
|
||||
remaining_call = expected_call.replace(actual_call, "", 1)
|
||||
call_item.parameters = remaining_call
|
||||
finish_reason_type = "tool_calls"
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
||||
index=call_item.tool_index,
|
||||
function=FunctionResponse(
|
||||
name=call_item.name,
|
||||
arguments=call_item.parameters,
|
||||
),
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(tool_calls=[tool_call]),
|
||||
finish_reason=(
|
||||
None
|
||||
if request.stream_options and request.stream_options.include_usage
|
||||
else finish_reason_type
|
||||
),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
467
python/sglang/srt/entrypoints/openai/serving_completions.py
Normal file
467
python/sglang/srt/entrypoints/openai/serving_completions.py
Normal file
@@ -0,0 +1,467 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from sglang.srt.code_completion_parser import (
|
||||
generate_completion_prompt_from_request,
|
||||
is_completion_template_defined,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
aggregate_token_usage,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"""Handler for completion requests"""
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "cmpl-"
|
||||
|
||||
def _validate_request(self, request: CompletionRequest) -> Optional[str]:
|
||||
"""Validate completion prompt format and content"""
|
||||
if not (prompt := request.prompt):
|
||||
return "Prompt cannot be None"
|
||||
|
||||
if isinstance(prompt, str):
|
||||
if not prompt.strip():
|
||||
return "Prompt cannot be empty or whitespace only"
|
||||
elif isinstance(prompt, list):
|
||||
if not prompt:
|
||||
return "Prompt list cannot be empty"
|
||||
|
||||
# Check if it's a list of strings
|
||||
if all(isinstance(item, str) for item in prompt):
|
||||
for i, item in enumerate(prompt):
|
||||
if not item.strip():
|
||||
return f"Prompt at index {i} cannot be empty or whitespace only"
|
||||
|
||||
# Check if it's a list of token IDs (integers)
|
||||
elif all(isinstance(item, int) for item in prompt):
|
||||
if any(item < 0 for item in prompt):
|
||||
return "Token IDs must be non-negative"
|
||||
|
||||
# Check if it's a list of lists (multiple token sequences)
|
||||
elif all(isinstance(item, list) for item in prompt):
|
||||
for i, item in enumerate(prompt):
|
||||
if not item:
|
||||
return f"Token sequence at index {i} cannot be empty"
|
||||
if not all(isinstance(token, int) for token in item):
|
||||
return f"Token sequence at index {i} must contain only integers"
|
||||
if any(token < 0 for token in item):
|
||||
return (
|
||||
f"Token sequence at index {i} contains negative token IDs"
|
||||
)
|
||||
else:
|
||||
return "Prompt must be string, list of strings, list of integers, or list of integer lists"
|
||||
else:
|
||||
return "Prompt must be string or list"
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
all_requests: List[CompletionRequest],
|
||||
request_ids: List[str],
|
||||
) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]:
|
||||
"""Convert OpenAI completion request to internal format"""
|
||||
# Validate batch requests
|
||||
if len(all_requests) > 1:
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
assert (
|
||||
type(request.prompt) is first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
if request.n > 1:
|
||||
raise ValueError(
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
lora_paths = []
|
||||
|
||||
for request in all_requests:
|
||||
# Process prompt
|
||||
prompt = request.prompt
|
||||
if is_completion_template_defined():
|
||||
prompt = generate_completion_prompt_from_request(request)
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
lora_paths.append(request.lora_path)
|
||||
|
||||
# Set logprob start length based on echo and logprobs
|
||||
if request.echo and request.logprobs:
|
||||
current_logprob_start_len = 0
|
||||
else:
|
||||
current_logprob_start_len = -1
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(request)
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
return_logprobs.append(request.logprobs is not None)
|
||||
logprob_start_lens.append(current_logprob_start_len)
|
||||
top_logprobs_nums.append(
|
||||
request.logprobs if request.logprobs is not None else 0
|
||||
)
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if len(all_requests) == 1:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts[0]}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts[0]}
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
lora_paths = lora_paths[0]
|
||||
request_ids = request_ids[0]
|
||||
else:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
return_text_in_logprobs=True,
|
||||
stream=all_requests[0].stream,
|
||||
rid=request_ids,
|
||||
lora_path=lora_paths,
|
||||
bootstrap_host=all_requests[0].bootstrap_host,
|
||||
bootstrap_port=all_requests[0].bootstrap_port,
|
||||
bootstrap_room=all_requests[0].bootstrap_room,
|
||||
)
|
||||
|
||||
return adapted_request, (
|
||||
all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
)
|
||||
|
||||
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
|
||||
"""Build sampling parameters for the request"""
|
||||
# Start with common parameters
|
||||
sampling_params = {
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": request.stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"json_schema": request.json_schema,
|
||||
"ebnf": request.ebnf,
|
||||
"n": request.n,
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"skip_special_tokens": request.skip_special_tokens,
|
||||
"logit_bias": request.logit_bias,
|
||||
}
|
||||
|
||||
# No additional completion-specific parameters needed currently
|
||||
# (json_schema is already handled in base method)
|
||||
|
||||
return sampling_params
|
||||
|
||||
async def _handle_streaming_request(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handle streaming completion request"""
|
||||
created = int(time.time())
|
||||
|
||||
async def generate_stream_resp():
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
n_prev_token = n_prev_tokens.get(index, 0)
|
||||
|
||||
text = content["text"]
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
|
||||
# Handle echo for first chunk
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
echo_text = self._get_echo_text(request, index)
|
||||
text = echo_text + text
|
||||
|
||||
# Handle logprobs
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
# The first chunk and echo is enabled.
|
||||
if not stream_buffer and request.echo:
|
||||
input_token_logprobs = content["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
input_top_logprobs = content["meta_info"][
|
||||
"input_top_logprobs"
|
||||
]
|
||||
else:
|
||||
input_token_logprobs = None
|
||||
input_top_logprobs = None
|
||||
|
||||
logprobs = to_openai_style_logprobs(
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_token_logprobs=content["meta_info"][
|
||||
"output_token_logprobs"
|
||||
][n_prev_token:],
|
||||
output_top_logprobs=content["meta_info"][
|
||||
"output_top_logprobs"
|
||||
][n_prev_token:],
|
||||
)
|
||||
n_prev_token = len(
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
|
||||
# Generate delta
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = stream_buffer + delta
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=delta,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason["type"] if finish_reason else None,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
)
|
||||
chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
object="text_completion",
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
stream_buffers[index] = stream_buffer
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Handle final usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = self._calculate_streaming_usage_base(
|
||||
prompt_tokens, completion_tokens, cached_tokens, request.n
|
||||
)
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
choices=[],
|
||||
model=request.model,
|
||||
usage=usage,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_none=True
|
||||
)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {error}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream_resp(),
|
||||
media_type="text/event-stream",
|
||||
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
||||
)
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[CompletionResponse, ErrorResponse]:
|
||||
"""Handle non-streaming completion request"""
|
||||
try:
|
||||
generator = self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
)
|
||||
ret = await generator.__anext__()
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = self._build_completion_response(
|
||||
request,
|
||||
ret,
|
||||
int(time.time()),
|
||||
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _build_completion_response(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
ret: List[Dict[str, Any]],
|
||||
created: int,
|
||||
cache_report: bool = False,
|
||||
) -> CompletionResponse:
|
||||
"""Build completion response from generation results"""
|
||||
choices = []
|
||||
echo = False
|
||||
|
||||
# Prepare echo prompts if needed
|
||||
echo_prompts = []
|
||||
if (not isinstance(request, list)) and request.echo:
|
||||
echo_prompts = self._prepare_echo_prompts(request)
|
||||
echo = True
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
text = ret_item["text"]
|
||||
|
||||
# Handle echo
|
||||
if isinstance(request, list) and request[idx].echo:
|
||||
echo = True
|
||||
text = request[idx].prompt + text
|
||||
elif echo and not isinstance(request, list):
|
||||
prompt_index = idx // request.n
|
||||
text = echo_prompts[prompt_index] + text
|
||||
|
||||
# Handle logprobs
|
||||
logprobs = None
|
||||
if isinstance(request, list) and request[idx].logprobs is not None:
|
||||
logprobs = True
|
||||
elif (not isinstance(request, list)) and request.logprobs is not None:
|
||||
logprobs = True
|
||||
|
||||
if logprobs:
|
||||
if echo:
|
||||
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
||||
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
||||
else:
|
||||
input_token_logprobs = None
|
||||
input_top_logprobs = None
|
||||
|
||||
logprobs = to_openai_style_logprobs(
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_token_logprobs=ret_item["meta_info"][
|
||||
"output_token_logprobs"
|
||||
],
|
||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||
)
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=idx,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason["type"] if finish_reason else None,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
# Calculate usage
|
||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
||||
|
||||
return CompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
model=request.model,
|
||||
created=created,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
|
||||
"""Get echo text for streaming response"""
|
||||
if isinstance(request.prompt, str):
|
||||
# for the case of single str prompts
|
||||
return request.prompt
|
||||
elif isinstance(request.prompt, list):
|
||||
if isinstance(request.prompt[0], str):
|
||||
# for the case of multiple str prompts
|
||||
return request.prompt[index // request.n]
|
||||
elif isinstance(request.prompt[0], int):
|
||||
# for the case of single token ids prompt
|
||||
return self.tokenizer_manager.tokenizer.decode(
|
||||
request.prompt, skip_special_tokens=True
|
||||
)
|
||||
elif isinstance(request.prompt[0], list) and isinstance(
|
||||
request.prompt[0][0], int
|
||||
):
|
||||
# for the case of multiple token ids prompts
|
||||
return self.tokenizer_manager.tokenizer.decode(
|
||||
request.prompt[index // request.n],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
return ""
|
||||
|
||||
def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]:
|
||||
"""Prepare echo prompts for non-streaming response"""
|
||||
# TODO: handle the case prompt is token ids
|
||||
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
|
||||
# for the case of multiple str prompts
|
||||
return request.prompt
|
||||
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
|
||||
# for the case of multiple token ids prompts
|
||||
return [
|
||||
self.tokenizer_manager.tokenizer.decode(
|
||||
prompt, skip_special_tokens=True
|
||||
)
|
||||
for prompt in request.prompt
|
||||
]
|
||||
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
|
||||
# for the case of single token ids prompt
|
||||
return [
|
||||
self.tokenizer_manager.tokenizer.decode(
|
||||
request.prompt, skip_special_tokens=True
|
||||
)
|
||||
]
|
||||
else:
|
||||
# for the case of single str prompt
|
||||
return [request.prompt]
|
||||
227
python/sglang/srt/entrypoints/openai/serving_embedding.py
Normal file
227
python/sglang/srt/entrypoints/openai/serving_embedding.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.conversation import generate_embedding_convs
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
EmbeddingObject,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
MultimodalEmbeddingInput,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
"""Handler for embedding requests"""
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "embd-"
|
||||
|
||||
def _validate_request(self, request: EmbeddingRequest) -> Optional[str]:
|
||||
"""Validate that the input is not empty or whitespace only."""
|
||||
if not (input := request.input):
|
||||
return "Input cannot be empty"
|
||||
|
||||
# Handle single string
|
||||
if isinstance(input, str):
|
||||
if not input.strip():
|
||||
return "Input cannot be empty or whitespace only"
|
||||
return None
|
||||
|
||||
# Handle list inputs
|
||||
if isinstance(input, list):
|
||||
if len(input) == 0:
|
||||
return "Input cannot be empty"
|
||||
|
||||
# Check first element to determine type
|
||||
first_item = input[0]
|
||||
|
||||
if isinstance(first_item, str):
|
||||
# List of strings
|
||||
for i, item in enumerate(input):
|
||||
if not isinstance(item, str):
|
||||
return f"All items in input list must be strings"
|
||||
if not item.strip():
|
||||
return f"Input at index {i} cannot be empty or whitespace only"
|
||||
elif isinstance(first_item, int):
|
||||
# List of integers (token IDs)
|
||||
for i, item in enumerate(input):
|
||||
if not isinstance(item, int):
|
||||
return f"All items in input list must be integers"
|
||||
if item < 0:
|
||||
return f"Token ID at index {i} must be non-negative"
|
||||
elif isinstance(first_item, list):
|
||||
# List of lists (multiple token sequences)
|
||||
for i, item in enumerate(input):
|
||||
if not isinstance(item, list):
|
||||
return f"Input at index {i} must be a list"
|
||||
if not item:
|
||||
return f"Input at index {i} cannot be empty"
|
||||
if not all(isinstance(token, int) for token in item):
|
||||
return f"Input at index {i} must contain only integers"
|
||||
if any(token < 0 for token in item):
|
||||
return f"Input at index {i} contains negative token IDs"
|
||||
# Note: MultimodalEmbeddingInput validation would be handled by Pydantic
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
all_requests: List[EmbeddingRequest],
|
||||
request_ids: List[str],
|
||||
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
|
||||
"""Convert OpenAI embedding request to internal format"""
|
||||
prompts = [request.input for request in all_requests]
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if len(all_requests) == 1:
|
||||
prompt = prompts[0]
|
||||
if isinstance(prompt, str):
|
||||
# Single string input
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif isinstance(prompt, list):
|
||||
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||
# List of strings
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif len(prompt) > 0 and isinstance(
|
||||
prompt[0], MultimodalEmbeddingInput
|
||||
):
|
||||
# Handle multimodal embedding inputs
|
||||
texts = []
|
||||
images = []
|
||||
for item in prompt:
|
||||
# Use padding for text if None - this could be improved
|
||||
texts.append(item.text if item.text is not None else "padding")
|
||||
images.append(item.image if item.image is not None else None)
|
||||
|
||||
generate_prompts = []
|
||||
# Check if we have a chat template for multimodal embeddings
|
||||
# This would need to be passed in from the server configuration
|
||||
chat_template_name = getattr(
|
||||
self.tokenizer_manager, "chat_template_name", None
|
||||
)
|
||||
if chat_template_name is not None:
|
||||
convs = generate_embedding_convs(
|
||||
texts, images, chat_template_name
|
||||
)
|
||||
for conv in convs:
|
||||
generate_prompts.append(conv.get_prompt())
|
||||
else:
|
||||
generate_prompts = texts
|
||||
|
||||
if len(generate_prompts) == 1:
|
||||
prompt_kwargs = {
|
||||
"text": generate_prompts[0],
|
||||
"image_data": images[0],
|
||||
}
|
||||
else:
|
||||
prompt_kwargs = {
|
||||
"text": generate_prompts,
|
||||
"image_data": images,
|
||||
}
|
||||
else:
|
||||
# List of integers (token IDs) or empty list
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
else:
|
||||
# Other types (should not happen but handle gracefully)
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
# Use the passed request_ids for single request
|
||||
final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
|
||||
else:
|
||||
# Handle batch requests
|
||||
if len(prompts) > 0:
|
||||
# Validate that all prompts have the same type
|
||||
first_prompt = prompts[0]
|
||||
first_type = type(first_prompt)
|
||||
for i, prompt in enumerate(prompts[1:], 1):
|
||||
if type(prompt) != first_type:
|
||||
raise AssertionError(
|
||||
f"All prompts in batch must have the same type, but prompt at index {i} has different type"
|
||||
)
|
||||
|
||||
if isinstance(first_prompt, str):
|
||||
# Batch of strings
|
||||
prompt_kwargs = {"text": prompts}
|
||||
elif isinstance(first_prompt, list):
|
||||
if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
|
||||
# Batch of lists of strings
|
||||
prompt_kwargs = {"text": prompts}
|
||||
elif len(first_prompt) > 0 and isinstance(
|
||||
first_prompt[0], MultimodalEmbeddingInput
|
||||
):
|
||||
# Handle multimodal batch requests
|
||||
raise NotImplementedError(
|
||||
"Multiple requests with multimodal inputs are not supported yet"
|
||||
)
|
||||
else:
|
||||
# Batch of token ID lists
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
else:
|
||||
# Other types
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
# Use the passed request_ids for batch requests
|
||||
final_request_id = request_ids
|
||||
|
||||
adapted_request = EmbeddingReqInput(
|
||||
rid=final_request_id,
|
||||
**prompt_kwargs,
|
||||
)
|
||||
|
||||
return adapted_request, (
|
||||
all_requests[0] if len(all_requests) == 1 else all_requests
|
||||
)
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: EmbeddingReqInput,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""Handle the embedding request"""
|
||||
try:
|
||||
ret = await self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
).__anext__()
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = self._build_embedding_response(
|
||||
ret, self.tokenizer_manager.model_path
|
||||
)
|
||||
return response
|
||||
|
||||
def _build_embedding_response(
|
||||
self, ret: List[Dict[str, Any]], model_path: str
|
||||
) -> EmbeddingResponse:
|
||||
"""Build the embedding response"""
|
||||
embedding_objects = []
|
||||
prompt_tokens = 0
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
embedding_objects.append(
|
||||
EmbeddingObject(
|
||||
embedding=ret_item["embedding"],
|
||||
index=idx,
|
||||
)
|
||||
)
|
||||
# Handle missing prompt_tokens gracefully
|
||||
meta_info = ret_item.get("meta_info", {})
|
||||
prompt_tokens += meta_info.get("prompt_tokens", 0)
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=embedding_objects,
|
||||
model=model_path,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
total_tokens=prompt_tokens,
|
||||
),
|
||||
)
|
||||
264
python/sglang/srt/entrypoints/openai/utils.py
Normal file
264
python/sglang/srt/entrypoints/openai/utils.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import jinja2.nodes
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# JINJA TEMPLATE CONTENT FORMAT DETECTION
|
||||
# ============================================================================
|
||||
#
|
||||
# This adapts vLLM's approach for detecting chat template content format:
|
||||
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
|
||||
# - Analyzes Jinja template AST to detect content iteration patterns
|
||||
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
|
||||
# - 'string' format: templates that expect simple string content
|
||||
# - Processes content accordingly to match template expectations
|
||||
|
||||
|
||||
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||
"""Check if node is a variable access like {{ varname }}"""
|
||||
if isinstance(node, jinja2.nodes.Name):
|
||||
return node.ctx == "load" and node.name == varname
|
||||
return False
|
||||
|
||||
|
||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
|
||||
if isinstance(node, jinja2.nodes.Getitem):
|
||||
return (
|
||||
_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key
|
||||
)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getattr):
|
||||
return _is_var_access(node.node, varname) and node.attr == key
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_var_or_elems_access(
|
||||
node: jinja2.nodes.Node,
|
||||
varname: str,
|
||||
key: str = None,
|
||||
) -> bool:
|
||||
"""Check if node accesses varname or varname[key] with filters/tests"""
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return node.node is not None and _is_var_or_elems_access(
|
||||
node.node, varname, key
|
||||
)
|
||||
if isinstance(node, jinja2.nodes.Test):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||
node.arg, jinja2.nodes.Slice
|
||||
):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
||||
|
||||
|
||||
def _try_extract_ast(chat_template: str):
|
||||
"""Try to parse the Jinja template into an AST"""
|
||||
try:
|
||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||
return jinja_compiled.environment.parse(chat_template)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error when compiling Jinja template: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def detect_template_content_format(chat_template: str) -> str:
|
||||
"""
|
||||
Detect whether a chat template expects 'string' or 'openai' content format.
|
||||
|
||||
- 'string': content is a simple string (like DeepSeek templates)
|
||||
- 'openai': content is a list of structured dicts (like Llama4 templates)
|
||||
|
||||
Detection logic:
|
||||
- If template has loops like {%- for content in message['content'] -%} → 'openai'
|
||||
- Otherwise → 'string'
|
||||
"""
|
||||
jinja_ast = _try_extract_ast(chat_template)
|
||||
if jinja_ast is None:
|
||||
return "string"
|
||||
|
||||
try:
|
||||
# Look for patterns like: {%- for content in message['content'] -%}
|
||||
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
|
||||
# Check if iterating over message['content'] or similar
|
||||
if _is_var_or_elems_access(loop_iter, "message", "content"):
|
||||
return "openai" # Found content iteration → openai format
|
||||
|
||||
return "string" # No content loops found → string format
|
||||
except Exception as e:
|
||||
logger.debug(f"Error when parsing AST of Jinja template: {e}")
|
||||
return "string"
|
||||
|
||||
|
||||
def process_content_for_template_format(
|
||||
msg_dict: dict,
|
||||
content_format: str,
|
||||
image_data: list,
|
||||
audio_data: list,
|
||||
modalities: list,
|
||||
) -> dict:
|
||||
"""
|
||||
Process message content based on detected template format.
|
||||
|
||||
Args:
|
||||
msg_dict: Message dictionary with content
|
||||
content_format: 'string' or 'openai' (detected via AST analysis)
|
||||
image_data: List to append extracted image URLs
|
||||
audio_data: List to append extracted audio URLs
|
||||
modalities: List to append modalities
|
||||
|
||||
Returns:
|
||||
Processed message dictionary
|
||||
"""
|
||||
if not isinstance(msg_dict.get("content"), list):
|
||||
# Already a string or None, no processing needed
|
||||
return {k: v for k, v in msg_dict.items() if v is not None}
|
||||
|
||||
if content_format == "openai":
|
||||
# OpenAI format: preserve structured content list, normalize types
|
||||
processed_content_parts = []
|
||||
for chunk in msg_dict["content"]:
|
||||
if isinstance(chunk, dict):
|
||||
chunk_type = chunk.get("type")
|
||||
|
||||
if chunk_type == "image_url":
|
||||
image_data.append(chunk["image_url"]["url"])
|
||||
if chunk.get("modalities"):
|
||||
modalities.append(chunk.get("modalities"))
|
||||
# Normalize to simple 'image' type for template compatibility
|
||||
processed_content_parts.append({"type": "image"})
|
||||
elif chunk_type == "audio_url":
|
||||
audio_data.append(chunk["audio_url"]["url"])
|
||||
# Normalize to simple 'audio' type
|
||||
processed_content_parts.append({"type": "audio"})
|
||||
else:
|
||||
# Keep other content as-is (text, etc.)
|
||||
processed_content_parts.append(chunk)
|
||||
|
||||
new_msg = {
|
||||
k: v for k, v in msg_dict.items() if v is not None and k != "content"
|
||||
}
|
||||
new_msg["content"] = processed_content_parts
|
||||
return new_msg
|
||||
|
||||
else: # content_format == "string"
|
||||
# String format: flatten to text only (for templates like DeepSeek)
|
||||
text_parts = []
|
||||
for chunk in msg_dict["content"]:
|
||||
if isinstance(chunk, dict) and chunk.get("type") == "text":
|
||||
text_parts.append(chunk["text"])
|
||||
# Note: For string format, we ignore images/audio since the template
|
||||
# doesn't expect structured content - multimodal placeholders would
|
||||
# need to be inserted differently
|
||||
|
||||
new_msg = msg_dict.copy()
|
||||
new_msg["content"] = " ".join(text_parts) if text_parts else ""
|
||||
new_msg = {k: v for k, v in new_msg.items() if v is not None}
|
||||
return new_msg
|
||||
|
||||
|
||||
def calculate_token_usage(
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cached_tokens: Optional[Dict[str, int]] = None,
|
||||
) -> UsageInfo:
|
||||
"""Calculate token usage information"""
|
||||
return UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens_details=cached_tokens,
|
||||
)
|
||||
|
||||
|
||||
def aggregate_token_usage(
|
||||
responses: List[Dict[str, Any]],
|
||||
n_choices: int = 1,
|
||||
enable_cache_report: bool = False,
|
||||
) -> UsageInfo:
|
||||
"""Aggregate token usage from multiple responses
|
||||
|
||||
Args:
|
||||
responses: List of response dictionaries with meta_info
|
||||
n_choices: Number of choices per request (for prompt token counting)
|
||||
enable_cache_report: Whether to include cached token details
|
||||
|
||||
Returns:
|
||||
Aggregated UsageInfo
|
||||
"""
|
||||
# Sum completion tokens from all responses
|
||||
completion_tokens = sum(
|
||||
response["meta_info"]["completion_tokens"] for response in responses
|
||||
)
|
||||
|
||||
# For prompt tokens, only count every n_choices-th response to avoid double counting
|
||||
prompt_tokens = sum(
|
||||
responses[i]["meta_info"]["prompt_tokens"]
|
||||
for i in range(0, len(responses), n_choices)
|
||||
)
|
||||
|
||||
# Handle cached tokens if cache reporting is enabled
|
||||
cached_tokens_details = None
|
||||
if enable_cache_report:
|
||||
cached_tokens_sum = sum(
|
||||
response["meta_info"].get("cached_tokens", 0) for response in responses
|
||||
)
|
||||
if cached_tokens_sum > 0:
|
||||
cached_tokens_details = {"cached_tokens": cached_tokens_sum}
|
||||
|
||||
return calculate_token_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cached_tokens=cached_tokens_details,
|
||||
)
|
||||
|
||||
|
||||
def to_openai_style_logprobs(
|
||||
input_token_logprobs=None,
|
||||
output_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
):
|
||||
ret_logprobs = LogProbs()
|
||||
|
||||
def append_token_logprobs(token_logprobs):
|
||||
for logprob, _, token_text in token_logprobs:
|
||||
ret_logprobs.tokens.append(token_text)
|
||||
ret_logprobs.token_logprobs.append(logprob)
|
||||
|
||||
# Not supported yet
|
||||
ret_logprobs.text_offset.append(-1)
|
||||
|
||||
def append_top_logprobs(top_logprobs):
|
||||
for tokens in top_logprobs:
|
||||
if tokens is not None:
|
||||
ret_logprobs.top_logprobs.append(
|
||||
{token[2]: token[0] for token in tokens}
|
||||
)
|
||||
else:
|
||||
ret_logprobs.top_logprobs.append(None)
|
||||
|
||||
if input_token_logprobs is not None:
|
||||
append_token_logprobs(input_token_logprobs)
|
||||
if output_token_logprobs is not None:
|
||||
append_token_logprobs(output_token_logprobs)
|
||||
if input_top_logprobs is not None:
|
||||
append_top_logprobs(input_top_logprobs)
|
||||
if output_top_logprobs is not None:
|
||||
append_top_logprobs(output_top_logprobs)
|
||||
|
||||
return ret_logprobs
|
||||
Reference in New Issue
Block a user