feat(oai refactor): Replace openai_api with entrypoints/openai (#7351)

Co-authored-by: Jin Pan <jpan236@wisc.edu>
This commit is contained in:
Chang Su
2025-06-21 13:21:06 -07:00
committed by GitHub
parent 02bf31ef29
commit 72676cd6c0
43 changed files with 674 additions and 4555 deletions

View File

@@ -20,7 +20,7 @@ from sglang.bench_serving import (
get_gen_prefix_cache_path,
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos

View File

@@ -64,11 +64,14 @@
"text = \"Once upon a time\"\n",
"\n",
"curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
"\n",
"text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n",
" \"embedding\"\n",
"]\n",
"result = subprocess.check_output(curl_text, shell=True)\n",
"\n",
"print(result)\n",
"\n",
"text_embedding = json.loads(result)[\"data\"][0][\"embedding\"]\n",
"\n",
"print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")"
]
@@ -152,6 +155,7 @@
"input_ids = tokenizer.encode(text)\n",
"\n",
"curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
"\n",
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",

View File

@@ -67,6 +67,7 @@
"\n",
"curl_command = f\"\"\"\n",
"curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
" -H \"Content-Type: application/json\" \\\\\n",
" -d '{{\n",
" \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
" \"messages\": [\n",

View File

@@ -36,7 +36,7 @@
"import requests\n",
"from PIL import Image\n",
"\n",
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
"from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest\n",
"from sglang.srt.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",

View File

@@ -15,9 +15,7 @@
import dataclasses
import json
import logging
import os
from enum import auto
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
@@ -57,46 +55,6 @@ class CompletionTemplate:
completion_templates: dict[str, CompletionTemplate] = {}
def load_completion_template_for_openai_api(completion_template_arg):
global completion_template_name
logger.info(
f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
)
if not completion_template_exists(completion_template_arg):
if not os.path.exists(completion_template_arg):
raise RuntimeError(
f"Completion template {completion_template_arg} is not a built-in template name "
"or a valid completion template file path."
)
assert completion_template_arg.endswith(
".json"
), "unrecognized format of completion template file"
with open(completion_template_arg, "r") as filep:
template = json.load(filep)
try:
fim_position = FimPosition[template["fim_position"]]
except KeyError:
raise ValueError(
f"Unknown fim position: {template['fim_position']}"
) from None
register_completion_template(
CompletionTemplate(
name=template["name"],
fim_begin_token=template["fim_begin_token"],
fim_middle_token=template["fim_middle_token"],
fim_end_token=template["fim_end_token"],
fim_position=fim_position,
),
override=True,
)
completion_template_name = template["name"]
else:
completion_template_name = completion_template_arg
def register_completion_template(template: CompletionTemplate, override: bool = False):
"""Register a new completion template."""
if not override:

View File

@@ -11,7 +11,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conversation chat templates."""
"""Conversation chat templates.
This module provides conversation template definitions, data structures, and utilities
for managing chat templates across different model types in SGLang.
Key components:
- Conversation class: Defines the structure and behavior of chat templates
- SeparatorStyle enum: Different conversation formatting styles
- Template registry: Functions to register and retrieve templates by name or model path
- Built-in templates: Pre-defined templates for popular models
"""
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -20,7 +30,7 @@ import re
from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.utils import read_system_prompt_from_file
@@ -618,7 +628,7 @@ def generate_chat_conv(
# llama2 template
# reference: https://huggingface.co/blog/codellama#conversational-instructions
# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
register_conv_template(
Conversation(

View File

@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import torch
import uvloop
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
from sglang.srt.entrypoints.EngineBase import EngineBase
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
guess_chat_template_name_from_model_path,
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
@@ -123,12 +119,13 @@ class Engine(EngineBase):
logger.info(f"{server_args=}")
# Launch subprocesses
tokenizer_manager, scheduler_info = _launch_subprocesses(
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
port_args=port_args,
)
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
self.template_manager = template_manager
self.scheduler_info = scheduler_info
context = zmq.Context(2)
@@ -647,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, Dict]:
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
"""
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
@@ -732,7 +729,7 @@ def _launch_subprocesses(
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here.
return None, None
return None, None, None
launch_dummy_health_check_server(server_args.host, server_args.port)
@@ -741,7 +738,7 @@ def _launch_subprocesses(
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
return None, None
return None, None, None
# Launch detokenizer process
detoken_proc = mp.Process(
@@ -755,15 +752,15 @@ def _launch_subprocesses(
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(
tokenizer_manager, server_args.chat_template, server_args.model_path
)
else:
guess_chat_template_name_from_model_path(server_args.model_path)
if server_args.completion_template:
load_completion_template_for_openai_api(server_args.completion_template)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Wait for the model to finish loading
scheduler_infos = []
@@ -787,4 +784,4 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, scheduler_info
return tokenizer_manager, template_manager, scheduler_info

View File

@@ -38,7 +38,8 @@ import orjson
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi import Depends, FastAPI, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
@@ -47,6 +48,20 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
EmbeddingRequest,
ModelCard,
ModelList,
ScoringRequest,
V1RerankReqInput,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
AbortReq,
@@ -67,26 +82,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
V1RerankReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.adapter import (
v1_batches,
v1_cancel_batch,
v1_chat_completions,
v1_completions,
v1_delete_file,
v1_embeddings,
v1_files_create,
v1_rerank,
v1_retrieve_batch,
v1_retrieve_file,
v1_retrieve_file_content,
v1_score,
)
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
@@ -109,6 +109,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class _GlobalState:
tokenizer_manager: TokenizerManager
template_manager: TemplateManager
scheduler_info: Dict
@@ -123,6 +124,24 @@ def set_global_state(global_state: _GlobalState):
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_chat = OpenAIServingChat(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_score = OpenAIServingScore(
_global_state.tokenizer_manager
)
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
_global_state.tokenizer_manager
)
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
@@ -148,6 +167,36 @@ app.add_middleware(
allow_headers=["*"],
)
# Custom exception handlers to change validation error status codes
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Override FastAPI's default 422 validation error with 400"""
return ORJSONResponse(
status_code=400,
content={
"detail": exc.errors(),
"body": exc.body,
},
)
async def validate_json_request(raw_request: Request):
"""Validate that the request content-type is application/json."""
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise RequestValidationError(
errors=[
{
"loc": ["header", "content-type"],
"msg": "Unsupported Media Type: Only 'application/json' is allowed",
"type": "value_error",
}
]
)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
@@ -330,13 +379,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e)
@app.api_route("/v1/rerank", methods=["POST", "PUT"])
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
try:
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
@app.api_route("/flush_cache", methods=["GET", "POST"])
@@ -619,25 +669,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
##### OpenAI-compatible API endpoints #####
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
"""OpenAI-compatible text completion endpoint."""
return await raw_request.app.state.openai_serving_completion.handle_request(
request, raw_request
)
@app.post("/v1/chat/completions")
async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
async def openai_v1_chat_completions(
request: ChatCompletionRequest, raw_request: Request
):
"""OpenAI-compatible chat completion endpoint."""
return await raw_request.app.state.openai_serving_chat.handle_request(
request, raw_request
)
@app.post("/v1/embeddings", response_class=ORJSONResponse)
async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
return response
@app.post(
"/v1/embeddings",
response_class=ORJSONResponse,
dependencies=[Depends(validate_json_request)],
)
async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
"""OpenAI-compatible embeddings endpoint."""
return await raw_request.app.state.openai_serving_embedding.handle_request(
request, raw_request
)
@app.get("/v1/models", response_class=ORJSONResponse)
def available_models():
"""Show available models."""
async def available_models():
"""Show available models. OpenAI-compatible endpoint."""
served_model_names = [_global_state.tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
@@ -651,47 +715,31 @@ def available_models():
return ModelList(data=model_cards)
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create(
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
@app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
async def retrieve_model(model: str):
"""Retrieves a model instance, providing basic information about the model."""
served_model_names = [_global_state.tokenizer_manager.served_model_name]
if model not in served_model_names:
return ORJSONResponse(
status_code=404,
content={
"error": {
"message": f"The model '{model}' does not exist",
"type": "invalid_request_error",
"param": "model",
"code": "model_not_found",
}
},
)
return ModelCard(
id=model,
root=model,
max_model_len=_global_state.tokenizer_manager.model_config.context_len,
)
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/delete
return await v1_delete_file(file_id)
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
return await v1_batches(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
# https://platform.openai.com/docs/api-reference/batch/cancel
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id)
@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve
return await v1_retrieve_file(file_id)
@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
return await v1_retrieve_file_content(file_id)
## SageMaker API
@app.get("/ping")
async def sagemaker_health() -> Response:
@@ -700,8 +748,13 @@ async def sagemaker_health() -> Response:
@app.post("/invocations")
async def sagemaker_chat_completions(raw_request: Request):
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
async def sagemaker_chat_completions(
request: ChatCompletionRequest, raw_request: Request
):
"""OpenAI-compatible chat completion endpoint."""
return await raw_request.app.state.openai_serving_chat.handle_request(
request, raw_request
)
## Vertex AI API
@@ -732,10 +785,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
return ORJSONResponse({"predictions": ret})
@app.post("/v1/score")
async def v1_score_request(raw_request: Request):
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
async def v1_score_request(request: ScoringRequest, raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
return await v1_score(_global_state.tokenizer_manager, raw_request)
return await raw_request.app.state.openai_serving_score.handle_request(
request, raw_request
)
def _create_error_response(e):
@@ -764,10 +819,13 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info,
)
)

View File

@@ -1,388 +0,0 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
SGLang OpenAI-Compatible API Server.
This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI.
"""
import argparse
import asyncio
import logging
import multiprocessing
import os
import threading
import time
from contextlib import asynccontextmanager
from typing import Callable, Dict, Optional
import numpy as np
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
add_prometheus_middleware,
delete_directory,
get_bool_env_var,
kill_process_tree,
set_uvicorn_logging_configs,
)
from sglang.srt.warmup import execute_warmups
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Store global states
class AppState:
engine: Optional[Engine] = None
server_args: Optional[ServerArgs] = None
tokenizer_manager: Optional[TokenizerManager] = None
scheduler_info: Optional[Dict] = None
embedding_server: Optional[OpenAIServingEmbedding] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.server_args.enable_metrics = True # By default, we enable metrics
server_args = app.state.server_args
# Initialize engine
logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...")
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
app.state.tokenizer_manager = tokenizer_manager
app.state.scheduler_info = scheduler_info
app.state.serving_embedding = OpenAIServingEmbedding(
tokenizer_manager=tokenizer_manager
)
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Initialize engine state attribute to None for now
app.state.engine = None
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), app.state.tokenizer_manager
)
logger.info("Warmup ended")
warmup_thread = getattr(app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
# Lifespan shutdown
if hasattr(app.state, "engine") and app.state.engine is not None:
logger.info("SGLang engine is shutting down.")
# Add engine cleanup logic here when implemented
# Fast API app with CORS enabled
app = FastAPI(
lifespan=lifespan,
# TODO: check where /openai.json is created or why we use this
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.api_route("/health", methods=["GET"])
async def health() -> Response:
"""Health check. Used for readiness and liveness probes."""
# In the future, this could check engine health more deeply
# For now, if the server is up, it's healthy.
return Response(status_code=200)
@app.api_route("/v1/models", methods=["GET"])
async def show_models():
"""Show available models. Currently, it returns the served model name.
This endpoint is compatible with the OpenAI API standard.
"""
served_model_names = [app.state.tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
model_cards.append(
ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)
)
return ModelList(data=model_cards)
@app.get("/get_model_info")
async def get_model_info():
"""Get the model information."""
result = {
"model_path": app.state.tokenizer_manager.model_path,
"tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": app.state.tokenizer_manager.is_generation,
}
return result
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
pass
@app.post("/v1/chat/completions")
async def openai_v1_chat_completions(raw_request: Request):
pass
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
try:
request_json = await raw_request.json()
request = EmbeddingRequest(**request_json)
except Exception as e:
return app.state.serving_embedding.create_error_response(
f"Invalid request body, error: {str(e)}"
)
ret = await app.state.serving_embedding.handle_request(request, raw_request)
return ret
@app.post("/v1/score")
async def v1_score_request(raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
pass
@app.api_route("/v1/models/{model_id}", methods=["GET"])
async def show_model_detail(model_id: str):
served_model_name = app.state.tokenizer_manager.served_model_name
return ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)
# Additional API endpoints will be implemented in separate serving_*.py modules
# and mounted as APIRouters in future PRs
def _wait_and_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
return
# TODO: Please wait until the /generate implementation is complete,
# or confirm if modifications are needed before removing this.
headers = {}
url = server_args.url()
if server_args.api_key:
headers["Authorization"] = f"Bearer {server_args.api_key}"
# Wait until the server is launched
success = False
for _ in range(120):
time.sleep(1)
try:
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
assert res.status_code == 200, f"{res=}, {res.text=}"
success = True
break
except (AssertionError, requests.exceptions.RequestException):
last_traceback = get_exception_traceback()
pass
if not success:
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_process_tree(os.getpid())
return
model_info = res.json()
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
# TODO: Replace with OpenAI API
max_new_tokens = 8 if model_info["is_generation"] else 1
json_data = {
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
}
if server_args.skip_tokenizer_init:
json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
# TODO Workaround the bug that embedding errors for list of size 1
if server_args.dp_size == 1:
json_data["input_ids"] = json_data["input_ids"][0]
else:
json_data["text"] = ["The capital city of France is"] * server_args.dp_size
# TODO Workaround the bug that embedding errors for list of size 1
if server_args.dp_size == 1:
json_data["text"] = json_data["text"][0]
# Debug dumping
if server_args.debug_tensor_dump_input_file:
json_data.pop("text", None)
json_data["input_ids"] = np.load(
server_args.debug_tensor_dump_input_file
).tolist()
json_data["sampling_params"]["max_new_tokens"] = 0
try:
if server_args.disaggregation_mode == "null":
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=600,
)
assert res.status_code == 200, f"{res}"
else:
logger.info(f"Start of prefill warmup ...")
json_data = {
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room": [
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
for i in range(server_args.dp_size)
],
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
}
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache.
)
logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
)
except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_process_tree(os.getpid())
return
# Debug print
# logger.info(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")
if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())
if server_args.pdlb_url is not None:
register_disaggregation_server(
server_args.disaggregation_mode,
server_args.port,
server_args.disaggregation_bootstrap_port,
server_args.pdlb_url,
)
if launch_callback is not None:
launch_callback()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server")
# Add arguments from ServerArgs. This allows reuse of existing CLI definitions.
ServerArgs.add_cli_args(parser)
# Potentially add server-specific arguments here in the future if needed
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
# Store server_args in app.state for access in lifespan and endpoints
app.state.server_args = server_args
# Configure logging
logging.basicConfig(
level=server_args.log_level.upper(),
format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
)
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
None,
None, # Never used
None,
),
)
app.warmup_thread = warmup_thread
try:
# Start the server
set_uvicorn_logging_configs()
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level.lower(),
timeout_keep_alive=60, # Increased keep-alive for potentially long requests
loop="uvloop", # Use uvloop for better performance if available
)
finally:
warmup_thread.join()

View File

@@ -207,7 +207,7 @@ class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@@ -404,21 +404,13 @@ class ChatCompletionRequest(BaseModel):
@model_validator(mode="before")
@classmethod
def set_tool_choice_default(cls, values):
if isinstance(values, dict):
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
return values
@field_validator("messages")
@classmethod
def validate_messages_not_empty(cls, v):
if not v:
raise ValueError("Messages cannot be empty")
return v
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
@@ -457,9 +449,11 @@ class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
finish_reason: Optional[
Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@@ -530,7 +524,7 @@ class EmbeddingRequest(BaseModel):
input: EmbeddingInput
model: str
encoding_format: str = "float"
dimensions: int = None
dimensions: Optional[int] = None
user: Optional[str] = None
# The request id.

View File

@@ -2,16 +2,12 @@ import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
OpenAIServingRequest,
UsageInfo,
)
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -51,7 +47,7 @@ class OpenAIServingBase(ABC):
)
except Exception as e:
logger.error(f"Error in request: {e}")
logger.exception(f"Error in request: {e}")
return self.create_error_response(
message=f"Internal server error: {str(e)}",
err_type="InternalServerError",
@@ -63,8 +59,12 @@ class OpenAIServingBase(ABC):
"""Generate request ID based on request type"""
pass
def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:
"""Generate request ID based on request type"""
return None
# TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError
# Temporarily return None in this function until the rid logic is clear.
if rid := getattr(request, "rid", None):
return rid
@@ -83,7 +83,7 @@ class OpenAIServingBase(ABC):
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
) -> StreamingResponse:
) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:
"""Handle streaming request
Override this method in child classes that support streaming requests.
@@ -99,7 +99,7 @@ class OpenAIServingBase(ABC):
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
) -> Union[Any, ErrorResponse]:
) -> Union[Any, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming request
Override this method in child classes that support non-streaming requests.
@@ -110,7 +110,7 @@ class OpenAIServingBase(ABC):
status_code=501,
)
def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:
"""Validate request"""
pass
@@ -122,6 +122,7 @@ class OpenAIServingBase(ABC):
param: Optional[str] = None,
) -> ORJSONResponse:
"""Create an error response"""
# TODO: remove fastapi dependency in openai and move response handling to the entrypoint
error = ErrorResponse(
object="error",
message=message,

View File

@@ -1,4 +1,3 @@
import base64
import json
import logging
import time
@@ -6,7 +5,7 @@ import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.openai.protocol import (
@@ -28,13 +27,14 @@ from sglang.srt.entrypoints.openai.protocol import (
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
detect_template_content_format,
process_content_for_template_format,
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.jinja_template_utils import process_content_for_template_format
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str
@@ -42,13 +42,13 @@ logger = logging.getLogger(__name__)
class OpenAIServingChat(OpenAIServingBase):
"""Handler for chat completion requests"""
"""Handler for /v1/chat/completions requests"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Instance-specific cache for template content format detection
self._cached_chat_template = None
self._cached_template_format = None
def __init__(
self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "chatcmpl-"
@@ -142,19 +142,14 @@ class OpenAIServingChat(OpenAIServingBase):
)
# Use chat template
if (
hasattr(self.tokenizer_manager, "chat_template_name")
and self.tokenizer_manager.chat_template_name is None
):
if self.template_manager.chat_template_name is None:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_jinja_template(request, tools, is_multimodal)
)
else:
prompt, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request)
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request, is_multimodal)
)
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
else:
# Use raw prompt
prompt_ids = request.messages
@@ -181,23 +176,14 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
"""Apply Jinja chat template"""
prompt = ""
prompt_ids = []
openai_compatible_messages = []
image_data = []
audio_data = []
modalities = []
# Detect template content format
current_template = self.tokenizer_manager.tokenizer.chat_template
if current_template != self._cached_chat_template:
self._cached_chat_template = current_template
self._cached_template_format = detect_template_content_format(
current_template
)
logger.info(
f"Detected chat template content format: {self._cached_template_format}"
)
template_content_format = self._cached_template_format
template_content_format = self.template_manager.jinja_template_content_format
for message in request.messages:
if message.content is None:
@@ -262,14 +248,21 @@ class OpenAIServingChat(OpenAIServingBase):
if is_multimodal:
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
stop = request.stop or []
stop = request.stop
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
modalities = modalities if modalities else []
return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _apply_conversation_template(
self, request: ChatCompletionRequest
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]:
self,
request: ChatCompletionRequest,
is_multimodal: bool,
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
"""Apply conversation template"""
conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name)
prompt = ""
prompt_ids = []
conv = generate_chat_conv(request, self.template_manager.chat_template_name)
# If we should continue the final assistant message, adjust the conversation.
if (
@@ -296,9 +289,9 @@ class OpenAIServingChat(OpenAIServingBase):
else:
prompt = conv.get_prompt()
image_data = conv.image_data
audio_data = conv.audio_data
modalities = conv.modalities
image_data = conv.image_data if conv.image_data else None
audio_data = conv.audio_data if conv.audio_data else None
modalities = conv.modalities if conv.modalities else []
stop = conv.stop_str or [] if not request.ignore_eos else []
if request.stop:
@@ -307,7 +300,10 @@ class OpenAIServingChat(OpenAIServingBase):
else:
stop.extend(request.stop)
return prompt, image_data, audio_data, modalities, stop
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _build_sampling_params(
self,
@@ -459,13 +455,9 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
and enable_thinking
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
@@ -591,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {usage_chunk.model_dump_json()}\n\n"
except Exception as e:
except ValueError as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
@@ -602,7 +594,7 @@ class OpenAIServingChat(OpenAIServingBase):
adapted_request: GenerateReqInput,
request: ChatCompletionRequest,
raw_request: Request,
) -> Union[ChatCompletionResponse, ErrorResponse]:
) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming chat completion request"""
try:
ret = await self.tokenizer_manager.generate_request(
@@ -627,7 +619,7 @@ class OpenAIServingChat(OpenAIServingBase):
request: ChatCompletionRequest,
ret: List[Dict[str, Any]],
created: int,
) -> ChatCompletionResponse:
) -> Union[ChatCompletionResponse, ORJSONResponse]:
"""Build chat completion response from generation results"""
choices = []
@@ -645,11 +637,8 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content
reasoning_text = None
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning and enable_thinking:
if reasoning_parser and request.separate_reasoning:
try:
parser = ReasoningParser(
model_type=reasoning_parser, stream_reasoning=False
@@ -691,9 +680,10 @@ class OpenAIServingChat(OpenAIServingBase):
choices.append(choice_data)
# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
ret,
n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
return ChatCompletionResponse(
@@ -821,6 +811,25 @@ class OpenAIServingChat(OpenAIServingBase):
reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta)
def _get_enable_thinking_from_request(request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
NOTE: This parameter is only useful for models that support enable_thinking
flag, such as Qwen3.
Args:
request_obj: The request object (or an item from a list of requests).
Returns:
The boolean value of 'enable_thinking' if found and not True, otherwise True.
"""
if (
hasattr(request, "chat_template_kwargs")
and request.chat_template_kwargs
and request.chat_template_kwargs.get("enable_thinking") is not None
):
return request.chat_template_kwargs.get("enable_thinking")
return True
async def _process_tool_call_stream(
self,
index: int,

View File

@@ -3,12 +3,9 @@ import time
from typing import Any, AsyncGenerator, Dict, List, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.code_completion_parser import (
generate_completion_prompt_from_request,
is_completion_template_defined,
)
from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
@@ -24,12 +21,22 @@ from sglang.srt.entrypoints.openai.utils import (
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__name__)
class OpenAIServingCompletion(OpenAIServingBase):
"""Handler for completion requests"""
"""Handler for /v1/completion requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "cmpl-"
@@ -47,7 +54,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
)
# Process prompt
prompt = request.prompt
if is_completion_template_defined():
if self.template_manager.completion_template_name is not None:
prompt = generate_completion_prompt_from_request(request)
# Set logprob start length based on echo and logprobs
@@ -141,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in self.tokenizer_manager.generate_request(
@@ -152,6 +160,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
stream_buffer = stream_buffers.get(index, "")
# Handle echo for first chunk
@@ -192,7 +201,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
hidden_states = content["meta_info"].get("hidden_states", None)
choice_data = CompletionResponseStreamChoice(
index=index,
@@ -269,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> Union[CompletionResponse, ErrorResponse]:
) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming completion request"""
try:
generator = self.tokenizer_manager.generate_request(

View File

@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.conversation import generate_embedding_convs
from sglang.srt.entrypoints.openai.protocol import (
@@ -13,10 +14,20 @@ from sglang.srt.entrypoints.openai.protocol import (
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import EmbeddingReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
class OpenAIServingEmbedding(OpenAIServingBase):
"""Handler for embedding requests"""
"""Handler for v1/embeddings requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "embd-"
@@ -68,11 +79,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings - if it's a single string in a list, treat as single string
if len(prompt) == 1:
prompt_kwargs = {"text": prompt[0]}
else:
prompt_kwargs = {"text": prompt}
prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs
texts = []
@@ -84,11 +91,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None
)
if chat_template_name is not None:
convs = generate_embedding_convs(texts, images, chat_template_name)
if self.template_manager.chat_template_name is not None:
convs = generate_embedding_convs(
texts, images, self.template_manager.chat_template_name
)
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
@@ -122,7 +128,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
adapted_request: EmbeddingReqInput,
request: EmbeddingRequest,
raw_request: Request,
) -> Union[EmbeddingResponse, ErrorResponse]:
) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:
"""Handle the embedding request"""
try:
ret = await self.tokenizer_manager.generate_request(

View File

@@ -2,6 +2,7 @@ import logging
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
@@ -15,7 +16,10 @@ logger = logging.getLogger(__name__)
class OpenAIServingRerank(OpenAIServingBase):
"""Handler for rerank requests"""
"""Handler for /v1/rerank requests"""
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
# to another module in the future.
def _request_id_prefix(self) -> str:
return "rerank-"
@@ -61,7 +65,7 @@ class OpenAIServingRerank(OpenAIServingBase):
adapted_request: EmbeddingReqInput,
request: V1RerankReqInput,
raw_request: Request,
) -> Union[RerankResponse, ErrorResponse]:
) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:
"""Handle the rerank request"""
try:
ret = await self.tokenizer_manager.generate_request(
@@ -74,16 +78,16 @@ class OpenAIServingRerank(OpenAIServingBase):
if not isinstance(ret, list):
ret = [ret]
response = self._build_rerank_response(ret, request)
return response
responses = self._build_rerank_response(ret, request)
return responses
def _build_rerank_response(
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
) -> List[RerankResponse]:
"""Build the rerank response from generation results"""
response = []
responses = []
for idx, ret_item in enumerate(ret):
response.append(
responses.append(
RerankResponse(
score=ret_item["embedding"],
document=request.documents[idx],
@@ -93,6 +97,6 @@ class OpenAIServingRerank(OpenAIServingBase):
)
# Sort by score in descending order (highest relevance first)
response.sort(key=lambda x: x.score, reverse=True)
responses.sort(key=lambda x: x.score, reverse=True)
return response
return responses

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Union
from fastapi import Request
@@ -14,7 +14,10 @@ logger = logging.getLogger(__name__)
class OpenAIServingScore(OpenAIServingBase):
"""Handler for scoring requests"""
"""Handler for /v1/score requests"""
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
# to another module in the future.
def _request_id_prefix(self) -> str:
return "score-"

View File

@@ -1,9 +1,6 @@
import logging
from typing import Any, Dict, List, Optional, Union
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
@@ -13,168 +10,6 @@ from sglang.srt.entrypoints.openai.protocol import (
logger = logging.getLogger(__name__)
# ============================================================================
# JINJA TEMPLATE CONTENT FORMAT DETECTION
# ============================================================================
#
# This adapts vLLM's approach for detecting chat template content format:
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
# - Analyzes Jinja template AST to detect content iteration patterns
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
# - 'string' format: templates that expect simple string content
# - Processes content accordingly to match template expectations
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
"""Check if node is a variable access like {{ varname }}"""
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str = None,
) -> bool:
"""Check if node accesses varname or varname[key] with filters/tests"""
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _try_extract_ast(chat_template: str):
"""Try to parse the Jinja template into an AST"""
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception as e:
logger.debug(f"Error when compiling Jinja template: {e}")
return None
def detect_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
- 'string': content is a simple string (like DeepSeek templates)
- 'openai': content is a list of structured dicts (like Llama4 templates)
Detection logic:
- If template has loops like {%- for content in message['content'] -%} → 'openai'
- Otherwise → 'string'
"""
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return "string"
try:
# Look for patterns like: {%- for content in message['content'] -%}
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
# Check if iterating over message['content'] or similar
if _is_var_or_elems_access(loop_iter, "message", "content"):
return "openai" # Found content iteration → openai format
return "string" # No content loops found → string format
except Exception as e:
logger.debug(f"Error when parsing AST of Jinja template: {e}")
return "string"
def process_content_for_template_format(
msg_dict: dict,
content_format: str,
image_data: list,
audio_data: list,
modalities: list,
) -> dict:
"""
Process message content based on detected template format.
Args:
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities
Returns:
Processed message dictionary
"""
if not isinstance(msg_dict.get("content"), list):
# Already a string or None, no processing needed
return {k: v for k, v in msg_dict.items() if v is not None}
if content_format == "openai":
# OpenAI format: preserve structured content list, normalize types
processed_content_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict):
chunk_type = chunk.get("type")
if chunk_type == "image_url":
image_data.append(chunk["image_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type
processed_content_parts.append({"type": "audio"})
else:
# Keep other content as-is (text, etc.)
processed_content_parts.append(chunk)
new_msg = {
k: v for k, v in msg_dict.items() if v is not None and k != "content"
}
new_msg["content"] = processed_content_parts
return new_msg
else: # content_format == "string"
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict) and chunk.get("type") == "text":
text_parts.append(chunk["text"])
# Note: For string format, we ignore images/audio since the template
# doesn't expect structured content - multimodal placeholders would
# need to be inserted differently
new_msg = msg_dict.copy()
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, List
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
_is_complete_json,
_partial_json_loads,
)
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)

View File

@@ -3,6 +3,7 @@ import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)

View File

@@ -1,6 +1,12 @@
import logging
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from sglang.srt.entrypoints.openai.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolChoice,
)
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.openai_api.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolChoice,
)
logger = logging.getLogger(__name__)

View File

@@ -2,6 +2,7 @@ import json
import logging
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)

View File

@@ -3,6 +3,7 @@ import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)

View File

@@ -4,6 +4,7 @@ import logging
import re
from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)

View File

@@ -3,6 +3,7 @@ import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)

View File

@@ -1,11 +1,12 @@
"""
Utility functions for OpenAI API adapter.
"""Template utilities for Jinja template processing.
This module provides utilities for analyzing and processing Jinja chat templates,
including content format detection and message processing.
"""
import logging
from typing import Dict, List
import jinja2.nodes
import jinja2
import transformers.utils.chat_template_utils as hf_chat_utils
logger = logging.getLogger(__name__)
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
return None
def detect_template_content_format(chat_template: str) -> str:
def detect_jinja_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.

View File

@@ -864,12 +864,6 @@ class SetInternalStateReq:
server_args: Dict[str, Any]
@dataclass
class V1RerankReqInput:
query: str
documents: List[str]
@dataclass
class SetInternalStateReqOutput:
updated: bool

View File

@@ -0,0 +1,226 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Centralized template management for chat templates and completion templates.
This module provides a unified interface for managing both chat conversation templates
and code completion templates, eliminating global state and improving modularity.
"""
import json
import logging
import os
from typing import Optional
from sglang.srt.code_completion_parser import (
CompletionTemplate,
FimPosition,
completion_template_exists,
register_completion_template,
)
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
get_conv_template_by_model_path,
register_conv_template,
)
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
logger = logging.getLogger(__name__)
class TemplateManager:
"""
Centralized manager for chat and completion templates.
This class encapsulates all template-related state and operations,
eliminating the need for global variables and providing a clean
interface for template management.
"""
def __init__(self):
self._chat_template_name: Optional[str] = None
self._completion_template_name: Optional[str] = None
self._jinja_template_content_format: Optional[str] = None
@property
def chat_template_name(self) -> Optional[str]:
"""Get the current chat template name."""
return self._chat_template_name
@property
def completion_template_name(self) -> Optional[str]:
"""Get the current completion template name."""
return self._completion_template_name
@property
def jinja_template_content_format(self) -> Optional[str]:
"""Get the detected template content format ('string' or 'openai' or None)."""
return self._jinja_template_content_format
def load_chat_template(
self, tokenizer_manager, chat_template_arg: str, model_path: str
) -> None:
"""
Load a chat template from various sources.
Args:
tokenizer_manager: The tokenizer manager instance
chat_template_arg: Template name or file path
model_path: Path to the model
"""
logger.info(f"Loading chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
if chat_template_arg.endswith(".jinja"):
self._load_jinja_template(tokenizer_manager, chat_template_arg)
else:
self._load_json_chat_template(chat_template_arg)
else:
self._chat_template_name = chat_template_arg
def guess_chat_template_from_model_path(self, model_path: str) -> None:
"""
Infer chat template name from model path.
Args:
model_path: Path to the model
"""
template_name = get_conv_template_by_model_path(model_path)
if template_name is not None:
logger.info(f"Inferred chat template from model path: {template_name}")
self._chat_template_name = template_name
def load_completion_template(self, completion_template_arg: str) -> None:
"""
Load completion template for code completion.
Args:
completion_template_arg: Template name or file path
"""
logger.info(f"Loading completion template: {completion_template_arg}")
if not completion_template_exists(completion_template_arg):
if not os.path.exists(completion_template_arg):
raise RuntimeError(
f"Completion template {completion_template_arg} is not a built-in template name "
"or a valid completion template file path."
)
self._load_json_completion_template(completion_template_arg)
else:
self._completion_template_name = completion_template_arg
def initialize_templates(
self,
tokenizer_manager,
model_path: str,
chat_template: Optional[str] = None,
completion_template: Optional[str] = None,
) -> None:
"""
Initialize all templates based on provided configuration.
Args:
tokenizer_manager: The tokenizer manager instance
model_path: Path to the model
chat_template: Optional chat template name/path
completion_template: Optional completion template name/path
"""
# Load chat template
if chat_template:
self.load_chat_template(tokenizer_manager, chat_template, model_path)
else:
self.guess_chat_template_from_model_path(model_path)
# Load completion template
if completion_template:
self.load_completion_template(completion_template)
def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
"""Load a Jinja template file."""
with open(template_path, "r") as f:
chat_template = "".join(f.readlines()).strip("\n")
tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
self._chat_template_name = None
# Detect content format from the loaded template
self._jinja_template_content_format = detect_jinja_template_content_format(
chat_template
)
logger.info(
f"Detected chat template content format: {self._jinja_template_content_format}"
)
def _load_json_chat_template(self, template_path: str) -> None:
"""Load a JSON chat template file."""
assert template_path.endswith(
".json"
), "unrecognized format of chat template file"
with open(template_path, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
self._chat_template_name = template["name"]
def _load_json_completion_template(self, template_path: str) -> None:
"""Load a JSON completion template file."""
assert template_path.endswith(
".json"
), "unrecognized format of completion template file"
with open(template_path, "r") as filep:
template = json.load(filep)
try:
fim_position = FimPosition[template["fim_position"]]
except KeyError:
raise ValueError(
f"Unknown fim position: {template['fim_position']}"
) from None
register_completion_template(
CompletionTemplate(
name=template["name"],
fim_begin_token=template["fim_begin_token"],
fim_middle_token=template["fim_middle_token"],
fim_end_token=template["fim_end_token"],
fim_position=fim_position,
),
override=True,
)
self._completion_template_name = template["name"]

View File

@@ -1058,12 +1058,7 @@ class TokenizerManager:
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
out_skip_names = set(["text", "output_ids", "embedding"])
elif self.log_requests_level == 1:
max_length = 2048
elif self.log_requests_level == 2:

File diff suppressed because it is too large Load Diff

View File

@@ -1,551 +0,0 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pydantic models for OpenAI API protocol"""
import time
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, model_serializer, root_validator
from typing_extensions import Literal
class ModelCard(BaseModel):
"""Model cards."""
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "sglang"
root: Optional[str] = None
max_model_len: Optional[int] = None
class ModelList(BaseModel):
"""Model list consists of model cards."""
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class TopLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
class ChatCompletionTokenLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[TopLogprob]
class ChoiceLogprobs(BaseModel):
# build for v1/chat/completions response
content: List[ChatCompletionTokenLogprob]
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
# only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details: Optional[Dict[str, int]] = None
class StreamOptions(BaseModel):
include_usage: Optional[bool] = False
class JsonSchemaResponseFormat(BaseModel):
name: str
description: Optional[str] = None
# use alias to workaround pydantic conflict
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
strict: Optional[bool] = False
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
purpose: str = (
"batch" # The intended purpose of the uploaded file, default is "batch"
)
class FileResponse(BaseModel):
id: str
object: str = "file"
bytes: int
created_at: int
filename: str
purpose: str
class FileDeleteResponse(BaseModel):
id: str
object: str = "file"
deleted: bool
class BatchRequest(BaseModel):
input_file_id: (
str # The ID of an uploaded file that contains requests for the new batch
)
endpoint: str # The endpoint to be used for all requests in the batch
completion_window: str # The time frame within which the batch should be processed
metadata: Optional[dict] = None # Optional custom metadata for the batch
class BatchResponse(BaseModel):
id: str
object: str = "batch"
endpoint: str
errors: Optional[dict] = None
input_file_id: str
completion_window: str
status: str = "validating"
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
created_at: int
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
request_counts: Optional[dict] = None
metadata: Optional[dict] = None
class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: bool = False
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: int = 16
n: int = 1
presence_penalty: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: float = 1.0
top_p: float = 1.0
user: Optional[str] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
json_schema: Optional[str] = None
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
return_hidden_states: Optional[bool] = False
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionStreamResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = None
class ChatCompletionMessageContentTextPart(BaseModel):
type: Literal["text"]
text: str
class ChatCompletionMessageContentImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentAudioURL(BaseModel):
url: str
class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
class ChatCompletionMessageContentAudioPart(BaseModel):
type: Literal["audio_url"]
audio_url: ChatCompletionMessageContentAudioURL
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentAudioPart,
]
class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
"""Tool call response."""
id: Optional[str] = None
index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse
class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant", "tool"]
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
tool_call_id: Optional[str] = None
name: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionMessageUserParam(BaseModel):
role: Literal["user"]
content: Union[str, List[ChatCompletionMessageContentPart]]
ChatCompletionMessageParam = Union[
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
]
class ResponseFormat(BaseModel):
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StructuresResponseFormat(BaseModel):
begin: str
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
end: str
class StructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]
class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
strict: bool = False
class Tool(BaseModel):
"""Function wrapper."""
type: str = Field(default="function", examples=["function"])
function: Function
class ToolChoiceFuncName(BaseModel):
"""The name of tool choice function."""
name: Optional[str] = None
class ToolChoice(BaseModel):
"""The tool choice definition."""
function: ToolChoiceFuncName
type: Literal["function"] = Field(default="function", examples=["function"])
class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: bool = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = Field(
default=None,
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
description="The maximum number of tokens that can be generated in the chat completion. ",
)
max_completion_tokens: Optional[int] = Field(
default=None,
description="The maximum number of completion tokens for a chat completion request, "
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
)
n: int = 1
presence_penalty: float = 0.0
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
temperature: float = 0.7
top_p: float = 1.0
user: Optional[str] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
@root_validator(pre=True)
def set_tool_choice_default(cls, values):
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
return values
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
separate_reasoning: bool = True
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
# The request id.
rid: Optional[str] = None
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
# Hidden States
return_hidden_states: Optional[bool] = False
class ChatMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
] = None
matched_stop: Union[None, int, str] = None
class ChatCompletionStreamResponse(BaseModel):
id: str
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = None
class MultimodalEmbeddingInput(BaseModel):
text: Optional[str] = None
image: Optional[str] = None
class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings/create
input: Union[
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
]
model: str
encoding_format: str = "float"
dimensions: int = None
user: Optional[str] = None
# The request id.
rid: Optional[str] = None
class EmbeddingObject(BaseModel):
embedding: List[float]
index: int
object: str = "embedding"
class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject]
model: str
object: str = "list"
usage: Optional[UsageInfo] = None
class ScoringRequest(BaseModel):
query: Optional[Union[str, List[int]]] = (
None # Query text or pre-tokenized token IDs
)
items: Optional[Union[str, List[str], List[List[int]]]] = (
None # Item text(s) or pre-tokenized token IDs
)
label_token_ids: Optional[List[int]] = (
None # Token IDs to compute probabilities for
)
apply_softmax: bool = False
item_first: bool = False
model: str
class ScoringResponse(BaseModel):
scores: List[
List[float]
] # List of lists of probabilities, each in the order of label_token_ids
model: str
usage: Optional[UsageInfo] = None
object: str = "scoring"
class RerankResponse(BaseModel):
score: float
document: str
index: int
meta_info: Optional[dict] = None
def exclude_if_none(obj, field_names: List[str]):
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}

View File

@@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple, Type
class StreamingParseResult:
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately.
"""
text = text.replace(self.think_start_token, "").strip()
if self.think_end_token not in text:
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
if not in_reasoning:
return StreamingParseResult(normal_text=text)
# The text is considered to be in a reasoning block.
processed_text = text.replace(self.think_start_token, "").strip()
if self.think_end_token not in processed_text:
# Assume reasoning was truncated before `</think>` token
return StreamingParseResult(reasoning_text=text)
return StreamingParseResult(reasoning_text=processed_text)
# Extract reasoning content
splits = text.split(self.think_end_token, maxsplit=1)
splits = processed_text.split(self.think_end_token, maxsplit=1)
reasoning_text = splits[0]
text = splits[1].strip()
normal_text = splits[1].strip()
return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text
)
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
"""
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
if not self.stripped_think_start and self.think_start_token in current_text:
current_text = current_text.replace(self.think_start_token, "")
self.stripped_think_start = True
self._in_reasoning = True
# Handle end of reasoning block
if self._in_reasoning and self.think_end_token in current_text:
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
"""
def __init__(self, stream_reasoning: bool = True):
# Qwen3 is assumed to be reasoning until `</think>` token
# Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
super().__init__(
"<think>",
"</think>",
force_reasoning=True,
force_reasoning=False,
stream_reasoning=stream_reasoning,
)
@@ -151,12 +161,12 @@ class ReasoningParser:
If True, streams reasoning content as it arrives.
"""
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
"deepseek-r1": DeepSeekR1Detector,
"qwen3": Qwen3Detector,
}
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
if not model_type:
raise ValueError("Model type must be specified")

View File

@@ -1,87 +0,0 @@
# sglang/test/srt/openai/conftest.py
import os
import socket
import subprocess
import sys
import tempfile
import time
from contextlib import closing
from typing import Generator
import pytest
import requests
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
def _pick_free_port() -> int:
with closing(socket.socket()) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None:
start = time.perf_counter()
while time.perf_counter() - start < timeout:
if proc.poll() is not None: # crashed
raise RuntimeError("api_server terminated prematurely")
try:
if requests.get(f"{base}/health", timeout=1).status_code == 200:
return
except requests.RequestException:
pass
time.sleep(0.4)
raise RuntimeError("api_server readiness probe timed out")
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
port = _pick_free_port()
cmd = [
sys.executable,
"-m",
SERVER_MODULE,
"--model-path",
model,
"--host",
"127.0.0.1",
"--port",
str(port),
*map(str, kw.get("args", [])),
]
env = {**os.environ, **kw.get("env", {})}
# Write logs to a temp file so the child never blocks on a full pipe.
log_file = tempfile.NamedTemporaryFile("w+", delete=False)
proc = subprocess.Popen(
cmd,
env=env,
stdout=log_file,
stderr=subprocess.STDOUT,
text=True,
)
base = f"http://127.0.0.1:{port}"
try:
_wait_until_healthy(proc, base, STARTUP_TIMEOUT)
except Exception as e:
proc.terminate()
proc.wait(5)
log_file.seek(0)
print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr)
raise e
return proc, base, log_file
@pytest.fixture(scope="session")
def openai_server() -> Generator[str, None, None]:
"""PyTest fixture that provides the server's base URL and cleans up."""
proc, base, log_file = launch_openai_server()
yield base
kill_process_tree(proc.pid)
log_file.close()

View File

@@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import (
class TestModelCard(unittest.TestCase):
"""Test ModelCard protocol model"""
def test_basic_model_card_creation(self):
"""Test basic model card creation with required fields"""
card = ModelCard(id="test-model")
self.assertEqual(card.id, "test-model")
self.assertEqual(card.object, "model")
self.assertEqual(card.owned_by, "sglang")
self.assertIsInstance(card.created, int)
self.assertIsNone(card.root)
self.assertIsNone(card.max_model_len)
def test_model_card_with_optional_fields(self):
"""Test model card with optional fields"""
card = ModelCard(
id="test-model",
root="/path/to/model",
max_model_len=2048,
created=1234567890,
)
self.assertEqual(card.id, "test-model")
self.assertEqual(card.root, "/path/to/model")
self.assertEqual(card.max_model_len, 2048)
self.assertEqual(card.created, 1234567890)
def test_model_card_serialization(self):
"""Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096)
@@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase):
self.assertEqual(model_list.data[1].id, "model-2")
class TestErrorResponse(unittest.TestCase):
"""Test ErrorResponse protocol model"""
def test_basic_error_response(self):
"""Test basic error response creation"""
error = ErrorResponse(
message="Invalid request", type="BadRequestError", code=400
)
self.assertEqual(error.object, "error")
self.assertEqual(error.message, "Invalid request")
self.assertEqual(error.type, "BadRequestError")
self.assertEqual(error.code, 400)
self.assertIsNone(error.param)
def test_error_response_with_param(self):
"""Test error response with parameter"""
error = ErrorResponse(
message="Invalid temperature",
type="ValidationError",
code=422,
param="temperature",
)
self.assertEqual(error.param, "temperature")
class TestUsageInfo(unittest.TestCase):
"""Test UsageInfo protocol model"""
def test_basic_usage_info(self):
"""Test basic usage info creation"""
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
self.assertEqual(usage.prompt_tokens, 10)
self.assertEqual(usage.completion_tokens, 20)
self.assertEqual(usage.total_tokens, 30)
self.assertIsNone(usage.prompt_tokens_details)
def test_usage_info_with_cache_details(self):
"""Test usage info with cache details"""
usage = UsageInfo(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30,
prompt_tokens_details={"cached_tokens": 5},
)
self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
class TestCompletionRequest(unittest.TestCase):
"""Test CompletionRequest protocol model"""
@@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream) # default
self.assertFalse(request.echo) # default
def test_completion_request_with_options(self):
"""Test completion request with various options"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "world"],
max_tokens=100,
temperature=0.7,
top_p=0.9,
n=2,
stream=True,
echo=True,
stop=[".", "!"],
logprobs=5,
)
self.assertEqual(request.prompt, ["Hello", "world"])
self.assertEqual(request.max_tokens, 100)
self.assertEqual(request.temperature, 0.7)
self.assertEqual(request.top_p, 0.9)
self.assertEqual(request.n, 2)
self.assertTrue(request.stream)
self.assertTrue(request.echo)
self.assertEqual(request.stop, [".", "!"])
self.assertEqual(request.logprobs, 5)
def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions"""
request = CompletionRequest(
@@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase):
CompletionRequest(model="test-model") # missing prompt
class TestCompletionResponse(unittest.TestCase):
"""Test CompletionResponse protocol model"""
def test_basic_completion_response(self):
"""Test basic completion response"""
choice = CompletionResponseChoice(
index=0, text="Hello world!", finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = CompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(response.id, "test-id")
self.assertEqual(response.object, "text_completion")
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.choices), 1)
self.assertEqual(response.choices[0].text, "Hello world!")
self.assertEqual(response.usage.total_tokens, 5)
class TestChatCompletionRequest(unittest.TestCase):
"""Test ChatCompletionRequest protocol model"""
@@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream) # default
self.assertEqual(request.tool_choice, "none") # default when no tools
def test_chat_completion_with_multimodal_content(self):
"""Test chat completion with multimodal content"""
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
},
],
}
]
request = ChatCompletionRequest(model="test-model", messages=messages)
self.assertEqual(len(request.messages[0].content), 2)
self.assertEqual(request.messages[0].content[0].type, "text")
self.assertEqual(request.messages[0].content[1].type, "image_url")
def test_chat_completion_with_tools(self):
"""Test chat completion with tools"""
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
request = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
self.assertEqual(len(request.tools), 1)
self.assertEqual(request.tools[0].function.name, "get_weather")
self.assertEqual(request.tool_choice, "auto") # default when tools present
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}]
@@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
class TestChatCompletionResponse(unittest.TestCase):
"""Test ChatCompletionResponse protocol model"""
def test_basic_chat_completion_response(self):
"""Test basic chat completion response"""
message = ChatMessage(role="assistant", content="Hello there!")
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(response.id, "test-id")
self.assertEqual(response.object, "chat.completion")
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.choices), 1)
self.assertEqual(response.choices[0].message.content, "Hello there!")
def test_chat_completion_response_with_tool_calls(self):
"""Test chat completion response with tool calls"""
tool_call = ToolCall(
id="call_123",
function=FunctionResponse(
name="get_weather", arguments='{"location": "San Francisco"}'
),
)
message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="tool_calls"
)
usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(
response.choices[0].message.tool_calls[0].function.name, "get_weather"
)
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
class TestEmbeddingRequest(unittest.TestCase):
"""Test EmbeddingRequest protocol model"""
def test_basic_embedding_request(self):
"""Test basic embedding request"""
request = EmbeddingRequest(model="test-model", input="Hello world")
self.assertEqual(request.model, "test-model")
self.assertEqual(request.input, "Hello world")
self.assertEqual(request.encoding_format, "float") # default
self.assertIsNone(request.dimensions) # default
def test_embedding_request_with_list_input(self):
"""Test embedding request with list input"""
request = EmbeddingRequest(
model="test-model", input=["Hello", "world"], dimensions=512
)
self.assertEqual(request.input, ["Hello", "world"])
self.assertEqual(request.dimensions, 512)
def test_multimodal_embedding_request(self):
"""Test multimodal embedding request"""
multimodal_input = [
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
]
request = EmbeddingRequest(model="test-model", input=multimodal_input)
self.assertEqual(len(request.input), 2)
self.assertEqual(request.input[0].text, "Hello")
self.assertEqual(request.input[0].image, "base64_image_data")
self.assertEqual(request.input[1].text, "World")
self.assertIsNone(request.input[1].image)
class TestEmbeddingResponse(unittest.TestCase):
"""Test EmbeddingResponse protocol model"""
def test_basic_embedding_response(self):
"""Test basic embedding response"""
embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
usage = UsageInfo(prompt_tokens=3, total_tokens=3)
response = EmbeddingResponse(
data=[embedding_obj], model="test-model", usage=usage
)
self.assertEqual(response.object, "list")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.usage.prompt_tokens, 3)
class TestScoringRequest(unittest.TestCase):
"""Test ScoringRequest protocol model"""
def test_basic_scoring_request(self):
"""Test basic scoring request"""
request = ScoringRequest(
model="test-model", query="Hello", items=["World", "Earth"]
)
self.assertEqual(request.model, "test-model")
self.assertEqual(request.query, "Hello")
self.assertEqual(request.items, ["World", "Earth"])
self.assertFalse(request.apply_softmax) # default
self.assertFalse(request.item_first) # default
def test_scoring_request_with_token_ids(self):
"""Test scoring request with token IDs"""
request = ScoringRequest(
model="test-model",
query=[1, 2, 3],
items=[[4, 5], [6, 7]],
label_token_ids=[8, 9],
apply_softmax=True,
item_first=True,
)
self.assertEqual(request.query, [1, 2, 3])
self.assertEqual(request.items, [[4, 5], [6, 7]])
self.assertEqual(request.label_token_ids, [8, 9])
self.assertTrue(request.apply_softmax)
self.assertTrue(request.item_first)
class TestScoringResponse(unittest.TestCase):
"""Test ScoringResponse protocol model"""
def test_basic_scoring_response(self):
"""Test basic scoring response"""
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
self.assertEqual(response.object, "scoring")
self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
self.assertEqual(response.model, "test-model")
self.assertIsNone(response.usage) # default
class TestFileOperations(unittest.TestCase):
"""Test file operation protocol models"""
def test_file_request(self):
"""Test file request model"""
file_data = b"test file content"
request = FileRequest(file=file_data, purpose="batch")
self.assertEqual(request.file, file_data)
self.assertEqual(request.purpose, "batch")
def test_file_response(self):
"""Test file response model"""
response = FileResponse(
id="file-123",
bytes=1024,
created_at=1234567890,
filename="test.jsonl",
purpose="batch",
)
self.assertEqual(response.id, "file-123")
self.assertEqual(response.object, "file")
self.assertEqual(response.bytes, 1024)
self.assertEqual(response.filename, "test.jsonl")
def test_file_delete_response(self):
"""Test file delete response model"""
response = FileDeleteResponse(id="file-123", deleted=True)
self.assertEqual(response.id, "file-123")
self.assertEqual(response.object, "file")
self.assertTrue(response.deleted)
class TestBatchOperations(unittest.TestCase):
"""Test batch operation protocol models"""
def test_batch_request(self):
"""Test batch request model"""
request = BatchRequest(
input_file_id="file-123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"custom": "value"},
)
self.assertEqual(request.input_file_id, "file-123")
self.assertEqual(request.endpoint, "/v1/chat/completions")
self.assertEqual(request.completion_window, "24h")
self.assertEqual(request.metadata, {"custom": "value"})
def test_batch_response(self):
"""Test batch response model"""
response = BatchResponse(
id="batch-123",
endpoint="/v1/chat/completions",
input_file_id="file-123",
completion_window="24h",
created_at=1234567890,
)
self.assertEqual(response.id, "batch-123")
self.assertEqual(response.object, "batch")
self.assertEqual(response.status, "validating") # default
self.assertEqual(response.endpoint, "/v1/chat/completions")
class TestResponseFormats(unittest.TestCase):
"""Test response format protocol models"""
def test_basic_response_format(self):
"""Test basic response format"""
format_obj = ResponseFormat(type="json_object")
self.assertEqual(format_obj.type, "json_object")
self.assertIsNone(format_obj.json_schema)
def test_json_schema_response_format(self):
"""Test JSON schema response format"""
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
json_schema = JsonSchemaResponseFormat(
name="person_schema", description="Person schema", schema=schema
)
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
self.assertEqual(format_obj.type, "json_schema")
self.assertEqual(format_obj.json_schema.name, "person_schema")
self.assertEqual(format_obj.json_schema.schema_, schema)
def test_structural_tag_response_format(self):
"""Test structural tag response format"""
structures = [
{
"begin": "<thinking>",
"schema_": {"type": "string"},
"end": "</thinking>",
}
]
format_obj = StructuralTagResponseFormat(
type="structural_tag", structures=structures, triggers=["think"]
)
self.assertEqual(format_obj.type, "structural_tag")
self.assertEqual(len(format_obj.structures), 1)
self.assertEqual(format_obj.triggers, ["think"])
class TestLogProbs(unittest.TestCase):
"""Test LogProbs protocol models"""
def test_basic_logprobs(self):
"""Test basic LogProbs model"""
logprobs = LogProbs(
text_offset=[0, 5, 11],
token_logprobs=[-0.1, -0.2, -0.3],
tokens=["Hello", " ", "world"],
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
)
self.assertEqual(len(logprobs.tokens), 3)
self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
def test_choice_logprobs(self):
"""Test ChoiceLogprobs model"""
token_logprob = ChatCompletionTokenLogprob(
token="Hello",
bytes=[72, 101, 108, 108, 111],
logprob=-0.1,
top_logprobs=[
TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
],
)
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
self.assertEqual(len(choice_logprobs.content), 1)
self.assertEqual(choice_logprobs.content[0].token, "Hello")
class TestStreamingModels(unittest.TestCase):
"""Test streaming response models"""
def test_stream_options(self):
"""Test StreamOptions model"""
options = StreamOptions(include_usage=True)
self.assertTrue(options.include_usage)
def test_chat_completion_stream_response(self):
"""Test ChatCompletionStreamResponse model"""
delta = DeltaMessage(role="assistant", content="Hello")
choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
response = ChatCompletionStreamResponse(
id="test-id", model="test-model", choices=[choice]
)
self.assertEqual(response.object, "chat.completion.chunk")
self.assertEqual(response.choices[0].delta.content, "Hello")
class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""
@@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase):
class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios"""
def test_empty_messages_validation(self):
"""Test validation with empty messages"""
with self.assertRaises(ValidationError):
ChatCompletionRequest(model="test-model", messages=[])
def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}]
@@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase):
with self.assertRaises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
def test_invalid_temperature_range(self):
"""Test invalid temperature values"""
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
self.assertEqual(request.temperature, 5.0) # Currently allowed
def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized"""
original_request = ChatCompletionRequest(

View File

@@ -1,52 +0,0 @@
# sglang/test/srt/openai/test_server.py
import requests
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
def test_health(openai_server: str):
r = requests.get(f"{openai_server}/health")
assert r.status_code == 200
# FastAPI returns an empty body → r.text == ""
assert r.text == ""
def test_models_endpoint(openai_server: str):
r = requests.get(f"{openai_server}/v1/models")
assert r.status_code == 200, r.text
payload = r.json()
# Basic contract
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
# Validate fields of the first model card
first = payload["data"][0]
for key in ("id", "root", "max_model_len"):
assert key in first, f"missing {key} in {first}"
# max_model_len must be positive
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
# The server should report the same model id we launched it with
ids = {m["id"] for m in payload["data"]}
assert MODEL_ID in ids
def test_get_model_info(openai_server: str):
r = requests.get(f"{openai_server}/get_model_info")
assert r.status_code == 200, r.text
info = r.json()
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
assert expected_keys.issubset(info.keys())
# model_path must end with the one we passed on the CLI
assert info["model_path"].endswith(MODEL_ID)
# is_generation is documented as a boolean
assert isinstance(info["is_generation"], bool)
def test_unknown_route_returns_404(openai_server: str):
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
assert r.status_code == 404

View File

@@ -57,11 +57,21 @@ class _MockTokenizerManager:
self.create_abort_task = Mock()
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = "llama-3"
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = None
class ServingChatTestCase(unittest.TestCase):
# ------------- common fixtures -------------
def setUp(self):
self.tm = _MockTokenizerManager()
self.chat = OpenAIServingChat(self.tm)
self.template_manager = _MockTemplateManager()
self.chat = OpenAIServingChat(self.tm, self.template_manager)
# frequently reused requests
self.basic_req = ChatCompletionRequest(
@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req)
# # ------------- tool-call branch -------------
# def test_tool_call_request_conversion(self):
# req = ChatCompletionRequest(
# model="x",
# messages=[{"role": "user", "content": "Weather?"}],
# tools=[
# {
# "type": "function",
# "function": {
# "name": "get_weather",
# "parameters": {"type": "object", "properties": {}},
# },
# }
# ],
# tool_choice="auto",
# )
# with patch.object(
# self.chat,
# "_process_messages",
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
# ):
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
# self.assertEqual(adapted.rid, "rid")
# def test_tool_choice_none(self):
# req = ChatCompletionRequest(
# model="x",
# messages=[{"role": "user", "content": "Hi"}],
# tools=[{"type": "function", "function": {"name": "noop"}}],
# tool_choice="none",
# )
# with patch.object(
# self.chat,
# "_process_messages",
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
# ):
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
# self.assertEqual(adapted.rid, "rid")
# ------------- multimodal branch -------------
def test_multimodal_request_with_images(self):
self.tm.model_config.is_multimodal = True
req = ChatCompletionRequest(
model="x",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in the image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,"},
},
],
}
],
)
with patch.object(
self.chat,
"_apply_jinja_template",
return_value=("prompt", [1, 2], ["img"], None, [], []),
), patch.object(
self.chat,
"_apply_conversation_template",
return_value=("prompt", ["img"], None, [], []),
):
out = self.chat._process_messages(req, True)
_, _, image_data, *_ = out
self.assertEqual(image_data, ["img"])
# ------------- template handling -------------
def test_jinja_template_processing(self):
req = ChatCompletionRequest(
model="x", messages=[{"role": "user", "content": "Hello"}]
)
self.tm.chat_template_name = None
self.tm.tokenizer.chat_template = "<jinja>"
with patch.object(
self.chat,
"_apply_jinja_template",
return_value=("processed", [1], None, None, [], ["</s>"]),
), patch("builtins.hasattr", return_value=True):
prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
self.assertEqual(prompt, "processed")
self.assertEqual(prompt_ids, [1])
# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(

View File

@@ -5,6 +5,7 @@ Run with:
"""
import unittest
from typing import Optional
from unittest.mock import AsyncMock, Mock, patch
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
from sglang.srt.managers.tokenizer_manager import TokenizerManager
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = None
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = (
None # Set to None to avoid template processing
)
class ServingCompletionTestCase(unittest.TestCase):
"""Bundle all prompt/echo tests in one TestCase."""
@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
tm.generate_request = AsyncMock()
tm.create_abort_task = Mock()
self.sc = OpenAIServingCompletion(tm)
self.template_manager = _MockTemplateManager()
self.sc = OpenAIServingCompletion(tm, self.template_manager)
# ---------- prompt-handling ----------
def test_single_string_prompt(self):
@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
def test_completion_template_handling(self):
req = CompletionRequest(
model="x", prompt="def f():", suffix="return 1", max_tokens=100
)
with patch(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
return_value=True,
), patch(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt",
):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "processed_prompt")
# ---------- echo-handling ----------
def test_echo_with_string_prompt_streaming(self):
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)

View File

@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
with the original adapter.py functionality and follows OpenAI API specifications.
"""
import asyncio
import json
import time
import unittest
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import Mock
from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
MultimodalEmbeddingInput,
UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
@@ -58,11 +49,22 @@ class _MockTokenizerManager:
self.generate_request = Mock(return_value=mock_generate_embedding())
# Mock TemplateManager for embedding tests
class _MockTemplateManager:
def __init__(self):
self.chat_template_name = None # None for embeddings usually
self.jinja_template_content_format = None
self.completion_template_name = None
class ServingEmbeddingTestCase(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
self.tokenizer_manager = _MockTokenizerManager()
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
self.template_manager = _MockTemplateManager()
self.serving_embedding = OpenAIServingEmbedding(
self.tokenizer_manager, self.template_manager
)
self.request = Mock(spec=Request)
self.request.headers = {}
@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIsNone(adapted_request.image_data[1])
# self.assertEqual(adapted_request.rid, "test-id")
def test_build_single_embedding_response(self):
"""Test building response for single embedding."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
]
response = self.serving_embedding._build_embedding_response(ret_data)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[0].object, "embedding")
self.assertEqual(response.usage.prompt_tokens, 5)
self.assertEqual(response.usage.total_tokens, 5)
self.assertEqual(response.usage.completion_tokens, 0)
def test_build_multiple_embedding_response(self):
"""Test building response for multiple embeddings."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3],
"meta_info": {"prompt_tokens": 3},
},
{
"embedding": [0.4, 0.5, 0.6],
"meta_info": {"prompt_tokens": 4},
},
]
response = self.serving_embedding._build_embedding_response(ret_data)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 2)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
self.assertEqual(response.data[1].index, 1)
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
self.assertEqual(response.usage.total_tokens, 7)
def test_handle_request_success(self):
"""Test successful embedding request handling."""
async def run_test():
# Mock the generate_request to return expected data
async def mock_generate():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate()
)
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
asyncio.run(run_test())
def test_handle_request_validation_error(self):
"""Test handling request with validation error."""
async def run_test():
invalid_request = EmbeddingRequest(model="test-model", input="")
response = await self.serving_embedding.handle_request(
invalid_request, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400)
asyncio.run(run_test())
def test_handle_request_generation_error(self):
"""Test handling request with generation error."""
async def run_test():
# Mock generate_request to raise an error
async def mock_generate_error():
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error()
)
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400)
asyncio.run(run_test())
def test_handle_request_internal_error(self):
"""Test handling request with internal server error."""
async def run_test():
# Mock _convert_to_internal_request to raise an exception
with patch.object(
self.serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 500)
asyncio.run(run_test())
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -29,6 +29,10 @@ suites = {
TestFile("models/test_reward_models.py", 132),
TestFile("models/test_vlm_models.py", 437),
TestFile("models/test_transformers_models.py", 320),
TestFile("openai/test_protocol.py", 10),
TestFile("openai/test_serving_chat.py", 10),
TestFile("openai/test_serving_completions.py", 10),
TestFile("openai/test_serving_embedding.py", 10),
TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22),
TestFile("test_create_kvindices.py", 2),
@@ -49,6 +53,7 @@ suites = {
TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 8),
TestFile("test_input_embeddings.py", 38),
TestFile("test_jinja_template_utils.py", 1),
TestFile("test_json_constrained.py", 98),
TestFile("test_large_max_new_tokens.py", 41),
TestFile("test_metrics.py", 32),
@@ -59,14 +64,8 @@ suites = {
TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
TestFile("openai/test_server.py", 120),
TestFile("openai/test_protocol.py", 60),
TestFile("openai/test_serving_chat.py", 120),
TestFile("openai/test_serving_completions.py", 120),
TestFile("openai/test_serving_embedding.py", 120),
TestFile("test_openai_server_hidden_states.py", 240),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),

View File

@@ -3,6 +3,7 @@ import unittest
from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
@@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import Function, Tool
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST

View File

@@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils.
import unittest
from unittest.mock import patch
from sglang.srt.openai_api.utils import (
detect_template_content_format,
from sglang.srt.jinja_template_utils import (
detect_jinja_template_content_format,
process_content_for_template_format,
)
from sglang.test.test_utils import CustomTestCase
@@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
result = detect_template_content_format(llama4_pattern)
result = detect_jinja_template_content_format(llama4_pattern)
self.assertEqual(result, "openai")
def test_detect_deepseek_string_format(self):
@@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
result = detect_template_content_format(deepseek_pattern)
result = detect_jinja_template_content_format(deepseek_pattern)
self.assertEqual(result, "string")
def test_detect_invalid_template(self):
"""Test handling of invalid template (should default to 'string')."""
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
result = detect_template_content_format(invalid_pattern)
result = detect_jinja_template_content_format(invalid_pattern)
self.assertEqual(result, "string")
def test_detect_empty_template(self):
"""Test handling of empty template (should default to 'string')."""
result = detect_template_content_format("")
result = detect_jinja_template_content_format("")
self.assertEqual(result, "string")
def test_process_content_openai_format(self):

View File

@@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase):
)
is_firsts = {}
is_finished = {}
for response in generator:
usage = response.usage
if usage is not None:
@@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase):
continue
index = response.choices[0].index
finish_reason = response.choices[0].finish_reason
if finish_reason is not None:
is_finished[index] = True
data = response.choices[0].delta
if is_firsts.get(index, True):
@@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase):
is_firsts[index] = False
continue
if logprobs:
if logprobs and not is_finished.get(index, False):
assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
@@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase):
assert (
isinstance(data.content, str)
or isinstance(data.reasoning_content, str)
or len(data.tool_calls) > 0
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
or response.choices[0].finish_reason
)
assert response.id
@@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase):
index, True
), f"index {index} is not found in the response"
def _create_batch(self, mode, client):
if mode == "completion":
input_file_path = "complete_input.jsonl"
# write content to input file
content = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 3 names of famous soccer player: ",
"max_tokens": 20,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 6 names of famous basketball player: ",
"max_tokens": 40,
},
},
{
"custom_id": "request-3",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 6 names of famous tenniss player: ",
"max_tokens": 40,
},
},
]
else:
input_file_path = "chat_input.jsonl"
content = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo-0125",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "Hello! List 3 NBA players and tell a story",
},
],
"max_tokens": 30,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo-0125",
"messages": [
{"role": "system", "content": "You are an assistant. "},
{
"role": "user",
"content": "Hello! List three capital and tell a story",
},
],
"max_tokens": 50,
},
},
]
with open(input_file_path, "w") as file:
for line in content:
file.write(json.dumps(line) + "\n")
with open(input_file_path, "rb") as file:
uploaded_file = client.files.create(file=file, purpose="batch")
if mode == "completion":
endpoint = "/v1/completions"
elif mode == "chat":
endpoint = "/v1/chat/completions"
completion_window = "24h"
batch_job = client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
return batch_job, content, uploaded_file
def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = client.batches.retrieve(batch_job.id)
assert (
batch_job.status == "completed"
), f"Batch job status is not completed: {batch_job.status}"
assert batch_job.request_counts.completed == len(content)
assert batch_job.request_counts.failed == 0
assert batch_job.request_counts.total == len(content)
result_file_id = batch_job.output_file_id
file_response = client.files.content(result_file_id)
result_content = file_response.read().decode("utf-8") # Decode bytes to string
results = [
json.loads(line)
for line in result_content.split("\n")
if line.strip() != ""
]
assert len(results) == len(content)
for delete_fid in [uploaded_file.id, result_file_id]:
del_pesponse = client.files.delete(delete_fid)
assert del_pesponse.deleted
def run_cancel_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
assert batch_job.status not in ["cancelling", "cancelled"]
batch_job = client.batches.cancel(batch_id=batch_job.id)
assert batch_job.status == "cancelling"
while batch_job.status not in ["failed", "cancelled"]:
batch_job = client.batches.retrieve(batch_job.id)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
time.sleep(3)
assert batch_job.status == "cancelled"
del_response = client.files.delete(uploaded_file.id)
assert del_response.deleted
def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
@@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase):
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num)
def test_batch(self):
for mode in ["completion", "chat"]:
self.run_batch(mode)
def test_cancel_batch(self):
for mode in ["completion", "chat"]:
self.run_cancel_batch(mode)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
@@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white
assert len(models) == 1
assert isinstance(getattr(models[0], "max_model_len", None), int)
def test_retrieve_model(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Test retrieving an existing model
retrieved_model = client.models.retrieve(self.model)
self.assertEqual(retrieved_model.id, self.model)
self.assertEqual(retrieved_model.root, self.model)
# Test retrieving a non-existent model
with self.assertRaises(openai.NotFoundError):
client.models.retrieve("non-existent-model")
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
@@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase):
self.assertTrue(len(response.data[0].embedding) > 0)
self.assertTrue(len(response.data[1].embedding) > 0)
def test_embedding_single_batch_str(self):
"""Test embedding with a List[str] and length equals to 1"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(model=self.model, input=["Hello world"])
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_embedding_single_int_list(self):
"""Test embedding with a List[int] or List[List[int]]]"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
)
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
)
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_empty_string_embedding(self):
"""Test embedding an empty string."""

View File

@@ -21,6 +21,7 @@ from transformers import (
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
@@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import (
MultimodalInputs,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs

View File

@@ -15,7 +15,7 @@ from transformers import (
from sglang import Engine
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"