diff --git a/benchmark/hicache/data_processing.py b/benchmark/hicache/data_processing.py
index fcc44086e..0152406a8 100644
--- a/benchmark/hicache/data_processing.py
+++ b/benchmark/hicache/data_processing.py
@@ -20,7 +20,7 @@ from sglang.bench_serving import (
get_gen_prefix_cache_path,
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
-from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
+from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos
diff --git a/docs/backend/openai_api_embeddings.ipynb b/docs/backend/openai_api_embeddings.ipynb
index e4a40cd5c..a8828daac 100644
--- a/docs/backend/openai_api_embeddings.ipynb
+++ b/docs/backend/openai_api_embeddings.ipynb
@@ -64,11 +64,14 @@
"text = \"Once upon a time\"\n",
"\n",
"curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
+ " -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
"\n",
- "text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n",
- " \"embedding\"\n",
- "]\n",
+ "result = subprocess.check_output(curl_text, shell=True)\n",
+ "\n",
+ "print(result)\n",
+ "\n",
+ "text_embedding = json.loads(result)[\"data\"][0][\"embedding\"]\n",
"\n",
"print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")"
]
@@ -152,6 +155,7 @@
"input_ids = tokenizer.encode(text)\n",
"\n",
"curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
+ " -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
"\n",
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",
diff --git a/docs/backend/openai_api_vision.ipynb b/docs/backend/openai_api_vision.ipynb
index 0c80fdc0d..f183fd8a0 100644
--- a/docs/backend/openai_api_vision.ipynb
+++ b/docs/backend/openai_api_vision.ipynb
@@ -67,6 +67,7 @@
"\n",
"curl_command = f\"\"\"\n",
"curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
+ " -H \"Content-Type: application/json\" \\\\\n",
" -d '{{\n",
" \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
" \"messages\": [\n",
diff --git a/docs/backend/vlm_query.ipynb b/docs/backend/vlm_query.ipynb
index 519811f75..b47d55580 100644
--- a/docs/backend/vlm_query.ipynb
+++ b/docs/backend/vlm_query.ipynb
@@ -36,7 +36,7 @@
"import requests\n",
"from PIL import Image\n",
"\n",
- "from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
+ "from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest\n",
"from sglang.srt.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",
diff --git a/python/sglang/srt/code_completion_parser.py b/python/sglang/srt/code_completion_parser.py
index 5b32d8fb6..0067ac471 100644
--- a/python/sglang/srt/code_completion_parser.py
+++ b/python/sglang/srt/code_completion_parser.py
@@ -15,9 +15,7 @@
import dataclasses
-import json
import logging
-import os
from enum import auto
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
@@ -57,46 +55,6 @@ class CompletionTemplate:
completion_templates: dict[str, CompletionTemplate] = {}
-def load_completion_template_for_openai_api(completion_template_arg):
- global completion_template_name
-
- logger.info(
- f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
- )
-
- if not completion_template_exists(completion_template_arg):
- if not os.path.exists(completion_template_arg):
- raise RuntimeError(
- f"Completion template {completion_template_arg} is not a built-in template name "
- "or a valid completion template file path."
- )
-
- assert completion_template_arg.endswith(
- ".json"
- ), "unrecognized format of completion template file"
- with open(completion_template_arg, "r") as filep:
- template = json.load(filep)
- try:
- fim_position = FimPosition[template["fim_position"]]
- except KeyError:
- raise ValueError(
- f"Unknown fim position: {template['fim_position']}"
- ) from None
- register_completion_template(
- CompletionTemplate(
- name=template["name"],
- fim_begin_token=template["fim_begin_token"],
- fim_middle_token=template["fim_middle_token"],
- fim_end_token=template["fim_end_token"],
- fim_position=fim_position,
- ),
- override=True,
- )
- completion_template_name = template["name"]
- else:
- completion_template_name = completion_template_arg
-
-
def register_completion_template(template: CompletionTemplate, override: bool = False):
"""Register a new completion template."""
if not override:
diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py
index ec5765a1f..661f47700 100644
--- a/python/sglang/srt/conversation.py
+++ b/python/sglang/srt/conversation.py
@@ -11,7 +11,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Conversation chat templates."""
+"""Conversation chat templates.
+
+This module provides conversation template definitions, data structures, and utilities
+for managing chat templates across different model types in SGLang.
+
+Key components:
+- Conversation class: Defines the structure and behavior of chat templates
+- SeparatorStyle enum: Different conversation formatting styles
+- Template registry: Functions to register and retrieve templates by name or model path
+- Built-in templates: Pre-defined templates for popular models
+"""
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -20,7 +30,7 @@ import re
from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union
-from sglang.srt.openai_api.protocol import ChatCompletionRequest
+from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.utils import read_system_prompt_from_file
@@ -618,7 +628,7 @@ def generate_chat_conv(
# llama2 template
-# reference: https://huggingface.co/blog/codellama#conversational-instructions
+# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
register_conv_template(
Conversation(
diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py
index 4c88a5289..49bb76dba 100644
--- a/python/sglang/srt/entrypoints/engine.py
+++ b/python/sglang/srt/entrypoints/engine.py
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import torch
import uvloop
-from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
from sglang.srt.entrypoints.EngineBase import EngineBase
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
+from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
-from sglang.srt.openai_api.adapter import (
- guess_chat_template_name_from_model_path,
- load_chat_template_for_openai_api,
-)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
@@ -123,12 +119,13 @@ class Engine(EngineBase):
logger.info(f"{server_args=}")
# Launch subprocesses
- tokenizer_manager, scheduler_info = _launch_subprocesses(
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
port_args=port_args,
)
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
+ self.template_manager = template_manager
self.scheduler_info = scheduler_info
context = zmq.Context(2)
@@ -647,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None
-) -> Tuple[TokenizerManager, Dict]:
+) -> Tuple[TokenizerManager, TemplateManager, Dict]:
"""
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
@@ -732,7 +729,7 @@ def _launch_subprocesses(
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here.
- return None, None
+ return None, None, None
launch_dummy_health_check_server(server_args.host, server_args.port)
@@ -741,7 +738,7 @@ def _launch_subprocesses(
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
- return None, None
+ return None, None, None
# Launch detokenizer process
detoken_proc = mp.Process(
@@ -755,15 +752,15 @@ def _launch_subprocesses(
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
- if server_args.chat_template:
- load_chat_template_for_openai_api(
- tokenizer_manager, server_args.chat_template, server_args.model_path
- )
- else:
- guess_chat_template_name_from_model_path(server_args.model_path)
- if server_args.completion_template:
- load_completion_template_for_openai_api(server_args.completion_template)
+ # Initialize templates
+ template_manager = TemplateManager()
+ template_manager.initialize_templates(
+ tokenizer_manager=tokenizer_manager,
+ model_path=server_args.model_path,
+ chat_template=server_args.chat_template,
+ completion_template=server_args.completion_template,
+ )
# Wait for the model to finish loading
scheduler_infos = []
@@ -787,4 +784,4 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
- return tokenizer_manager, scheduler_info
+ return tokenizer_manager, template_manager, scheduler_info
diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py
index 9262d10a9..b3e646147 100644
--- a/python/sglang/srt/entrypoints/http_server.py
+++ b/python/sglang/srt/entrypoints/http_server.py
@@ -38,7 +38,8 @@ import orjson
import requests
import uvicorn
import uvloop
-from fastapi import FastAPI, File, Form, Request, UploadFile
+from fastapi import Depends, FastAPI, Request, UploadFile
+from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
@@ -47,6 +48,20 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses
+from sglang.srt.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ CompletionRequest,
+ EmbeddingRequest,
+ ModelCard,
+ ModelList,
+ ScoringRequest,
+ V1RerankReqInput,
+)
+from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
+from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
+from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
+from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
+from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
AbortReq,
@@ -67,26 +82,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
- V1RerankReqInput,
VertexGenerateReqInput,
)
+from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
-from sglang.srt.openai_api.adapter import (
- v1_batches,
- v1_cancel_batch,
- v1_chat_completions,
- v1_completions,
- v1_delete_file,
- v1_embeddings,
- v1_files_create,
- v1_rerank,
- v1_retrieve_batch,
- v1_retrieve_file,
- v1_retrieve_file_content,
- v1_score,
-)
-from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
@@ -109,6 +109,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class _GlobalState:
tokenizer_manager: TokenizerManager
+ template_manager: TemplateManager
scheduler_info: Dict
@@ -123,6 +124,24 @@ def set_global_state(global_state: _GlobalState):
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
+
+ # Initialize OpenAI serving handlers
+ fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
+ _global_state.tokenizer_manager, _global_state.template_manager
+ )
+ fast_api_app.state.openai_serving_chat = OpenAIServingChat(
+ _global_state.tokenizer_manager, _global_state.template_manager
+ )
+ fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
+ _global_state.tokenizer_manager, _global_state.template_manager
+ )
+ fast_api_app.state.openai_serving_score = OpenAIServingScore(
+ _global_state.tokenizer_manager
+ )
+ fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
+ _global_state.tokenizer_manager
+ )
+
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
@@ -148,6 +167,36 @@ app.add_middleware(
allow_headers=["*"],
)
+
+# Custom exception handlers to change validation error status codes
+@app.exception_handler(RequestValidationError)
+async def validation_exception_handler(request: Request, exc: RequestValidationError):
+ """Override FastAPI's default 422 validation error with 400"""
+ return ORJSONResponse(
+ status_code=400,
+ content={
+ "detail": exc.errors(),
+ "body": exc.body,
+ },
+ )
+
+
+async def validate_json_request(raw_request: Request):
+ """Validate that the request content-type is application/json."""
+ content_type = raw_request.headers.get("content-type", "").lower()
+ media_type = content_type.split(";", maxsplit=1)[0]
+ if media_type != "application/json":
+ raise RequestValidationError(
+ errors=[
+ {
+ "loc": ["header", "content-type"],
+ "msg": "Unsupported Media Type: Only 'application/json' is allowed",
+ "type": "value_error",
+ }
+ ]
+ )
+
+
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
@@ -330,13 +379,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e)
-@app.api_route("/v1/rerank", methods=["POST", "PUT"])
-async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
- try:
- ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
- return ret
- except ValueError as e:
- return _create_error_response(e)
+@app.api_route(
+ "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
+)
+async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
+ """Endpoint for reranking documents based on query relevance."""
+ return await raw_request.app.state.openai_serving_rerank.handle_request(
+ request, raw_request
+ )
@app.api_route("/flush_cache", methods=["GET", "POST"])
@@ -619,25 +669,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
##### OpenAI-compatible API endpoints #####
-@app.post("/v1/completions")
-async def openai_v1_completions(raw_request: Request):
- return await v1_completions(_global_state.tokenizer_manager, raw_request)
+@app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
+async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
+ """OpenAI-compatible text completion endpoint."""
+ return await raw_request.app.state.openai_serving_completion.handle_request(
+ request, raw_request
+ )
-@app.post("/v1/chat/completions")
-async def openai_v1_chat_completions(raw_request: Request):
- return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
+@app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
+async def openai_v1_chat_completions(
+ request: ChatCompletionRequest, raw_request: Request
+):
+ """OpenAI-compatible chat completion endpoint."""
+ return await raw_request.app.state.openai_serving_chat.handle_request(
+ request, raw_request
+ )
-@app.post("/v1/embeddings", response_class=ORJSONResponse)
-async def openai_v1_embeddings(raw_request: Request):
- response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
- return response
+@app.post(
+ "/v1/embeddings",
+ response_class=ORJSONResponse,
+ dependencies=[Depends(validate_json_request)],
+)
+async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
+ """OpenAI-compatible embeddings endpoint."""
+ return await raw_request.app.state.openai_serving_embedding.handle_request(
+ request, raw_request
+ )
@app.get("/v1/models", response_class=ORJSONResponse)
-def available_models():
- """Show available models."""
+async def available_models():
+ """Show available models. OpenAI-compatible endpoint."""
served_model_names = [_global_state.tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
@@ -651,47 +715,31 @@ def available_models():
return ModelList(data=model_cards)
-@app.post("/v1/files")
-async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
- return await v1_files_create(
- file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
+@app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
+async def retrieve_model(model: str):
+ """Retrieves a model instance, providing basic information about the model."""
+ served_model_names = [_global_state.tokenizer_manager.served_model_name]
+
+ if model not in served_model_names:
+ return ORJSONResponse(
+ status_code=404,
+ content={
+ "error": {
+ "message": f"The model '{model}' does not exist",
+ "type": "invalid_request_error",
+ "param": "model",
+ "code": "model_not_found",
+ }
+ },
+ )
+
+ return ModelCard(
+ id=model,
+ root=model,
+ max_model_len=_global_state.tokenizer_manager.model_config.context_len,
)
-@app.delete("/v1/files/{file_id}")
-async def delete_file(file_id: str):
- # https://platform.openai.com/docs/api-reference/files/delete
- return await v1_delete_file(file_id)
-
-
-@app.post("/v1/batches")
-async def openai_v1_batches(raw_request: Request):
- return await v1_batches(_global_state.tokenizer_manager, raw_request)
-
-
-@app.post("/v1/batches/{batch_id}/cancel")
-async def cancel_batches(batch_id: str):
- # https://platform.openai.com/docs/api-reference/batch/cancel
- return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
-
-
-@app.get("/v1/batches/{batch_id}")
-async def retrieve_batch(batch_id: str):
- return await v1_retrieve_batch(batch_id)
-
-
-@app.get("/v1/files/{file_id}")
-async def retrieve_file(file_id: str):
- # https://platform.openai.com/docs/api-reference/files/retrieve
- return await v1_retrieve_file(file_id)
-
-
-@app.get("/v1/files/{file_id}/content")
-async def retrieve_file_content(file_id: str):
- # https://platform.openai.com/docs/api-reference/files/retrieve-contents
- return await v1_retrieve_file_content(file_id)
-
-
## SageMaker API
@app.get("/ping")
async def sagemaker_health() -> Response:
@@ -700,8 +748,13 @@ async def sagemaker_health() -> Response:
@app.post("/invocations")
-async def sagemaker_chat_completions(raw_request: Request):
- return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
+async def sagemaker_chat_completions(
+ request: ChatCompletionRequest, raw_request: Request
+):
+ """OpenAI-compatible chat completion endpoint."""
+ return await raw_request.app.state.openai_serving_chat.handle_request(
+ request, raw_request
+ )
## Vertex AI API
@@ -732,10 +785,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
return ORJSONResponse({"predictions": ret})
-@app.post("/v1/score")
-async def v1_score_request(raw_request: Request):
+@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
+async def v1_score_request(request: ScoringRequest, raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
- return await v1_score(_global_state.tokenizer_manager, raw_request)
+ return await raw_request.app.state.openai_serving_score.handle_request(
+ request, raw_request
+ )
def _create_error_response(e):
@@ -764,10 +819,13 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
- tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
+ server_args=server_args
+ )
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
+ template_manager=template_manager,
scheduler_info=scheduler_info,
)
)
diff --git a/python/sglang/srt/entrypoints/openai/api_server.py b/python/sglang/srt/entrypoints/openai/api_server.py
deleted file mode 100644
index a31643395..000000000
--- a/python/sglang/srt/entrypoints/openai/api_server.py
+++ /dev/null
@@ -1,388 +0,0 @@
-# Copyright 2023-2024 SGLang Team
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""
-SGLang OpenAI-Compatible API Server.
-
-This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI.
-"""
-
-import argparse
-import asyncio
-import logging
-import multiprocessing
-import os
-import threading
-import time
-from contextlib import asynccontextmanager
-from typing import Callable, Dict, Optional
-
-import numpy as np
-import requests
-import uvicorn
-import uvloop
-from fastapi import FastAPI, Request
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import Response
-
-from sglang.srt.disaggregation.utils import (
- FAKE_BOOTSTRAP_HOST,
- register_disaggregation_server,
-)
-from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
-from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
-from sglang.srt.managers.tokenizer_manager import TokenizerManager
-from sglang.srt.metrics.func_timer import enable_func_timer
-from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
-from sglang.srt.server_args import ServerArgs
-from sglang.srt.utils import (
- add_prometheus_middleware,
- delete_directory,
- get_bool_env_var,
- kill_process_tree,
- set_uvicorn_logging_configs,
-)
-from sglang.srt.warmup import execute_warmups
-from sglang.utils import get_exception_traceback
-
-logger = logging.getLogger(__name__)
-asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
-
-
-# Store global states
-class AppState:
- engine: Optional[Engine] = None
- server_args: Optional[ServerArgs] = None
- tokenizer_manager: Optional[TokenizerManager] = None
- scheduler_info: Optional[Dict] = None
- embedding_server: Optional[OpenAIServingEmbedding] = None
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- app.state.server_args.enable_metrics = True # By default, we enable metrics
-
- server_args = app.state.server_args
-
- # Initialize engine
- logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...")
-
- tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
- app.state.tokenizer_manager = tokenizer_manager
- app.state.scheduler_info = scheduler_info
- app.state.serving_embedding = OpenAIServingEmbedding(
- tokenizer_manager=tokenizer_manager
- )
-
- if server_args.enable_metrics:
- add_prometheus_middleware(app)
- enable_func_timer()
-
- # Initialize engine state attribute to None for now
- app.state.engine = None
-
- if server_args.warmups is not None:
- await execute_warmups(
- server_args.warmups.split(","), app.state.tokenizer_manager
- )
- logger.info("Warmup ended")
-
- warmup_thread = getattr(app, "warmup_thread", None)
- if warmup_thread is not None:
- warmup_thread.start()
-
- yield
-
- # Lifespan shutdown
- if hasattr(app.state, "engine") and app.state.engine is not None:
- logger.info("SGLang engine is shutting down.")
- # Add engine cleanup logic here when implemented
-
-
-# Fast API app with CORS enabled
-app = FastAPI(
- lifespan=lifespan,
- # TODO: check where /openai.json is created or why we use this
- openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
-)
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
-
-
-@app.api_route("/health", methods=["GET"])
-async def health() -> Response:
- """Health check. Used for readiness and liveness probes."""
- # In the future, this could check engine health more deeply
- # For now, if the server is up, it's healthy.
- return Response(status_code=200)
-
-
-@app.api_route("/v1/models", methods=["GET"])
-async def show_models():
- """Show available models. Currently, it returns the served model name.
-
- This endpoint is compatible with the OpenAI API standard.
- """
- served_model_names = [app.state.tokenizer_manager.served_model_name]
- model_cards = []
- for served_model_name in served_model_names:
- model_cards.append(
- ModelCard(
- id=served_model_name,
- root=served_model_name,
- max_model_len=app.state.tokenizer_manager.model_config.context_len,
- )
- )
- return ModelList(data=model_cards)
-
-
-@app.get("/get_model_info")
-async def get_model_info():
- """Get the model information."""
- result = {
- "model_path": app.state.tokenizer_manager.model_path,
- "tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path,
- "is_generation": app.state.tokenizer_manager.is_generation,
- }
- return result
-
-
-@app.post("/v1/completions")
-async def openai_v1_completions(raw_request: Request):
- pass
-
-
-@app.post("/v1/chat/completions")
-async def openai_v1_chat_completions(raw_request: Request):
- pass
-
-
-@app.post("/v1/embeddings")
-async def openai_v1_embeddings(raw_request: Request):
- try:
- request_json = await raw_request.json()
- request = EmbeddingRequest(**request_json)
- except Exception as e:
- return app.state.serving_embedding.create_error_response(
- f"Invalid request body, error: {str(e)}"
- )
-
- ret = await app.state.serving_embedding.handle_request(request, raw_request)
- return ret
-
-
-@app.post("/v1/score")
-async def v1_score_request(raw_request: Request):
- """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
- pass
-
-
-@app.api_route("/v1/models/{model_id}", methods=["GET"])
-async def show_model_detail(model_id: str):
- served_model_name = app.state.tokenizer_manager.served_model_name
-
- return ModelCard(
- id=served_model_name,
- root=served_model_name,
- max_model_len=app.state.tokenizer_manager.model_config.context_len,
- )
-
-
-# Additional API endpoints will be implemented in separate serving_*.py modules
-# and mounted as APIRouters in future PRs
-
-
-def _wait_and_warmup(
- server_args: ServerArgs,
- pipe_finish_writer: Optional[multiprocessing.connection.Connection],
- image_token_text: str,
- launch_callback: Optional[Callable[[], None]] = None,
-):
- return
- # TODO: Please wait until the /generate implementation is complete,
- # or confirm if modifications are needed before removing this.
-
- headers = {}
- url = server_args.url()
- if server_args.api_key:
- headers["Authorization"] = f"Bearer {server_args.api_key}"
-
- # Wait until the server is launched
- success = False
- for _ in range(120):
- time.sleep(1)
- try:
- res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
- assert res.status_code == 200, f"{res=}, {res.text=}"
- success = True
- break
- except (AssertionError, requests.exceptions.RequestException):
- last_traceback = get_exception_traceback()
- pass
-
- if not success:
- if pipe_finish_writer is not None:
- pipe_finish_writer.send(last_traceback)
- logger.error(f"Initialization failed. warmup error: {last_traceback}")
- kill_process_tree(os.getpid())
- return
-
- model_info = res.json()
-
- # Send a warmup request
- request_name = "/generate" if model_info["is_generation"] else "/encode"
- # TODO: Replace with OpenAI API
- max_new_tokens = 8 if model_info["is_generation"] else 1
- json_data = {
- "sampling_params": {
- "temperature": 0,
- "max_new_tokens": max_new_tokens,
- },
- }
- if server_args.skip_tokenizer_init:
- json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
- # TODO Workaround the bug that embedding errors for list of size 1
- if server_args.dp_size == 1:
- json_data["input_ids"] = json_data["input_ids"][0]
- else:
- json_data["text"] = ["The capital city of France is"] * server_args.dp_size
- # TODO Workaround the bug that embedding errors for list of size 1
- if server_args.dp_size == 1:
- json_data["text"] = json_data["text"][0]
-
- # Debug dumping
- if server_args.debug_tensor_dump_input_file:
- json_data.pop("text", None)
- json_data["input_ids"] = np.load(
- server_args.debug_tensor_dump_input_file
- ).tolist()
- json_data["sampling_params"]["max_new_tokens"] = 0
-
- try:
- if server_args.disaggregation_mode == "null":
- res = requests.post(
- url + request_name,
- json=json_data,
- headers=headers,
- timeout=600,
- )
- assert res.status_code == 200, f"{res}"
- else:
- logger.info(f"Start of prefill warmup ...")
- json_data = {
- "sampling_params": {
- "temperature": 0.0,
- "max_new_tokens": 8,
- "ignore_eos": True,
- },
- "bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
- # This is a hack to ensure fake transfer is enabled during prefill warmup
- # ensure each dp rank has a unique bootstrap_room during prefill warmup
- "bootstrap_room": [
- i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
- for i in range(server_args.dp_size)
- ],
- "input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
- }
- res = requests.post(
- url + request_name,
- json=json_data,
- headers=headers,
- timeout=1800, # because of deep gemm precache is very long if not precache.
- )
- logger.info(
- f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
- )
-
- except Exception:
- last_traceback = get_exception_traceback()
- if pipe_finish_writer is not None:
- pipe_finish_writer.send(last_traceback)
- logger.error(f"Initialization failed. warmup error: {last_traceback}")
- kill_process_tree(os.getpid())
- return
-
- # Debug print
- # logger.info(f"{res.json()=}")
-
- logger.info("The server is fired up and ready to roll!")
- if pipe_finish_writer is not None:
- pipe_finish_writer.send("ready")
-
- if server_args.delete_ckpt_after_loading:
- delete_directory(server_args.model_path)
-
- if server_args.debug_tensor_dump_input_file:
- kill_process_tree(os.getpid())
-
- if server_args.pdlb_url is not None:
- register_disaggregation_server(
- server_args.disaggregation_mode,
- server_args.port,
- server_args.disaggregation_bootstrap_port,
- server_args.pdlb_url,
- )
-
- if launch_callback is not None:
- launch_callback()
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server")
- # Add arguments from ServerArgs. This allows reuse of existing CLI definitions.
- ServerArgs.add_cli_args(parser)
- # Potentially add server-specific arguments here in the future if needed
-
- args = parser.parse_args()
- server_args = ServerArgs.from_cli_args(args)
-
- # Store server_args in app.state for access in lifespan and endpoints
- app.state.server_args = server_args
-
- # Configure logging
- logging.basicConfig(
- level=server_args.log_level.upper(),
- format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
- )
-
- # Send a warmup request - we will create the thread launch it
- # in the lifespan after all other warmups have fired.
- warmup_thread = threading.Thread(
- target=_wait_and_warmup,
- args=(
- server_args,
- None,
- None, # Never used
- None,
- ),
- )
- app.warmup_thread = warmup_thread
-
- try:
- # Start the server
- set_uvicorn_logging_configs()
- uvicorn.run(
- app,
- host=server_args.host,
- port=server_args.port,
- log_level=server_args.log_level.lower(),
- timeout_keep_alive=60, # Increased keep-alive for potentially long requests
- loop="uvloop", # Use uvloop for better performance if available
- )
- finally:
- warmup_thread.join()
diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py
index 9acd50d02..89b2d3ab6 100644
--- a/python/sglang/srt/entrypoints/openai/protocol.py
+++ b/python/sglang/srt/entrypoints/openai/protocol.py
@@ -207,7 +207,7 @@ class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
- finish_reason: Literal["stop", "length", "content_filter", "abort"]
+ finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@@ -404,21 +404,13 @@ class ChatCompletionRequest(BaseModel):
@model_validator(mode="before")
@classmethod
def set_tool_choice_default(cls, values):
- if isinstance(values, dict):
- if values.get("tool_choice") is None:
- if values.get("tools") is None:
- values["tool_choice"] = "none"
- else:
- values["tool_choice"] = "auto"
+ if values.get("tool_choice") is None:
+ if values.get("tools") is None:
+ values["tool_choice"] = "none"
+ else:
+ values["tool_choice"] = "auto"
return values
- @field_validator("messages")
- @classmethod
- def validate_messages_not_empty(cls, v):
- if not v:
- raise ValueError("Messages cannot be empty")
- return v
-
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
@@ -457,9 +449,11 @@ class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
- finish_reason: Literal[
- "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
- ]
+ finish_reason: Optional[
+ Literal[
+ "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
+ ]
+ ] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@@ -530,7 +524,7 @@ class EmbeddingRequest(BaseModel):
input: EmbeddingInput
model: str
encoding_format: str = "float"
- dimensions: int = None
+ dimensions: Optional[int] = None
user: Optional[str] = None
# The request id.
diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py
index 8e22c26c4..ba7514f0d 100644
--- a/python/sglang/srt/entrypoints/openai/serving_base.py
+++ b/python/sglang/srt/entrypoints/openai/serving_base.py
@@ -2,16 +2,12 @@ import json
import logging
import uuid
from abc import ABC, abstractmethod
-from typing import Any, Dict, Optional, Union
+from typing import Any, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
-from sglang.srt.entrypoints.openai.protocol import (
- ErrorResponse,
- OpenAIServingRequest,
- UsageInfo,
-)
+from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -51,7 +47,7 @@ class OpenAIServingBase(ABC):
)
except Exception as e:
- logger.error(f"Error in request: {e}")
+ logger.exception(f"Error in request: {e}")
return self.create_error_response(
message=f"Internal server error: {str(e)}",
err_type="InternalServerError",
@@ -63,8 +59,12 @@ class OpenAIServingBase(ABC):
"""Generate request ID based on request type"""
pass
- def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
+ def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:
"""Generate request ID based on request type"""
+ return None
+
+ # TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError
+ # Temporarily return None in this function until the rid logic is clear.
if rid := getattr(request, "rid", None):
return rid
@@ -83,7 +83,7 @@ class OpenAIServingBase(ABC):
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
- ) -> StreamingResponse:
+ ) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:
"""Handle streaming request
Override this method in child classes that support streaming requests.
@@ -99,7 +99,7 @@ class OpenAIServingBase(ABC):
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
- ) -> Union[Any, ErrorResponse]:
+ ) -> Union[Any, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming request
Override this method in child classes that support non-streaming requests.
@@ -110,7 +110,7 @@ class OpenAIServingBase(ABC):
status_code=501,
)
- def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
+ def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:
"""Validate request"""
pass
@@ -122,6 +122,7 @@ class OpenAIServingBase(ABC):
param: Optional[str] = None,
) -> ORJSONResponse:
"""Create an error response"""
+ # TODO: remove fastapi dependency in openai and move response handling to the entrypoint
error = ErrorResponse(
object="error",
message=message,
diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py
index cb128ff41..b91fee18c 100644
--- a/python/sglang/srt/entrypoints/openai/serving_chat.py
+++ b/python/sglang/srt/entrypoints/openai/serving_chat.py
@@ -1,4 +1,3 @@
-import base64
import json
import logging
import time
@@ -6,7 +5,7 @@ import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request
-from fastapi.responses import StreamingResponse
+from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.openai.protocol import (
@@ -28,13 +27,14 @@ from sglang.srt.entrypoints.openai.protocol import (
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
- detect_template_content_format,
- process_content_for_template_format,
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
+from sglang.srt.jinja_template_utils import process_content_for_template_format
from sglang.srt.managers.io_struct import GenerateReqInput
+from sglang.srt.managers.template_manager import TemplateManager
+from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str
@@ -42,13 +42,13 @@ logger = logging.getLogger(__name__)
class OpenAIServingChat(OpenAIServingBase):
- """Handler for chat completion requests"""
+ """Handler for /v1/chat/completions requests"""
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # Instance-specific cache for template content format detection
- self._cached_chat_template = None
- self._cached_template_format = None
+ def __init__(
+ self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager
+ ):
+ super().__init__(tokenizer_manager)
+ self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "chatcmpl-"
@@ -142,19 +142,14 @@ class OpenAIServingChat(OpenAIServingBase):
)
# Use chat template
- if (
- hasattr(self.tokenizer_manager, "chat_template_name")
- and self.tokenizer_manager.chat_template_name is None
- ):
+ if self.template_manager.chat_template_name is None:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_jinja_template(request, tools, is_multimodal)
)
else:
- prompt, image_data, audio_data, modalities, stop = (
- self._apply_conversation_template(request)
+ prompt, prompt_ids, image_data, audio_data, modalities, stop = (
+ self._apply_conversation_template(request, is_multimodal)
)
- if not is_multimodal:
- prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
else:
# Use raw prompt
prompt_ids = request.messages
@@ -181,23 +176,14 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
"""Apply Jinja chat template"""
+ prompt = ""
+ prompt_ids = []
openai_compatible_messages = []
image_data = []
audio_data = []
modalities = []
- # Detect template content format
- current_template = self.tokenizer_manager.tokenizer.chat_template
- if current_template != self._cached_chat_template:
- self._cached_chat_template = current_template
- self._cached_template_format = detect_template_content_format(
- current_template
- )
- logger.info(
- f"Detected chat template content format: {self._cached_template_format}"
- )
-
- template_content_format = self._cached_template_format
+ template_content_format = self.template_manager.jinja_template_content_format
for message in request.messages:
if message.content is None:
@@ -262,14 +248,21 @@ class OpenAIServingChat(OpenAIServingBase):
if is_multimodal:
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
- stop = request.stop or []
+ stop = request.stop
+ image_data = image_data if image_data else None
+ audio_data = audio_data if audio_data else None
+ modalities = modalities if modalities else []
return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _apply_conversation_template(
- self, request: ChatCompletionRequest
- ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]:
+ self,
+ request: ChatCompletionRequest,
+ is_multimodal: bool,
+ ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
"""Apply conversation template"""
- conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name)
+ prompt = ""
+ prompt_ids = []
+ conv = generate_chat_conv(request, self.template_manager.chat_template_name)
# If we should continue the final assistant message, adjust the conversation.
if (
@@ -296,9 +289,9 @@ class OpenAIServingChat(OpenAIServingBase):
else:
prompt = conv.get_prompt()
- image_data = conv.image_data
- audio_data = conv.audio_data
- modalities = conv.modalities
+ image_data = conv.image_data if conv.image_data else None
+ audio_data = conv.audio_data if conv.audio_data else None
+ modalities = conv.modalities if conv.modalities else []
stop = conv.stop_str or [] if not request.ignore_eos else []
if request.stop:
@@ -307,7 +300,10 @@ class OpenAIServingChat(OpenAIServingBase):
else:
stop.extend(request.stop)
- return prompt, image_data, audio_data, modalities, stop
+ if not is_multimodal:
+ prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
+
+ return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _build_sampling_params(
self,
@@ -459,13 +455,9 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
- enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
- "enable_thinking", True
- )
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
- and enable_thinking
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
@@ -591,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {usage_chunk.model_dump_json()}\n\n"
- except Exception as e:
+ except ValueError as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
@@ -602,7 +594,7 @@ class OpenAIServingChat(OpenAIServingBase):
adapted_request: GenerateReqInput,
request: ChatCompletionRequest,
raw_request: Request,
- ) -> Union[ChatCompletionResponse, ErrorResponse]:
+ ) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming chat completion request"""
try:
ret = await self.tokenizer_manager.generate_request(
@@ -627,7 +619,7 @@ class OpenAIServingChat(OpenAIServingBase):
request: ChatCompletionRequest,
ret: List[Dict[str, Any]],
created: int,
- ) -> ChatCompletionResponse:
+ ) -> Union[ChatCompletionResponse, ORJSONResponse]:
"""Build chat completion response from generation results"""
choices = []
@@ -645,11 +637,8 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content
reasoning_text = None
- enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
- "enable_thinking", True
- )
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
- if reasoning_parser and request.separate_reasoning and enable_thinking:
+ if reasoning_parser and request.separate_reasoning:
try:
parser = ReasoningParser(
model_type=reasoning_parser, stream_reasoning=False
@@ -691,9 +680,10 @@ class OpenAIServingChat(OpenAIServingBase):
choices.append(choice_data)
# Calculate usage
- cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = UsageProcessor.calculate_response_usage(
- ret, n_choices=request.n, enable_cache_report=cache_report
+ ret,
+ n_choices=request.n,
+ enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
return ChatCompletionResponse(
@@ -821,6 +811,25 @@ class OpenAIServingChat(OpenAIServingBase):
reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta)
+ def _get_enable_thinking_from_request(request: ChatCompletionRequest) -> bool:
+ """Extracts the 'enable_thinking' flag from request chat_template_kwargs.
+
+ NOTE: This parameter is only useful for models that support enable_thinking
+ flag, such as Qwen3.
+
+ Args:
+ request_obj: The request object (or an item from a list of requests).
+ Returns:
+ The boolean value of 'enable_thinking' if found and not True, otherwise True.
+ """
+ if (
+ hasattr(request, "chat_template_kwargs")
+ and request.chat_template_kwargs
+ and request.chat_template_kwargs.get("enable_thinking") is not None
+ ):
+ return request.chat_template_kwargs.get("enable_thinking")
+ return True
+
async def _process_tool_call_stream(
self,
index: int,
diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py
index af715b32d..3db881641 100644
--- a/python/sglang/srt/entrypoints/openai/serving_completions.py
+++ b/python/sglang/srt/entrypoints/openai/serving_completions.py
@@ -3,12 +3,9 @@ import time
from typing import Any, AsyncGenerator, Dict, List, Union
from fastapi import Request
-from fastapi.responses import StreamingResponse
+from fastapi.responses import ORJSONResponse, StreamingResponse
-from sglang.srt.code_completion_parser import (
- generate_completion_prompt_from_request,
- is_completion_template_defined,
-)
+from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
@@ -24,12 +21,22 @@ from sglang.srt.entrypoints.openai.utils import (
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput
+from sglang.srt.managers.template_manager import TemplateManager
+from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__name__)
class OpenAIServingCompletion(OpenAIServingBase):
- """Handler for completion requests"""
+ """Handler for /v1/completion requests"""
+
+ def __init__(
+ self,
+ tokenizer_manager: TokenizerManager,
+ template_manager: TemplateManager,
+ ):
+ super().__init__(tokenizer_manager)
+ self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "cmpl-"
@@ -47,7 +54,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
)
# Process prompt
prompt = request.prompt
- if is_completion_template_defined():
+ if self.template_manager.completion_template_name is not None:
prompt = generate_completion_prompt_from_request(request)
# Set logprob start length based on echo and logprobs
@@ -141,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
+ hidden_states = {}
try:
async for content in self.tokenizer_manager.generate_request(
@@ -152,6 +160,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
+ hidden_states[index] = content["meta_info"].get("hidden_states", None)
stream_buffer = stream_buffers.get(index, "")
# Handle echo for first chunk
@@ -192,7 +201,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
- hidden_states = content["meta_info"].get("hidden_states", None)
choice_data = CompletionResponseStreamChoice(
index=index,
@@ -269,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
- ) -> Union[CompletionResponse, ErrorResponse]:
+ ) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming completion request"""
try:
generator = self.tokenizer_manager.generate_request(
diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py
index 4fe60f230..4f2db1dbe 100644
--- a/python/sglang/srt/entrypoints/openai/serving_embedding.py
+++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
+from fastapi.responses import ORJSONResponse
from sglang.srt.conversation import generate_embedding_convs
from sglang.srt.entrypoints.openai.protocol import (
@@ -13,10 +14,20 @@ from sglang.srt.entrypoints.openai.protocol import (
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import EmbeddingReqInput
+from sglang.srt.managers.template_manager import TemplateManager
+from sglang.srt.managers.tokenizer_manager import TokenizerManager
class OpenAIServingEmbedding(OpenAIServingBase):
- """Handler for embedding requests"""
+ """Handler for v1/embeddings requests"""
+
+ def __init__(
+ self,
+ tokenizer_manager: TokenizerManager,
+ template_manager: TemplateManager,
+ ):
+ super().__init__(tokenizer_manager)
+ self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "embd-"
@@ -68,11 +79,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
- # List of strings - if it's a single string in a list, treat as single string
- if len(prompt) == 1:
- prompt_kwargs = {"text": prompt[0]}
- else:
- prompt_kwargs = {"text": prompt}
+ prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs
texts = []
@@ -84,11 +91,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
- chat_template_name = getattr(
- self.tokenizer_manager, "chat_template_name", None
- )
- if chat_template_name is not None:
- convs = generate_embedding_convs(texts, images, chat_template_name)
+ if self.template_manager.chat_template_name is not None:
+ convs = generate_embedding_convs(
+ texts, images, self.template_manager.chat_template_name
+ )
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
@@ -122,7 +128,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
adapted_request: EmbeddingReqInput,
request: EmbeddingRequest,
raw_request: Request,
- ) -> Union[EmbeddingResponse, ErrorResponse]:
+ ) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:
"""Handle the embedding request"""
try:
ret = await self.tokenizer_manager.generate_request(
diff --git a/python/sglang/srt/entrypoints/openai/serving_rerank.py b/python/sglang/srt/entrypoints/openai/serving_rerank.py
index 50be5c3cc..b053c55b3 100644
--- a/python/sglang/srt/entrypoints/openai/serving_rerank.py
+++ b/python/sglang/srt/entrypoints/openai/serving_rerank.py
@@ -2,6 +2,7 @@ import logging
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
+from fastapi.responses import ORJSONResponse
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
@@ -15,7 +16,10 @@ logger = logging.getLogger(__name__)
class OpenAIServingRerank(OpenAIServingBase):
- """Handler for rerank requests"""
+ """Handler for /v1/rerank requests"""
+
+ # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
+ # to another module in the future.
def _request_id_prefix(self) -> str:
return "rerank-"
@@ -61,7 +65,7 @@ class OpenAIServingRerank(OpenAIServingBase):
adapted_request: EmbeddingReqInput,
request: V1RerankReqInput,
raw_request: Request,
- ) -> Union[RerankResponse, ErrorResponse]:
+ ) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:
"""Handle the rerank request"""
try:
ret = await self.tokenizer_manager.generate_request(
@@ -74,16 +78,16 @@ class OpenAIServingRerank(OpenAIServingBase):
if not isinstance(ret, list):
ret = [ret]
- response = self._build_rerank_response(ret, request)
- return response
+ responses = self._build_rerank_response(ret, request)
+ return responses
def _build_rerank_response(
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
) -> List[RerankResponse]:
"""Build the rerank response from generation results"""
- response = []
+ responses = []
for idx, ret_item in enumerate(ret):
- response.append(
+ responses.append(
RerankResponse(
score=ret_item["embedding"],
document=request.documents[idx],
@@ -93,6 +97,6 @@ class OpenAIServingRerank(OpenAIServingBase):
)
# Sort by score in descending order (highest relevance first)
- response.sort(key=lambda x: x.score, reverse=True)
+ responses.sort(key=lambda x: x.score, reverse=True)
- return response
+ return responses
diff --git a/python/sglang/srt/entrypoints/openai/serving_score.py b/python/sglang/srt/entrypoints/openai/serving_score.py
index af73a866a..fc8ce5dca 100644
--- a/python/sglang/srt/entrypoints/openai/serving_score.py
+++ b/python/sglang/srt/entrypoints/openai/serving_score.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any, Dict, List, Optional, Union
+from typing import Union
from fastapi import Request
@@ -14,7 +14,10 @@ logger = logging.getLogger(__name__)
class OpenAIServingScore(OpenAIServingBase):
- """Handler for scoring requests"""
+ """Handler for /v1/score requests"""
+
+ # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
+ # to another module in the future.
def _request_id_prefix(self) -> str:
return "score-"
diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py
index e80125cf6..94ac5458d 100644
--- a/python/sglang/srt/entrypoints/openai/utils.py
+++ b/python/sglang/srt/entrypoints/openai/utils.py
@@ -1,9 +1,6 @@
import logging
from typing import Any, Dict, List, Optional, Union
-import jinja2.nodes
-import transformers.utils.chat_template_utils as hf_chat_utils
-
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
@@ -13,168 +10,6 @@ from sglang.srt.entrypoints.openai.protocol import (
logger = logging.getLogger(__name__)
-# ============================================================================
-# JINJA TEMPLATE CONTENT FORMAT DETECTION
-# ============================================================================
-#
-# This adapts vLLM's approach for detecting chat template content format:
-# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
-# - Analyzes Jinja template AST to detect content iteration patterns
-# - 'openai' format: templates with {%- for content in message['content'] -%} loops
-# - 'string' format: templates that expect simple string content
-# - Processes content accordingly to match template expectations
-
-
-def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
- """Check if node is a variable access like {{ varname }}"""
- if isinstance(node, jinja2.nodes.Name):
- return node.ctx == "load" and node.name == varname
- return False
-
-
-def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
- """Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
- if isinstance(node, jinja2.nodes.Getitem):
- return (
- _is_var_access(node.node, varname)
- and isinstance(node.arg, jinja2.nodes.Const)
- and node.arg.value == key
- )
-
- if isinstance(node, jinja2.nodes.Getattr):
- return _is_var_access(node.node, varname) and node.attr == key
-
- return False
-
-
-def _is_var_or_elems_access(
- node: jinja2.nodes.Node,
- varname: str,
- key: str = None,
-) -> bool:
- """Check if node accesses varname or varname[key] with filters/tests"""
- if isinstance(node, jinja2.nodes.Filter):
- return node.node is not None and _is_var_or_elems_access(
- node.node, varname, key
- )
- if isinstance(node, jinja2.nodes.Test):
- return _is_var_or_elems_access(node.node, varname, key)
-
- if isinstance(node, jinja2.nodes.Getitem) and isinstance(
- node.arg, jinja2.nodes.Slice
- ):
- return _is_var_or_elems_access(node.node, varname, key)
-
- return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
-
-
-def _try_extract_ast(chat_template: str):
- """Try to parse the Jinja template into an AST"""
- try:
- jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
- return jinja_compiled.environment.parse(chat_template)
- except Exception as e:
- logger.debug(f"Error when compiling Jinja template: {e}")
- return None
-
-
-def detect_template_content_format(chat_template: str) -> str:
- """
- Detect whether a chat template expects 'string' or 'openai' content format.
-
- - 'string': content is a simple string (like DeepSeek templates)
- - 'openai': content is a list of structured dicts (like Llama4 templates)
-
- Detection logic:
- - If template has loops like {%- for content in message['content'] -%} → 'openai'
- - Otherwise → 'string'
- """
- jinja_ast = _try_extract_ast(chat_template)
- if jinja_ast is None:
- return "string"
-
- try:
- # Look for patterns like: {%- for content in message['content'] -%}
- for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
- loop_iter = loop_ast.iter
-
- # Check if iterating over message['content'] or similar
- if _is_var_or_elems_access(loop_iter, "message", "content"):
- return "openai" # Found content iteration → openai format
-
- return "string" # No content loops found → string format
- except Exception as e:
- logger.debug(f"Error when parsing AST of Jinja template: {e}")
- return "string"
-
-
-def process_content_for_template_format(
- msg_dict: dict,
- content_format: str,
- image_data: list,
- audio_data: list,
- modalities: list,
-) -> dict:
- """
- Process message content based on detected template format.
-
- Args:
- msg_dict: Message dictionary with content
- content_format: 'string' or 'openai' (detected via AST analysis)
- image_data: List to append extracted image URLs
- audio_data: List to append extracted audio URLs
- modalities: List to append modalities
-
- Returns:
- Processed message dictionary
- """
- if not isinstance(msg_dict.get("content"), list):
- # Already a string or None, no processing needed
- return {k: v for k, v in msg_dict.items() if v is not None}
-
- if content_format == "openai":
- # OpenAI format: preserve structured content list, normalize types
- processed_content_parts = []
- for chunk in msg_dict["content"]:
- if isinstance(chunk, dict):
- chunk_type = chunk.get("type")
-
- if chunk_type == "image_url":
- image_data.append(chunk["image_url"]["url"])
- if chunk.get("modalities"):
- modalities.append(chunk.get("modalities"))
- # Normalize to simple 'image' type for template compatibility
- processed_content_parts.append({"type": "image"})
- elif chunk_type == "audio_url":
- audio_data.append(chunk["audio_url"]["url"])
- # Normalize to simple 'audio' type
- processed_content_parts.append({"type": "audio"})
- else:
- # Keep other content as-is (text, etc.)
- processed_content_parts.append(chunk)
-
- new_msg = {
- k: v for k, v in msg_dict.items() if v is not None and k != "content"
- }
- new_msg["content"] = processed_content_parts
- return new_msg
-
- else: # content_format == "string"
- # String format: flatten to text only (for templates like DeepSeek)
- text_parts = []
- for chunk in msg_dict["content"]:
- if isinstance(chunk, dict) and chunk.get("type") == "text":
- text_parts.append(chunk["text"])
- # Note: For string format, we ignore images/audio since the template
- # doesn't expect structured content - multimodal placeholders would
- # need to be inserted differently
-
- new_msg = msg_dict.copy()
- new_msg["content"] = " ".join(text_parts) if text_parts else ""
- new_msg = {k: v for k, v in new_msg.items() if v is not None}
- return new_msg
-
-
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py
index e0e95cff4..d1e414df6 100644
--- a/python/sglang/srt/function_call/base_format_detector.py
+++ b/python/sglang/srt/function_call/base_format_detector.py
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
+from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
_is_complete_json,
_partial_json_loads,
)
-from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py
index 1245c3db4..e3befca5b 100644
--- a/python/sglang/srt/function_call/deepseekv3_detector.py
+++ b/python/sglang/srt/function_call/deepseekv3_detector.py
@@ -3,6 +3,7 @@ import logging
import re
from typing import List
+from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json
-from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py
index af1033a52..10b92a0af 100644
--- a/python/sglang/srt/function_call/function_call_parser.py
+++ b/python/sglang/srt/function_call/function_call_parser.py
@@ -1,6 +1,12 @@
import logging
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
+from sglang.srt.entrypoints.openai.protocol import (
+ StructuralTagResponseFormat,
+ StructuresResponseFormat,
+ Tool,
+ ToolChoice,
+)
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
-from sglang.srt.openai_api.protocol import (
- StructuralTagResponseFormat,
- StructuresResponseFormat,
- Tool,
- ToolChoice,
-)
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py
index 065ffd7f6..e7afeddb0 100644
--- a/python/sglang/srt/function_call/llama32_detector.py
+++ b/python/sglang/srt/function_call/llama32_detector.py
@@ -2,6 +2,7 @@ import json
import logging
from typing import List
+from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
-from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py
index 05d3bfead..031368006 100644
--- a/python/sglang/srt/function_call/mistral_detector.py
+++ b/python/sglang/srt/function_call/mistral_detector.py
@@ -3,6 +3,7 @@ import logging
import re
from typing import List
+from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
-from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py
index 4ef2e3db5..d3096d919 100644
--- a/python/sglang/srt/function_call/pythonic_detector.py
+++ b/python/sglang/srt/function_call/pythonic_detector.py
@@ -4,6 +4,7 @@ import logging
import re
from typing import List, Optional
+from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
-from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py
index ad1317777..cee3f18ea 100644
--- a/python/sglang/srt/function_call/qwen25_detector.py
+++ b/python/sglang/srt/function_call/qwen25_detector.py
@@ -3,6 +3,7 @@ import logging
import re
from typing import List
+from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
-from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/openai_api/utils.py b/python/sglang/srt/jinja_template_utils.py
similarity index 95%
rename from python/sglang/srt/openai_api/utils.py
rename to python/sglang/srt/jinja_template_utils.py
index 610251aff..14a94e487 100644
--- a/python/sglang/srt/openai_api/utils.py
+++ b/python/sglang/srt/jinja_template_utils.py
@@ -1,11 +1,12 @@
-"""
-Utility functions for OpenAI API adapter.
+"""Template utilities for Jinja template processing.
+
+This module provides utilities for analyzing and processing Jinja chat templates,
+including content format detection and message processing.
"""
import logging
-from typing import Dict, List
-import jinja2.nodes
+import jinja2
import transformers.utils.chat_template_utils as hf_chat_utils
logger = logging.getLogger(__name__)
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
return None
-def detect_template_content_format(chat_template: str) -> str:
+def detect_jinja_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 3c4bf2a42..0fcb38227 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -864,12 +864,6 @@ class SetInternalStateReq:
server_args: Dict[str, Any]
-@dataclass
-class V1RerankReqInput:
- query: str
- documents: List[str]
-
-
@dataclass
class SetInternalStateReqOutput:
updated: bool
diff --git a/python/sglang/srt/managers/template_manager.py b/python/sglang/srt/managers/template_manager.py
new file mode 100644
index 000000000..4684bf1a0
--- /dev/null
+++ b/python/sglang/srt/managers/template_manager.py
@@ -0,0 +1,226 @@
+# Copyright 2023-2024 SGLang Team
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+Centralized template management for chat templates and completion templates.
+
+This module provides a unified interface for managing both chat conversation templates
+and code completion templates, eliminating global state and improving modularity.
+"""
+
+import json
+import logging
+import os
+from typing import Optional
+
+from sglang.srt.code_completion_parser import (
+ CompletionTemplate,
+ FimPosition,
+ completion_template_exists,
+ register_completion_template,
+)
+from sglang.srt.conversation import (
+ Conversation,
+ SeparatorStyle,
+ chat_template_exists,
+ get_conv_template_by_model_path,
+ register_conv_template,
+)
+from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
+
+logger = logging.getLogger(__name__)
+
+
+class TemplateManager:
+ """
+ Centralized manager for chat and completion templates.
+
+ This class encapsulates all template-related state and operations,
+ eliminating the need for global variables and providing a clean
+ interface for template management.
+ """
+
+ def __init__(self):
+ self._chat_template_name: Optional[str] = None
+ self._completion_template_name: Optional[str] = None
+ self._jinja_template_content_format: Optional[str] = None
+
+ @property
+ def chat_template_name(self) -> Optional[str]:
+ """Get the current chat template name."""
+ return self._chat_template_name
+
+ @property
+ def completion_template_name(self) -> Optional[str]:
+ """Get the current completion template name."""
+ return self._completion_template_name
+
+ @property
+ def jinja_template_content_format(self) -> Optional[str]:
+ """Get the detected template content format ('string' or 'openai' or None)."""
+ return self._jinja_template_content_format
+
+ def load_chat_template(
+ self, tokenizer_manager, chat_template_arg: str, model_path: str
+ ) -> None:
+ """
+ Load a chat template from various sources.
+
+ Args:
+ tokenizer_manager: The tokenizer manager instance
+ chat_template_arg: Template name or file path
+ model_path: Path to the model
+ """
+ logger.info(f"Loading chat template: {chat_template_arg}")
+
+ if not chat_template_exists(chat_template_arg):
+ if not os.path.exists(chat_template_arg):
+ raise RuntimeError(
+ f"Chat template {chat_template_arg} is not a built-in template name "
+ "or a valid chat template file path."
+ )
+
+ if chat_template_arg.endswith(".jinja"):
+ self._load_jinja_template(tokenizer_manager, chat_template_arg)
+ else:
+ self._load_json_chat_template(chat_template_arg)
+ else:
+ self._chat_template_name = chat_template_arg
+
+ def guess_chat_template_from_model_path(self, model_path: str) -> None:
+ """
+ Infer chat template name from model path.
+
+ Args:
+ model_path: Path to the model
+ """
+ template_name = get_conv_template_by_model_path(model_path)
+ if template_name is not None:
+ logger.info(f"Inferred chat template from model path: {template_name}")
+ self._chat_template_name = template_name
+
+ def load_completion_template(self, completion_template_arg: str) -> None:
+ """
+ Load completion template for code completion.
+
+ Args:
+ completion_template_arg: Template name or file path
+ """
+ logger.info(f"Loading completion template: {completion_template_arg}")
+
+ if not completion_template_exists(completion_template_arg):
+ if not os.path.exists(completion_template_arg):
+ raise RuntimeError(
+ f"Completion template {completion_template_arg} is not a built-in template name "
+ "or a valid completion template file path."
+ )
+
+ self._load_json_completion_template(completion_template_arg)
+ else:
+ self._completion_template_name = completion_template_arg
+
+ def initialize_templates(
+ self,
+ tokenizer_manager,
+ model_path: str,
+ chat_template: Optional[str] = None,
+ completion_template: Optional[str] = None,
+ ) -> None:
+ """
+ Initialize all templates based on provided configuration.
+
+ Args:
+ tokenizer_manager: The tokenizer manager instance
+ model_path: Path to the model
+ chat_template: Optional chat template name/path
+ completion_template: Optional completion template name/path
+ """
+ # Load chat template
+ if chat_template:
+ self.load_chat_template(tokenizer_manager, chat_template, model_path)
+ else:
+ self.guess_chat_template_from_model_path(model_path)
+
+ # Load completion template
+ if completion_template:
+ self.load_completion_template(completion_template)
+
+ def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
+ """Load a Jinja template file."""
+ with open(template_path, "r") as f:
+ chat_template = "".join(f.readlines()).strip("\n")
+ tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
+ self._chat_template_name = None
+ # Detect content format from the loaded template
+ self._jinja_template_content_format = detect_jinja_template_content_format(
+ chat_template
+ )
+ logger.info(
+ f"Detected chat template content format: {self._jinja_template_content_format}"
+ )
+
+ def _load_json_chat_template(self, template_path: str) -> None:
+ """Load a JSON chat template file."""
+ assert template_path.endswith(
+ ".json"
+ ), "unrecognized format of chat template file"
+
+ with open(template_path, "r") as filep:
+ template = json.load(filep)
+ try:
+ sep_style = SeparatorStyle[template["sep_style"]]
+ except KeyError:
+ raise ValueError(
+ f"Unknown separator style: {template['sep_style']}"
+ ) from None
+
+ register_conv_template(
+ Conversation(
+ name=template["name"],
+ system_template=template["system"] + "\n{system_message}",
+ system_message=template.get("system_message", ""),
+ roles=(template["user"], template["assistant"]),
+ sep_style=sep_style,
+ sep=template.get("sep", "\n"),
+ stop_str=template["stop_str"],
+ ),
+ override=True,
+ )
+ self._chat_template_name = template["name"]
+
+ def _load_json_completion_template(self, template_path: str) -> None:
+ """Load a JSON completion template file."""
+ assert template_path.endswith(
+ ".json"
+ ), "unrecognized format of completion template file"
+
+ with open(template_path, "r") as filep:
+ template = json.load(filep)
+ try:
+ fim_position = FimPosition[template["fim_position"]]
+ except KeyError:
+ raise ValueError(
+ f"Unknown fim position: {template['fim_position']}"
+ ) from None
+
+ register_completion_template(
+ CompletionTemplate(
+ name=template["name"],
+ fim_begin_token=template["fim_begin_token"],
+ fim_middle_token=template["fim_middle_token"],
+ fim_end_token=template["fim_end_token"],
+ fim_position=fim_position,
+ ),
+ override=True,
+ )
+ self._completion_template_name = template["name"]
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index fbab668a4..959f23277 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -1058,12 +1058,7 @@ class TokenizerManager:
"lora_path",
]
)
- out_skip_names = set(
- [
- "text",
- "output_ids",
- ]
- )
+ out_skip_names = set(["text", "output_ids", "embedding"])
elif self.log_requests_level == 1:
max_length = 2048
elif self.log_requests_level == 2:
diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py
deleted file mode 100644
index aba1a5afd..000000000
--- a/python/sglang/srt/openai_api/adapter.py
+++ /dev/null
@@ -1,2148 +0,0 @@
-# Copyright 2023-2024 SGLang Team
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Conversion between OpenAI APIs and native SRT APIs"""
-
-import asyncio
-import base64
-import json
-import logging
-import os
-import time
-import uuid
-from http import HTTPStatus
-from typing import Dict, List
-
-from fastapi import HTTPException, Request, UploadFile
-from fastapi.responses import ORJSONResponse, StreamingResponse
-from pydantic import ValidationError
-
-from sglang.srt.code_completion_parser import (
- generate_completion_prompt_from_request,
- is_completion_template_defined,
-)
-from sglang.srt.conversation import (
- Conversation,
- SeparatorStyle,
- chat_template_exists,
- generate_chat_conv,
- generate_embedding_convs,
- get_conv_template_by_model_path,
- register_conv_template,
-)
-from sglang.srt.function_call.function_call_parser import FunctionCallParser
-from sglang.srt.managers.io_struct import (
- EmbeddingReqInput,
- GenerateReqInput,
- V1RerankReqInput,
-)
-from sglang.srt.openai_api.protocol import (
- BatchRequest,
- BatchResponse,
- ChatCompletionRequest,
- ChatCompletionResponse,
- ChatCompletionResponseChoice,
- ChatCompletionResponseStreamChoice,
- ChatCompletionStreamResponse,
- ChatCompletionTokenLogprob,
- ChatMessage,
- ChoiceLogprobs,
- CompletionRequest,
- CompletionResponse,
- CompletionResponseChoice,
- CompletionResponseStreamChoice,
- CompletionStreamResponse,
- DeltaMessage,
- EmbeddingObject,
- EmbeddingRequest,
- EmbeddingResponse,
- ErrorResponse,
- FileDeleteResponse,
- FileRequest,
- FileResponse,
- FunctionResponse,
- LogProbs,
- MultimodalEmbeddingInput,
- RerankResponse,
- ScoringRequest,
- ScoringResponse,
- ToolCall,
- TopLogprob,
- UsageInfo,
-)
-from sglang.srt.openai_api.utils import (
- detect_template_content_format,
- process_content_for_template_format,
-)
-from sglang.srt.reasoning_parser import ReasoningParser
-from sglang.utils import convert_json_schema_to_str, get_exception_traceback
-
-logger = logging.getLogger(__name__)
-
-chat_template_name = None
-
-# Global cache for template content format detection (one model/template per instance)
-# NOTE: A better approach would be to initialize the chat template format when the endpoint is created
-_cached_chat_template = None
-_cached_template_format = None
-
-
-class FileMetadata:
- def __init__(self, filename: str, purpose: str):
- self.filename = filename
- self.purpose = purpose
-
-
-# In-memory storage for batch jobs and files
-batch_storage: Dict[str, BatchResponse] = {}
-file_id_request: Dict[str, FileMetadata] = {}
-file_id_response: Dict[str, FileResponse] = {}
-# map file id to file path in SGLang backend
-file_id_storage: Dict[str, str] = {}
-
-# backend storage directory
-storage_dir = None
-
-
-def create_error_response(
- message: str,
- err_type: str = "BadRequestError",
- status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
-):
- error = ErrorResponse(message=message, type=err_type, code=status_code.value)
- return ORJSONResponse(content=error.model_dump(), status_code=error.code)
-
-
-def create_streaming_error_response(
- message: str,
- err_type: str = "BadRequestError",
- status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
-) -> str:
- error = ErrorResponse(message=message, type=err_type, code=status_code.value)
- json_str = json.dumps({"error": error.model_dump()})
- return json_str
-
-
-def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path):
- global chat_template_name
-
- logger.info(
- f"Use chat template for the OpenAI-compatible API server: {chat_template_arg}"
- )
-
- if not chat_template_exists(chat_template_arg):
- if not os.path.exists(chat_template_arg):
- raise RuntimeError(
- f"Chat template {chat_template_arg} is not a built-in template name "
- "or a valid chat template file path."
- )
- if chat_template_arg.endswith(".jinja"):
- with open(chat_template_arg, "r") as f:
- chat_template = "".join(f.readlines()).strip("\n")
- tokenizer_manager.tokenizer.chat_template = chat_template.replace(
- "\\n", "\n"
- )
- chat_template_name = None
- else:
- assert chat_template_arg.endswith(
- ".json"
- ), "unrecognized format of chat template file"
- with open(chat_template_arg, "r") as filep:
- template = json.load(filep)
- try:
- sep_style = SeparatorStyle[template["sep_style"]]
- except KeyError:
- raise ValueError(
- f"Unknown separator style: {template['sep_style']}"
- ) from None
- register_conv_template(
- Conversation(
- name=template["name"],
- system_template=template["system"] + "\n{system_message}",
- system_message=template.get("system_message", ""),
- roles=(template["user"], template["assistant"]),
- sep_style=sep_style,
- sep=template.get("sep", "\n"),
- stop_str=template["stop_str"],
- ),
- override=True,
- )
- chat_template_name = template["name"]
- else:
- chat_template_name = chat_template_arg
-
-
-def guess_chat_template_name_from_model_path(model_path):
- global chat_template_name
- chat_template_name = get_conv_template_by_model_path(model_path)
- if chat_template_name is not None:
- logger.info(
- f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
- )
-
-
-def _validate_prompt(prompt: str):
- """Validate that the prompt is not empty or whitespace only."""
- is_invalid = False
-
- # Check for empty/whitespace string
- if isinstance(prompt, str):
- is_invalid = not prompt.strip()
- # Check for various invalid list cases: [], [""], [" "], [[]]
- elif isinstance(prompt, list):
- is_invalid = not prompt or (
- len(prompt) == 1
- and (
- (isinstance(prompt[0], str) and not prompt[0].strip())
- or (isinstance(prompt[0], list) and not prompt[0])
- )
- )
-
- if is_invalid:
- raise HTTPException(
- status_code=400,
- detail="Input cannot be empty or contain only whitespace.",
- )
-
- return prompt
-
-
-async def v1_files_create(
- file: UploadFile, purpose: str, file_storage_path: str = None
-):
- try:
- global storage_dir
- if file_storage_path:
- storage_dir = file_storage_path
- # Read the file content
- file_content = await file.read()
-
- # Create an instance of RequestBody
- request_body = FileRequest(file=file_content, purpose=purpose)
-
- # Save the file to the sglang_oai_storage directory
- os.makedirs(storage_dir, exist_ok=True)
- file_id = f"backend_input_file-{uuid.uuid4()}"
- filename = f"{file_id}.jsonl"
- file_path = os.path.join(storage_dir, filename)
-
- with open(file_path, "wb") as f:
- f.write(request_body.file)
-
- # add info to global file map
- file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
- file_id_storage[file_id] = file_path
-
- # Return the response in the required format
- response = FileResponse(
- id=file_id,
- bytes=len(request_body.file),
- created_at=int(time.time()),
- filename=file.filename,
- purpose=request_body.purpose,
- )
- file_id_response[file_id] = response
-
- return response
- except ValidationError as e:
- return {"error": "Invalid input", "details": e.errors()}
-
-
-async def v1_delete_file(file_id: str):
- # Retrieve the file job from the in-memory storage
- file_response = file_id_response.get(file_id)
- if file_response is None:
- raise HTTPException(status_code=404, detail="File not found")
- file_path = file_id_storage.get(file_id)
- if file_path is None:
- raise HTTPException(status_code=404, detail="File not found")
- os.remove(file_path)
- del file_id_response[file_id]
- del file_id_storage[file_id]
- return FileDeleteResponse(id=file_id, deleted=True)
-
-
-async def v1_batches(tokenizer_manager, raw_request: Request):
- try:
- body = await raw_request.json()
-
- batch_request = BatchRequest(**body)
-
- batch_id = f"batch_{uuid.uuid4()}"
-
- # Create an instance of BatchResponse
- batch_response = BatchResponse(
- id=batch_id,
- endpoint=batch_request.endpoint,
- input_file_id=batch_request.input_file_id,
- completion_window=batch_request.completion_window,
- created_at=int(time.time()),
- metadata=batch_request.metadata,
- )
-
- batch_storage[batch_id] = batch_response
-
- # Start processing the batch asynchronously
- asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
-
- # Return the initial batch_response
- return batch_response
-
- except ValidationError as e:
- return {"error": "Invalid input", "details": e.errors()}
- except Exception as e:
- return {"error": str(e)}
-
-
-async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
- try:
- # Update the batch status to "in_progress"
- batch_storage[batch_id].status = "in_progress"
- batch_storage[batch_id].in_progress_at = int(time.time())
-
- # Retrieve the input file content
- input_file_request = file_id_request.get(batch_request.input_file_id)
- if not input_file_request:
- raise ValueError("Input file not found")
-
- # Parse the JSONL file and process each request
- input_file_path = file_id_storage.get(batch_request.input_file_id)
- with open(input_file_path, "r", encoding="utf-8") as f:
- lines = f.readlines()
-
- total_requests = len(lines)
- completed_requests = 0
- failed_requests = 0
-
- all_ret = []
- end_point = batch_storage[batch_id].endpoint
- file_request_list = []
- all_requests = []
- request_ids = []
- for line_id, line in enumerate(lines):
- request_data = json.loads(line)
- file_request_list.append(request_data)
- body = request_data["body"]
- request_ids.append(f"{batch_id}-req_{line_id}")
-
- # Although streaming is supported for standalone completions, it is not supported in
- # batch mode (multiple completions in single request).
- if body.get("stream", False):
- raise ValueError("Streaming requests are not supported in batch mode")
-
- if end_point == "/v1/chat/completions":
- all_requests.append(ChatCompletionRequest(**body))
- elif end_point == "/v1/completions":
- all_requests.append(CompletionRequest(**body))
-
- if end_point == "/v1/chat/completions":
- adapted_request, request = v1_chat_generate_request(
- all_requests, tokenizer_manager, request_ids=request_ids
- )
- elif end_point == "/v1/completions":
- adapted_request, request = v1_generate_request(
- all_requests, request_ids=request_ids
- )
-
- try:
- created = int(time.time())
- ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
- if not isinstance(ret, list):
- ret = [ret]
- if end_point == "/v1/chat/completions":
- responses = v1_chat_generate_response(
- request,
- ret,
- created,
- to_file=True,
- cache_report=tokenizer_manager.server_args.enable_cache_report,
- tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
- )
- else:
- responses = v1_generate_response(
- request,
- ret,
- tokenizer_manager,
- created,
- to_file=True,
- cache_report=tokenizer_manager.server_args.enable_cache_report,
- )
-
- except Exception as e:
- logger.error(f"error: {get_exception_traceback()}")
- responses = []
- error_json = {
- "id": f"batch_req_{uuid.uuid4()}",
- "custom_id": request_data.get("custom_id"),
- "response": None,
- "error": {"message": str(e)},
- }
- all_ret.append(error_json)
- failed_requests += len(file_request_list)
-
- for idx, response in enumerate(responses):
- # the batch_req here can be changed to be named within a batch granularity
- response_json = {
- "id": f"batch_req_{uuid.uuid4()}",
- "custom_id": file_request_list[idx].get("custom_id"),
- "response": response,
- "error": None,
- }
- all_ret.append(response_json)
- completed_requests += 1
-
- # Write results to a new file
- output_file_id = f"backend_result_file-{uuid.uuid4()}"
- global storage_dir
- output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
- with open(output_file_path, "w", encoding="utf-8") as f:
- for ret in all_ret:
- f.write(json.dumps(ret) + "\n")
-
- # Update batch response with output file information
- retrieve_batch = batch_storage[batch_id]
- retrieve_batch.output_file_id = output_file_id
- file_id_storage[output_file_id] = output_file_path
- file_id_response[output_file_id] = FileResponse(
- id=output_file_id,
- bytes=os.path.getsize(output_file_path),
- created_at=int(time.time()),
- filename=f"{output_file_id}.jsonl",
- purpose="batch_result",
- )
- # Update batch status to "completed"
- retrieve_batch.status = "completed"
- retrieve_batch.completed_at = int(time.time())
- retrieve_batch.request_counts = {
- "total": total_requests,
- "completed": completed_requests,
- "failed": failed_requests,
- }
-
- except Exception as e:
- logger.error(f"error: {e}")
- # Update batch status to "failed"
- retrieve_batch = batch_storage[batch_id]
- retrieve_batch.status = "failed"
- retrieve_batch.failed_at = int(time.time())
- retrieve_batch.errors = {"message": str(e)}
-
-
-async def v1_retrieve_batch(batch_id: str):
- # Retrieve the batch job from the in-memory storage
- batch_response = batch_storage.get(batch_id)
- if batch_response is None:
- raise HTTPException(status_code=404, detail="Batch not found")
-
- return batch_response
-
-
-async def v1_cancel_batch(tokenizer_manager, batch_id: str):
- # Retrieve the batch job from the in-memory storage
- batch_response = batch_storage.get(batch_id)
- if batch_response is None:
- raise HTTPException(status_code=404, detail="Batch not found")
-
- # Only do cancal when status is "validating" or "in_progress"
- if batch_response.status in ["validating", "in_progress"]:
- # Start cancelling the batch asynchronously
- asyncio.create_task(
- cancel_batch(
- tokenizer_manager=tokenizer_manager,
- batch_id=batch_id,
- input_file_id=batch_response.input_file_id,
- )
- )
-
- # Update batch status to "cancelling"
- batch_response.status = "cancelling"
-
- return batch_response
- else:
- raise HTTPException(
- status_code=500,
- detail=f"Current status is {batch_response.status}, no need to cancel",
- )
-
-
-async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
- try:
- # Update the batch status to "cancelling"
- batch_storage[batch_id].status = "cancelling"
-
- # Retrieve the input file content
- input_file_request = file_id_request.get(input_file_id)
- if not input_file_request:
- raise ValueError("Input file not found")
-
- # Parse the JSONL file and process each request
- input_file_path = file_id_storage.get(input_file_id)
- with open(input_file_path, "r", encoding="utf-8") as f:
- lines = f.readlines()
-
- # Cancel requests by request_ids
- for line_id in range(len(lines)):
- rid = f"{batch_id}-req_{line_id}"
- tokenizer_manager.abort_request(rid=rid)
-
- retrieve_batch = batch_storage[batch_id]
- retrieve_batch.status = "cancelled"
-
- except Exception as e:
- logger.error("error in SGLang:", e)
- # Update batch status to "failed"
- retrieve_batch = batch_storage[batch_id]
- retrieve_batch.status = "failed"
- retrieve_batch.failed_at = int(time.time())
- retrieve_batch.errors = {"message": str(e)}
-
-
-async def v1_retrieve_file(file_id: str):
- # Retrieve the batch job from the in-memory storage
- file_response = file_id_response.get(file_id)
- if file_response is None:
- raise HTTPException(status_code=404, detail="File not found")
- return file_response
-
-
-async def v1_retrieve_file_content(file_id: str):
- file_pth = file_id_storage.get(file_id)
- if not file_pth or not os.path.exists(file_pth):
- raise HTTPException(status_code=404, detail="File not found")
-
- def iter_file():
- with open(file_pth, mode="rb") as file_like:
- yield from file_like
-
- return StreamingResponse(iter_file(), media_type="application/octet-stream")
-
-
-def v1_generate_request(
- all_requests: List[CompletionRequest], request_ids: List[str] = None
-):
- if len(all_requests) > 1:
- first_prompt_type = type(all_requests[0].prompt)
- for request in all_requests:
- assert (
- type(request.prompt) is first_prompt_type
- ), "All prompts must be of the same type in file input settings"
- if request.n > 1:
- raise ValueError(
- "Parallel sampling is not supported for completions from files"
- )
-
- prompts = []
- sampling_params_list = []
- return_logprobs = []
- logprob_start_lens = []
- top_logprobs_nums = []
- lora_paths = []
- return_hidden_states = []
-
- for request in all_requests:
- # NOTE: with openai API, the prompt's logprobs are always not computed
- if request.echo and request.logprobs:
- logger.warning(
- "Echo is not compatible with logprobs. "
- "To compute logprobs of input prompt, please use the native /generate API."
- )
-
- prompt = request.prompt
- if is_completion_template_defined():
- prompt = generate_completion_prompt_from_request(request)
- prompts.append(prompt)
-
- lora_paths.append(request.lora_path)
- if request.echo and request.logprobs:
- current_logprob_start_len = 0
- else:
- current_logprob_start_len = -1
- sampling_params_list.append(
- {
- "temperature": request.temperature,
- "max_new_tokens": request.max_tokens,
- "min_new_tokens": request.min_tokens,
- "stop": request.stop,
- "stop_token_ids": request.stop_token_ids,
- "top_p": request.top_p,
- "top_k": request.top_k,
- "min_p": request.min_p,
- "presence_penalty": request.presence_penalty,
- "frequency_penalty": request.frequency_penalty,
- "repetition_penalty": request.repetition_penalty,
- "regex": request.regex,
- "json_schema": request.json_schema,
- "ebnf": request.ebnf,
- "n": request.n,
- "no_stop_trim": request.no_stop_trim,
- "ignore_eos": request.ignore_eos,
- "skip_special_tokens": request.skip_special_tokens,
- "logit_bias": request.logit_bias,
- }
- )
- return_logprobs.append(request.logprobs is not None)
- logprob_start_lens.append(current_logprob_start_len)
- top_logprobs_nums.append(
- request.logprobs if request.logprobs is not None else 0
- )
- return_hidden_states.append(request.return_hidden_states)
-
- if len(all_requests) == 1:
- if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
- prompt_kwargs = {"text": prompts[0]}
- else:
- prompt_kwargs = {"input_ids": prompts[0]}
- sampling_params_list = sampling_params_list[0]
- return_logprobs = return_logprobs[0]
- logprob_start_lens = logprob_start_lens[0]
- top_logprobs_nums = top_logprobs_nums[0]
- lora_paths = lora_paths[0]
- return_hidden_states = return_hidden_states[0]
- else:
- if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
- prompt_kwargs = {"text": prompts}
- else:
- prompt_kwargs = {"input_ids": prompts}
-
- adapted_request = GenerateReqInput(
- **prompt_kwargs,
- sampling_params=sampling_params_list,
- return_logprob=return_logprobs,
- top_logprobs_num=top_logprobs_nums,
- logprob_start_len=logprob_start_lens,
- return_text_in_logprobs=True,
- stream=all_requests[0].stream,
- rid=request_ids,
- lora_path=lora_paths,
- return_hidden_states=return_hidden_states,
- bootstrap_host=all_requests[0].bootstrap_host,
- bootstrap_port=all_requests[0].bootstrap_port,
- bootstrap_room=all_requests[0].bootstrap_room,
- )
-
- return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
-
-
-def v1_generate_response(
- request, ret, tokenizer_manager, created, to_file=False, cache_report=False
-):
- choices = []
- echo = False
-
- if (not isinstance(request, list)) and request.echo:
- # TODO: handle the case prompt is token ids
- if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
- # for the case of multiple str prompts
- prompts = request.prompt
- elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
- # for the case of multiple token ids prompts
- prompts = [
- tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
- for prompt in request.prompt
- ]
- elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
- # for the case of single token ids prompt
- prompts = [
- tokenizer_manager.tokenizer.decode(
- request.prompt, skip_special_tokens=True
- )
- ]
- else:
- # for the case of single str prompt
- prompts = [request.prompt]
- echo = True
-
- for idx, ret_item in enumerate(ret):
- text = ret_item["text"]
- if isinstance(request, list) and request[idx].echo:
- echo = True
- text = request[idx].prompt + text
- if echo and not isinstance(request, list):
- prompt_index = idx // request.n
- text = prompts[prompt_index] + text
-
- logprobs = False
- if isinstance(request, list) and request[idx].logprobs is not None:
- logprobs = True
- elif (not isinstance(request, list)) and request.logprobs is not None:
- logprobs = True
- if logprobs:
- if echo:
- input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
- input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
- else:
- input_token_logprobs = None
- input_top_logprobs = None
-
- logprobs = to_openai_style_logprobs(
- input_token_logprobs=input_token_logprobs,
- input_top_logprobs=input_top_logprobs,
- output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
- output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
- )
- else:
- logprobs = None
-
- hidden_states = None
- if isinstance(request, list) and request[idx].return_hidden_states:
- hidden_states = ret_item["meta_info"].get("hidden_states", None)
- elif (not isinstance(request, list)) and request.return_hidden_states:
- hidden_states = ret_item["meta_info"].get("hidden_states", None)
- if hidden_states is not None:
- hidden_states = (
- hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
- )
-
- finish_reason = ret_item["meta_info"]["finish_reason"]
-
- if to_file:
- # to make the choice data json serializable
- choice_data = {
- "index": 0,
- "text": text,
- "logprobs": logprobs,
- "finish_reason": finish_reason["type"] if finish_reason else None,
- "matched_stop": (
- finish_reason["matched"]
- if finish_reason and "matched" in finish_reason
- else None
- ),
- }
- if hidden_states is not None:
- choice_data["hidden_states"] = hidden_states
- else:
- choice_data = CompletionResponseChoice(
- index=idx,
- text=text,
- logprobs=logprobs,
- finish_reason=finish_reason["type"] if finish_reason else None,
- matched_stop=(
- finish_reason["matched"]
- if finish_reason and "matched" in finish_reason
- else None
- ),
- hidden_states=hidden_states,
- )
-
- choices.append(choice_data)
-
- if to_file:
- responses = []
- for i, choice in enumerate(choices):
- response = {
- "status_code": 200,
- "request_id": ret[i]["meta_info"]["id"],
- "body": {
- # remain the same but if needed we can change that
- "id": ret[i]["meta_info"]["id"],
- "object": "text_completion",
- "created": created,
- "model": request[i].model,
- "choices": choice,
- "usage": {
- "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
- "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
- "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
- + ret[i]["meta_info"]["completion_tokens"],
- },
- "system_fingerprint": None,
- },
- }
- responses.append(response)
- return responses
- else:
- prompt_tokens = sum(
- ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
- )
- completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
- cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
- response = CompletionResponse(
- id=ret[0]["meta_info"]["id"],
- model=request.model,
- created=created,
- choices=choices,
- usage=UsageInfo(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- prompt_tokens_details=(
- {"cached_tokens": cached_tokens} if cache_report else None
- ),
- ),
- )
- return response
-
-
-async def v1_completions(tokenizer_manager, raw_request: Request):
- try:
- request_json = await raw_request.json()
- except Exception as e:
- return create_error_response("Invalid request body, error: ", str(e))
- all_requests = [CompletionRequest(**request_json)]
- created = int(time.time())
- adapted_request, request = v1_generate_request(all_requests)
-
- if adapted_request.stream:
-
- async def generate_stream_resp():
- stream_buffers = {}
- n_prev_tokens = {}
- prompt_tokens = {}
- completion_tokens = {}
- cached_tokens = {}
- hidden_states = {}
-
- try:
- async for content in tokenizer_manager.generate_request(
- adapted_request, raw_request
- ):
- index = content.get("index", 0)
-
- stream_buffer = stream_buffers.get(index, "")
- n_prev_token = n_prev_tokens.get(index, 0)
-
- text = content["text"]
- prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
- completion_tokens[index] = content["meta_info"]["completion_tokens"]
- cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
- hidden_states[index] = content["meta_info"].get(
- "hidden_states", None
- ) or hidden_states.get(index)
-
- if not stream_buffer: # The first chunk
- if request.echo:
- if isinstance(request.prompt, str):
- # for the case of single str prompts
- prompts = request.prompt
- elif isinstance(request.prompt, list):
- if isinstance(request.prompt[0], str):
- # for the case of multiple str prompts
- prompts = request.prompt[index // request.n]
- elif isinstance(request.prompt[0], int):
- # for the case of single token ids prompt
- prompts = tokenizer_manager.tokenizer.decode(
- request.prompt, skip_special_tokens=True
- )
- elif isinstance(request.prompt[0], list) and isinstance(
- request.prompt[0][0], int
- ):
- # for the case of multiple token ids prompts
- prompts = tokenizer_manager.tokenizer.decode(
- request.prompt[index // request.n],
- skip_special_tokens=True,
- )
-
- # Prepend prompt in response text.
- text = prompts + text
-
- if request.logprobs is not None:
- # The first chunk and echo is enabled.
- if not stream_buffer and request.echo:
- input_token_logprobs = content["meta_info"][
- "input_token_logprobs"
- ]
- input_top_logprobs = content["meta_info"][
- "input_top_logprobs"
- ]
- else:
- input_token_logprobs = None
- input_top_logprobs = None
-
- logprobs = to_openai_style_logprobs(
- input_token_logprobs=input_token_logprobs,
- input_top_logprobs=input_top_logprobs,
- output_token_logprobs=content["meta_info"][
- "output_token_logprobs"
- ][n_prev_token:],
- output_top_logprobs=content["meta_info"][
- "output_top_logprobs"
- ][n_prev_token:],
- )
- n_prev_token = len(
- content["meta_info"]["output_token_logprobs"]
- )
- else:
- logprobs = None
-
- delta = text[len(stream_buffer) :]
- stream_buffer = stream_buffer + delta
- finish_reason = content["meta_info"]["finish_reason"]
- choice_data = CompletionResponseStreamChoice(
- index=index,
- text=delta,
- logprobs=logprobs,
- finish_reason=finish_reason["type"] if finish_reason else None,
- matched_stop=(
- finish_reason["matched"]
- if finish_reason and "matched" in finish_reason
- else None
- ),
- )
- chunk = CompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- object="text_completion",
- choices=[choice_data],
- model=request.model,
- )
-
- stream_buffers[index] = stream_buffer
- n_prev_tokens[index] = n_prev_token
-
- yield f"data: {chunk.model_dump_json()}\n\n"
- if request.return_hidden_states and hidden_states:
- for index, choice_hidden_states in hidden_states.items():
- last_token_hidden_states = (
- choice_hidden_states[-1]
- if choice_hidden_states and len(choice_hidden_states) > 1
- else []
- )
- hidden_states_chunk = CompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[
- CompletionResponseStreamChoice(
- text="",
- index=index,
- hidden_states=last_token_hidden_states,
- finish_reason=None,
- )
- ],
- model=request.model,
- )
- yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
- if request.stream_options and request.stream_options.include_usage:
- total_prompt_tokens = sum(
- tokens
- for i, tokens in prompt_tokens.items()
- if i % request.n == 0
- )
- total_completion_tokens = sum(
- tokens for tokens in completion_tokens.values()
- )
- cache_report = tokenizer_manager.server_args.enable_cache_report
- if cache_report:
- cached_tokens_sum = sum(
- tokens for tokens in cached_tokens.values()
- )
- prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
- else:
- prompt_tokens_details = None
- usage = UsageInfo(
- prompt_tokens=total_prompt_tokens,
- completion_tokens=total_completion_tokens,
- total_tokens=total_prompt_tokens + total_completion_tokens,
- prompt_tokens_details=prompt_tokens_details,
- )
-
- final_usage_chunk = CompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[],
- model=request.model,
- usage=usage,
- )
- final_usage_data = final_usage_chunk.model_dump_json(
- exclude_none=True
- )
- yield f"data: {final_usage_data}\n\n"
- except ValueError as e:
- error = create_streaming_error_response(str(e))
- yield f"data: {error}\n\n"
- yield "data: [DONE]\n\n"
-
- return StreamingResponse(
- generate_stream_resp(),
- media_type="text/event-stream",
- background=tokenizer_manager.create_abort_task(adapted_request),
- )
-
- # Non-streaming response.
- try:
- ret = await tokenizer_manager.generate_request(
- adapted_request, raw_request
- ).__anext__()
- except ValueError as e:
- return create_error_response(str(e))
-
- if not isinstance(ret, list):
- ret = [ret]
-
- response = v1_generate_response(
- request,
- ret,
- tokenizer_manager,
- created,
- cache_report=tokenizer_manager.server_args.enable_cache_report,
- )
- return response
-
-
-def _get_enable_thinking_from_request(request_obj):
- """Extracts the 'enable_thinking' flag from request chat_template_kwargs.
-
- Args:
- request_obj: The request object (or an item from a list of requests).
-
- Returns:
- The boolean value of 'enable_thinking' if found and not True, otherwise True.
- """
- if (
- hasattr(request_obj, "chat_template_kwargs")
- and request_obj.chat_template_kwargs
- and request_obj.chat_template_kwargs.get("enable_thinking") is not None
- ):
- return request_obj.chat_template_kwargs.get("enable_thinking")
- return True
-
-
-def v1_chat_generate_request(
- all_requests: List[ChatCompletionRequest],
- tokenizer_manager,
- request_ids: List[str] = None,
-):
- input_ids = []
- prompts = []
- sampling_params_list = []
- image_data_list = []
- audio_data_list = []
- return_logprobs = []
- logprob_start_lens = []
- top_logprobs_nums = []
- modalities_list = []
- lora_paths = []
- return_hidden_states = []
-
- # NOTE: with openai API, the prompt's logprobs are always not computed
-
- is_multimodal = tokenizer_manager.model_config.is_multimodal
- for request in all_requests:
- # Prep the data needed for the underlying GenerateReqInput:
- # - prompt: The full prompt string.
- # - stop: Custom stop tokens.
- # - image_data: None or a list of image strings (URLs or base64 strings).
- # - audio_data: None or a list of audio strings (URLs).
- # None skips any image processing in GenerateReqInput.
- tool_call_constraint = None
- prompt = ""
- prompt_ids = []
- if not isinstance(request.messages, str):
- # Apply chat template and its stop strings.
- tools = None
- if request.tools and request.tool_choice != "none":
- request.skip_special_tokens = False
- if not isinstance(request.tool_choice, str):
- tools = [
- item.function.model_dump()
- for item in request.tools
- if item.function.name == request.tool_choice.function.name
- ]
- else:
- tools = [item.function.model_dump() for item in request.tools]
-
- tool_call_parser = tokenizer_manager.server_args.tool_call_parser
- parser = FunctionCallParser(request.tools, tool_call_parser)
- tool_call_constraint = parser.get_structure_constraint(
- request.tool_choice
- )
-
- if chat_template_name is None:
- openai_compatible_messages = []
- image_data = []
- audio_data = []
- modalities = []
-
- # Detect template content format by analyzing the jinja template (cached globally)
- global _cached_chat_template, _cached_template_format
- current_template = tokenizer_manager.tokenizer.chat_template
-
- if current_template != _cached_chat_template:
- # Template changed or first time - analyze it
- _cached_chat_template = current_template
- _cached_template_format = detect_template_content_format(
- current_template
- )
- logger.info(
- f"Detected chat template content format: {_cached_template_format}"
- )
-
- template_content_format = _cached_template_format
-
- for message in request.messages:
- if message.content is None:
- message.content = ""
- msg_dict = message.model_dump()
-
- # Process content based on detected template format
- processed_msg = process_content_for_template_format(
- msg_dict,
- template_content_format,
- image_data,
- audio_data,
- modalities,
- )
- openai_compatible_messages.append(processed_msg)
-
- # Handle assistant prefix for continue_final_message
- if (
- openai_compatible_messages
- and openai_compatible_messages[-1]["role"] == "assistant"
- ):
- if request.continue_final_message:
- # Remove the final assistant message so its content can be continued.
- assistant_prefix = openai_compatible_messages[-1]["content"]
- openai_compatible_messages = openai_compatible_messages[:-1]
- else:
- assistant_prefix = None
- else:
- assistant_prefix = None
-
- try:
- prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
- openai_compatible_messages,
- tokenize=True,
- add_generation_prompt=True,
- tools=tools,
- **(
- request.chat_template_kwargs
- if request.chat_template_kwargs
- else {}
- ),
- )
- except:
- # This except branch will be triggered when the chosen model
- # has a different tools input format that is not compatible
- # with openAI's apply_chat_template tool_call format, like Mistral.
- tools = [t if "function" in t else {"function": t} for t in tools]
- prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
- openai_compatible_messages,
- tokenize=True,
- add_generation_prompt=True,
- tools=tools,
- **(
- request.chat_template_kwargs
- if request.chat_template_kwargs
- else {}
- ),
- )
-
- if assistant_prefix:
- encoded = tokenizer_manager.tokenizer.encode(assistant_prefix)
- if (
- encoded
- and encoded[0] == tokenizer_manager.tokenizer.bos_token_id
- ):
- encoded = encoded[1:]
- prompt_ids += encoded
- if is_multimodal:
- prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
- stop = request.stop
- image_data = image_data if image_data else None
- audio_data = audio_data if audio_data else None
- modalities = modalities if modalities else []
- else:
- conv = generate_chat_conv(request, chat_template_name)
- # If we should continue the final assistant message, adjust the conversation.
- if (
- request.continue_final_message
- and request.messages
- and request.messages[-1].role == "assistant"
- ):
- # Remove the auto-added blank assistant turn, if present.
- if conv.messages and conv.messages[-1][1] is None:
- conv.messages.pop()
- # Rebuild the prompt from the conversation.
- prompt = conv.get_prompt()
- # Strip any trailing stop tokens or separators that indicate end-of-assistant.
- if isinstance(conv.stop_str, list):
- for stop_token in conv.stop_str:
- if prompt.endswith(stop_token):
- prompt = prompt[: -len(stop_token)]
- elif isinstance(conv.stop_str, str) and prompt.endswith(
- conv.stop_str
- ):
- prompt = prompt[: -len(conv.stop_str)]
- if conv.sep and prompt.endswith(conv.sep):
- prompt = prompt[: -len(conv.sep)]
- if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
- prompt = prompt[: -len(conv.sep2)]
- else:
- prompt = conv.get_prompt()
-
- image_data = conv.image_data
- audio_data = conv.audio_data
- modalities = conv.modalities
- stop = conv.stop_str or [] if not request.ignore_eos else []
-
- if request.stop:
- if isinstance(request.stop, str):
- stop.append(request.stop)
- else:
- stop.extend(request.stop)
-
- if not is_multimodal:
- prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
- else:
- # Use the raw prompt and stop strings if the messages is already a string.
- prompt_ids = request.messages
- stop = request.stop
- image_data = None
- audio_data = None
- modalities = []
- prompt = request.messages
- input_ids.append(prompt_ids)
- return_logprobs.append(request.logprobs)
- logprob_start_lens.append(-1)
- top_logprobs_nums.append(request.top_logprobs or 0)
- lora_paths.append(request.lora_path)
- prompts.append(prompt)
-
- sampling_params = {
- "temperature": request.temperature,
- "max_new_tokens": request.max_tokens or request.max_completion_tokens,
- "min_new_tokens": request.min_tokens,
- "stop": stop,
- "stop_token_ids": request.stop_token_ids,
- "top_p": request.top_p,
- "top_k": request.top_k,
- "min_p": request.min_p,
- "presence_penalty": request.presence_penalty,
- "frequency_penalty": request.frequency_penalty,
- "repetition_penalty": request.repetition_penalty,
- "regex": request.regex,
- "ebnf": request.ebnf,
- "n": request.n,
- "no_stop_trim": request.no_stop_trim,
- "ignore_eos": request.ignore_eos,
- "skip_special_tokens": request.skip_special_tokens,
- "logit_bias": request.logit_bias,
- }
-
- if request.response_format and request.response_format.type == "json_schema":
- sampling_params["json_schema"] = convert_json_schema_to_str(
- request.response_format.json_schema.schema_
- )
- elif request.response_format and request.response_format.type == "json_object":
- sampling_params["json_schema"] = '{"type": "object"}'
- elif (
- request.response_format and request.response_format.type == "structural_tag"
- ):
- sampling_params["structural_tag"] = convert_json_schema_to_str(
- request.response_format.model_dump(by_alias=True)
- )
-
- # Check if there are already existing output constraints
- has_existing_constraints = (
- sampling_params.get("regex")
- or sampling_params.get("ebnf")
- or sampling_params.get("structural_tag")
- or sampling_params.get("json_schema")
- )
-
- if tool_call_constraint and has_existing_constraints:
- logger.warning("Constrained decoding is not compatible with tool calls.")
- elif tool_call_constraint:
- constraint_type, constraint_value = tool_call_constraint
- if constraint_type == "structural_tag":
- sampling_params[constraint_type] = convert_json_schema_to_str(
- constraint_value.model_dump(by_alias=True)
- )
- else:
- sampling_params[constraint_type] = constraint_value
-
- sampling_params_list.append(sampling_params)
-
- image_data_list.append(image_data)
- audio_data_list.append(audio_data)
- modalities_list.append(modalities)
- return_hidden_states.append(request.return_hidden_states)
- if len(all_requests) == 1:
- if is_multimodal:
- # processor will need text input
- prompt_kwargs = {"text": prompts[0]}
- else:
- if isinstance(input_ids[0], str):
- prompt_kwargs = {"text": input_ids[0]}
- else:
- prompt_kwargs = {"input_ids": input_ids[0]}
- sampling_params_list = sampling_params_list[0]
- image_data_list = image_data_list[0]
- audio_data_list = audio_data_list[0]
- return_logprobs = return_logprobs[0]
- logprob_start_lens = logprob_start_lens[0]
- top_logprobs_nums = top_logprobs_nums[0]
- modalities_list = modalities_list[0]
- lora_paths = lora_paths[0]
- request_ids = request_ids[0]
- return_hidden_states = return_hidden_states[0]
- else:
- if tokenizer_manager.model_config.is_multimodal:
- # processor will need text input
- prompt_kwargs = {"text": prompts}
- else:
- if isinstance(input_ids[0], str):
- prompt_kwargs = {"text": input_ids}
- else:
- prompt_kwargs = {"input_ids": input_ids}
-
- adapted_request = GenerateReqInput(
- **prompt_kwargs,
- image_data=image_data_list,
- audio_data=audio_data_list,
- sampling_params=sampling_params_list,
- return_logprob=return_logprobs,
- logprob_start_len=logprob_start_lens,
- top_logprobs_num=top_logprobs_nums,
- stream=all_requests[0].stream,
- return_text_in_logprobs=True,
- rid=request_ids,
- modalities=modalities_list,
- lora_path=lora_paths,
- bootstrap_host=all_requests[0].bootstrap_host,
- bootstrap_port=all_requests[0].bootstrap_port,
- bootstrap_room=all_requests[0].bootstrap_room,
- return_hidden_states=return_hidden_states,
- )
-
- return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
-
-
-def v1_chat_generate_response(
- request,
- ret,
- created,
- to_file=False,
- cache_report=False,
- tool_call_parser=None,
- reasoning_parser=None,
-):
- choices = []
-
- for idx, ret_item in enumerate(ret):
- logprobs = False
- if isinstance(request, list) and request[idx].logprobs:
- logprobs = True
- elif (not isinstance(request, list)) and request.logprobs:
- logprobs = True
- if logprobs:
- logprobs = to_openai_style_logprobs(
- output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
- output_top_logprobs=ret_item["meta_info"].get(
- "output_top_logprobs", None
- ),
- )
- token_logprobs = []
- for token_idx, (token, logprob) in enumerate(
- zip(logprobs.tokens, logprobs.token_logprobs)
- ):
- token_bytes = list(token.encode("utf-8"))
- top_logprobs = []
- if logprobs.top_logprobs:
- for top_token, top_logprob in logprobs.top_logprobs[
- token_idx
- ].items():
- top_token_bytes = list(top_token.encode("utf-8"))
- top_logprobs.append(
- TopLogprob(
- token=top_token,
- bytes=top_token_bytes,
- logprob=top_logprob,
- )
- )
- token_logprobs.append(
- ChatCompletionTokenLogprob(
- token=token,
- bytes=token_bytes,
- logprob=logprob,
- top_logprobs=top_logprobs,
- )
- )
-
- choice_logprobs = ChoiceLogprobs(content=token_logprobs)
- else:
- choice_logprobs = None
-
- if isinstance(request, list) and request[idx].return_hidden_states:
- include_hidden_states = True
- elif not isinstance(request, list) and request.return_hidden_states:
- include_hidden_states = True
- else:
- include_hidden_states = False
- if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
- hidden_states = ret_item["meta_info"]["hidden_states"]
- hidden_states = (
- hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
- )
- else:
- hidden_states = None
-
- finish_reason = ret_item["meta_info"]["finish_reason"]
-
- tool_calls = None
- text = ret_item["text"]
-
- if isinstance(request, list):
- tool_choice = request[idx].tool_choice
- tools = request[idx].tools
- separate_reasoning = request[idx].separate_reasoning
- enable_thinking = _get_enable_thinking_from_request(request[idx])
- else:
- tool_choice = request.tool_choice
- tools = request.tools
- separate_reasoning = request.separate_reasoning
- enable_thinking = _get_enable_thinking_from_request(request)
-
- reasoning_text = None
- if reasoning_parser and separate_reasoning and enable_thinking:
- try:
- parser = ReasoningParser(
- model_type=reasoning_parser, stream_reasoning=False
- )
- reasoning_text, text = parser.parse_non_stream(text)
- except Exception as e:
- logger.error(f"Exception: {e}")
- return create_error_response(
- HTTPStatus.BAD_REQUEST,
- "Failed to parse reasoning related info to json format!",
- )
-
- if tool_choice != "none" and tools:
- parser = FunctionCallParser(tools, tool_call_parser)
- if parser.has_tool_call(text):
- if finish_reason["type"] == "stop":
- finish_reason["type"] = "tool_calls"
- finish_reason["matched"] = None
- try:
- text, call_info_list = parser.parse_non_stream(text)
- tool_calls = [
- ToolCall(
- id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
- function=FunctionResponse(
- name=call_info.name, arguments=call_info.parameters
- ),
- )
- for call_info in call_info_list
- ]
- except Exception as e:
- logger.error(f"Exception: {e}")
- return create_error_response(
- HTTPStatus.BAD_REQUEST,
- "Failed to parse fc related info to json format!",
- )
-
- if to_file:
- # to make the choice data json serializable
- choice_data = {
- "index": 0,
- "message": {
- "role": "assistant",
- "content": text if text else None,
- "tool_calls": tool_calls,
- "reasoning_content": reasoning_text if reasoning_text else None,
- },
- "logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
- "finish_reason": finish_reason["type"] if finish_reason else None,
- "matched_stop": (
- finish_reason["matched"]
- if finish_reason and "matched" in finish_reason
- else None
- ),
- }
- if hidden_states is not None:
- choice_data["hidden_states"] = hidden_states
- else:
- choice_data = ChatCompletionResponseChoice(
- index=idx,
- message=ChatMessage(
- role="assistant",
- content=text if text else None,
- tool_calls=tool_calls,
- reasoning_content=reasoning_text if reasoning_text else None,
- ),
- logprobs=choice_logprobs,
- finish_reason=finish_reason["type"] if finish_reason else None,
- matched_stop=(
- finish_reason["matched"]
- if finish_reason and "matched" in finish_reason
- else None
- ),
- hidden_states=hidden_states,
- )
-
- choices.append(choice_data)
-
- if to_file:
- responses = []
-
- for i, choice in enumerate(choices):
- response = {
- "status_code": 200,
- "request_id": ret[i]["meta_info"]["id"],
- "body": {
- # remain the same but if needed we can change that
- "id": ret[i]["meta_info"]["id"],
- "object": "chat.completion",
- "created": created,
- "model": (
- request[i].model if isinstance(request, list) else request.model
- ),
- "choices": choice,
- "usage": {
- "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
- "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
- "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
- + ret[i]["meta_info"]["completion_tokens"],
- },
- "system_fingerprint": None,
- },
- }
- responses.append(response)
- return responses
- else:
- prompt_tokens = sum(
- ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
- )
- completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
- cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
- response = ChatCompletionResponse(
- id=ret[0]["meta_info"]["id"],
- created=created,
- model=request.model,
- choices=choices,
- usage=UsageInfo(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- prompt_tokens_details=(
- {"cached_tokens": cached_tokens} if cache_report else None
- ),
- ),
- )
- return response
-
-
-async def v1_chat_completions(
- tokenizer_manager, raw_request: Request, cache_report=False
-):
- try:
- request_json = await raw_request.json()
- except Exception as e:
- return create_error_response("Invalid request body, error: ", str(e))
- all_requests = [ChatCompletionRequest(**request_json)]
- created = int(time.time())
- adapted_request, request = v1_chat_generate_request(
- all_requests, tokenizer_manager, request_ids=[all_requests[0].rid]
- )
-
- if adapted_request.stream:
- parser_dict = {}
- reasoning_parser_dict = {}
-
- async def generate_stream_resp():
- tool_index_previous = -1
- is_firsts = {}
- stream_buffers = {}
- n_prev_tokens = {}
- prompt_tokens = {}
- completion_tokens = {}
- cached_tokens = {}
- hidden_states = {}
- try:
- async for content in tokenizer_manager.generate_request(
- adapted_request, raw_request
- ):
- index = content.get("index", 0)
- text = content["text"]
- hidden_states[index] = content["meta_info"].get(
- "hidden_states", None
- ) or hidden_states.get(index)
-
- is_first = is_firsts.get(index, True)
- stream_buffer = stream_buffers.get(index, "")
- n_prev_token = n_prev_tokens.get(index, 0)
-
- prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
- completion_tokens[index] = content["meta_info"]["completion_tokens"]
- cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
- if request.logprobs:
- logprobs = to_openai_style_logprobs(
- output_token_logprobs=content["meta_info"][
- "output_token_logprobs"
- ][n_prev_token:],
- output_top_logprobs=content["meta_info"].get(
- "output_top_logprobs", []
- )[n_prev_token:],
- )
-
- n_prev_token = len(
- content["meta_info"]["output_token_logprobs"]
- )
- token_logprobs = []
- for token, logprob in zip(
- logprobs.tokens, logprobs.token_logprobs
- ):
- token_bytes = list(token.encode("utf-8"))
- top_logprobs = []
- if logprobs.top_logprobs:
- for top_token, top_logprob in logprobs.top_logprobs[
- 0
- ].items():
- top_token_bytes = list(top_token.encode("utf-8"))
- top_logprobs.append(
- TopLogprob(
- token=top_token,
- bytes=top_token_bytes,
- logprob=top_logprob,
- )
- )
- token_logprobs.append(
- ChatCompletionTokenLogprob(
- token=token,
- bytes=token_bytes,
- logprob=logprob,
- top_logprobs=top_logprobs,
- )
- )
-
- choice_logprobs = ChoiceLogprobs(content=token_logprobs)
-
- else:
- choice_logprobs = None
-
- finish_reason = content["meta_info"]["finish_reason"]
- finish_reason_type = (
- finish_reason["type"] if finish_reason else None
- )
-
- if is_first:
- # First chunk with role
- is_first = False
- delta = DeltaMessage(role="assistant")
- choice_data = ChatCompletionResponseStreamChoice(
- index=index,
- delta=delta,
- finish_reason=finish_reason_type,
- matched_stop=(
- finish_reason["matched"]
- if finish_reason and "matched" in finish_reason
- else None
- ),
- logprobs=choice_logprobs,
- )
- chunk = ChatCompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[choice_data],
- model=request.model,
- )
- yield f"data: {chunk.model_dump_json()}\n\n"
-
- text = content["text"]
- delta = text[len(stream_buffer) :]
- new_stream_buffer = stream_buffer + delta
-
- enable_thinking = _get_enable_thinking_from_request(request)
-
- if (
- tokenizer_manager.server_args.reasoning_parser
- and request.separate_reasoning
- and enable_thinking
- ):
- if index not in reasoning_parser_dict:
- reasoning_parser_dict[index] = ReasoningParser(
- tokenizer_manager.server_args.reasoning_parser,
- request.stream_reasoning,
- )
- reasoning_parser = reasoning_parser_dict[index]
- reasoning_text, delta = reasoning_parser.parse_stream_chunk(
- delta
- )
- if reasoning_text:
- choice_data = ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(
- reasoning_content=(
- reasoning_text if reasoning_text else None
- )
- ),
- finish_reason=finish_reason_type,
- )
- chunk = ChatCompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[choice_data],
- model=request.model,
- )
- yield f"data: {chunk.model_dump_json()}\n\n"
- if (delta and len(delta) == 0) or not delta:
- stream_buffers[index] = new_stream_buffer
- is_firsts[index] = is_first
- n_prev_tokens[index] = n_prev_token
- continue
-
- if request.tool_choice != "none" and request.tools:
- if index not in parser_dict:
- parser_dict[index] = FunctionCallParser(
- tools=request.tools,
- tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
- )
- parser = parser_dict[index]
-
- # parse_increment => returns (normal_text, calls)
- normal_text, calls = parser.parse_stream_chunk(delta)
-
- # 1) if there's normal_text, output it as normal content
- if normal_text:
- choice_data = ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(
- content=normal_text if normal_text else None
- ),
- finish_reason=finish_reason_type,
- )
- chunk = ChatCompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[choice_data],
- model=request.model,
- )
- yield f"data: {chunk.model_dump_json()}\n\n"
-
- # 2) if we found calls, we output them as separate chunk(s)
- for call_item in calls:
- tool_index_current = call_item.tool_index
- # transform call_item -> FunctionResponse + ToolCall
- if finish_reason_type == "stop":
- latest_delta_len = 0
- if isinstance(call_item.parameters, str):
- latest_delta_len = len(call_item.parameters)
-
- expected_call = json.dumps(
- parser.detector.prev_tool_call_arr[index].get(
- "arguments", {}
- ),
- ensure_ascii=False,
- )
- actual_call = parser.detector.streamed_args_for_tool[
- index
- ]
- if latest_delta_len > 0:
- actual_call = actual_call[:-latest_delta_len]
- remaining_call = expected_call.replace(
- actual_call, "", 1
- )
- call_item.parameters = remaining_call
-
- finish_reason_type = "tool_calls"
- tool_call = ToolCall(
- id=(
- f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
- if tool_index_previous != tool_index_current
- else None
- ),
- index=call_item.tool_index,
- function=FunctionResponse(
- name=call_item.name,
- arguments=call_item.parameters,
- ),
- )
- tool_index_previous = tool_index_current
- choice_data = ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(tool_calls=[tool_call]),
- finish_reason=(
- None
- if request.stream_options
- and request.stream_options.include_usage
- else finish_reason_type
- ), # additional chunk will be return
- )
- chunk = ChatCompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[choice_data],
- model=request.model,
- )
- yield f"data: {chunk.model_dump_json()}\n\n"
-
- stream_buffers[index] = new_stream_buffer
- is_firsts[index] = is_first
- n_prev_tokens[index] = n_prev_token
-
- else:
- # No tool calls => just treat this as normal text
- if delta or not (
- request.stream_options
- and request.stream_options.include_usage
- ):
- choice_data = ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(content=delta if delta else None),
- finish_reason=(
- None
- if request.stream_options
- and request.stream_options.include_usage
- else finish_reason_type
- ),
- matched_stop=(
- finish_reason["matched"]
- if finish_reason and "matched" in finish_reason
- else None
- ),
- logprobs=choice_logprobs,
- )
- chunk = ChatCompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[choice_data],
- model=request.model,
- )
- yield f"data: {chunk.model_dump_json()}\n\n"
- stream_buffers[index] = new_stream_buffer
- is_firsts[index] = is_first
- n_prev_tokens[index] = n_prev_token
- if finish_reason_type == "stop" and request.tool_choice != "none":
- parser = FunctionCallParser(
- tools=request.tools,
- tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
- )
- if parser.has_tool_call(new_stream_buffer):
- # if the stream ends with empty string after tool calls
- finish_reason_type = "tool_calls"
-
- if request.stream_options and request.stream_options.include_usage:
- total_prompt_tokens = sum(
- tokens
- for i, tokens in prompt_tokens.items()
- if i % request.n == 0
- )
- total_completion_tokens = sum(
- tokens for tokens in completion_tokens.values()
- )
- cache_report = tokenizer_manager.server_args.enable_cache_report
- if cache_report:
- cached_tokens_sum = sum(
- tokens for tokens in cached_tokens.values()
- )
- prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
- else:
- prompt_tokens_details = None
- usage = UsageInfo(
- prompt_tokens=total_prompt_tokens,
- completion_tokens=total_completion_tokens,
- total_tokens=total_prompt_tokens + total_completion_tokens,
- prompt_tokens_details=prompt_tokens_details,
- )
-
- else:
- usage = None
- if request.return_hidden_states and hidden_states:
- for index, choice_hidden_states in hidden_states.items():
- last_token_hidden_states = (
- choice_hidden_states[-1]
- if choice_hidden_states and len(choice_hidden_states) > 1
- else []
- )
- hidden_states_chunk = ChatCompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[
- ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(
- hidden_states=last_token_hidden_states
- ),
- finish_reason=finish_reason_type,
- )
- ],
- model=request.model,
- )
- yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
- final_usage_chunk = ChatCompletionStreamResponse(
- id=content["meta_info"]["id"],
- created=created,
- choices=[
- ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(),
- finish_reason=finish_reason_type,
- )
- ],
- model=request.model,
- usage=usage,
- )
- yield f"data: {final_usage_chunk.model_dump_json()}\n\n"
- except ValueError as e:
- error = create_streaming_error_response(str(e))
- yield f"data: {error}\n\n"
- yield "data: [DONE]\n\n"
-
- return StreamingResponse(
- generate_stream_resp(),
- media_type="text/event-stream",
- background=tokenizer_manager.create_abort_task(adapted_request),
- )
-
- # Non-streaming response.
- try:
- ret = await tokenizer_manager.generate_request(
- adapted_request, raw_request
- ).__anext__()
- except ValueError as e:
- return create_error_response(str(e))
- if not isinstance(ret, list):
- ret = [ret]
-
- response = v1_chat_generate_response(
- request,
- ret,
- created,
- cache_report=tokenizer_manager.server_args.enable_cache_report,
- tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
- reasoning_parser=tokenizer_manager.server_args.reasoning_parser,
- )
-
- return response
-
-
-def v1_embedding_request(all_requests, tokenizer_manager):
- prompts = []
- sampling_params_list = []
- first_prompt_type = type(all_requests[0].input)
-
- for request in all_requests:
- prompt = request.input
- # Check for empty/whitespace string
- prompt = _validate_prompt(request.input)
- assert (
- type(prompt) is first_prompt_type
- ), "All prompts must be of the same type in file input settings"
- prompts.append(prompt)
-
- if len(all_requests) == 1:
- prompt = prompts[0]
- if isinstance(prompt, str) or isinstance(prompt[0], str):
- prompt_kwargs = {"text": prompt}
- elif isinstance(prompt, list) and isinstance(
- prompt[0], MultimodalEmbeddingInput
- ):
- texts = []
- images = []
- for item in prompt:
- # TODO simply use padding for text, we should use a better way to handle this
- texts.append(item.text if item.text is not None else "padding")
- images.append(item.image if item.image is not None else None)
- generate_prompts = []
- if chat_template_name is not None:
- convs = generate_embedding_convs(texts, images, chat_template_name)
- for conv in convs:
- generate_prompts.append(conv.get_prompt())
- else:
- generate_prompts = texts
- if len(generate_prompts) == 1:
- prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
- else:
- prompt_kwargs = {"text": generate_prompts, "image_data": images}
- else:
- prompt_kwargs = {"input_ids": prompt}
- request_ids = all_requests[0].rid
- else:
- if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
- prompt_kwargs = {"text": prompts}
- elif isinstance(prompts[0], list) and isinstance(
- prompts[0][0], MultimodalEmbeddingInput
- ):
- # TODO: multiple requests
- raise NotImplementedError(
- "Multiple requests with multimodal inputs are not supported yet"
- )
- else:
- prompt_kwargs = {"input_ids": prompts}
- request_ids = [req.rid for req in all_requests]
-
- adapted_request = EmbeddingReqInput(
- rid=request_ids,
- **prompt_kwargs,
- )
-
- if len(all_requests) == 1:
- return adapted_request, all_requests[0]
- return adapted_request, all_requests
-
-
-def v1_embedding_response(ret, model_path, to_file=False):
- embedding_objects = []
- prompt_tokens = 0
- for idx, ret_item in enumerate(ret):
- embedding_objects.append(
- EmbeddingObject(
- embedding=ret[idx]["embedding"],
- index=idx,
- )
- )
- prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
-
- return EmbeddingResponse(
- data=embedding_objects,
- model=model_path,
- usage=UsageInfo(
- prompt_tokens=prompt_tokens,
- total_tokens=prompt_tokens,
- ),
- )
-
-
-async def v1_embeddings(tokenizer_manager, raw_request: Request):
- try:
- request_json = await raw_request.json()
- except Exception as e:
- return create_error_response("Invalid request body, error: ", str(e))
- all_requests = [EmbeddingRequest(**request_json)]
- adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
-
- try:
- ret = await tokenizer_manager.generate_request(
- adapted_request, raw_request
- ).__anext__()
- except ValueError as e:
- return create_error_response(str(e))
-
- if not isinstance(ret, list):
- ret = [ret]
-
- response = v1_embedding_response(ret, tokenizer_manager.model_path)
-
- return response
-
-
-def v1_rerank_request(obj: V1RerankReqInput):
- if obj.query is None:
- raise ValueError("query is required")
- if obj.documents is None or len(obj.documents) == 0:
- raise ValueError("documents is required")
-
- pairs = []
- for doc in obj.documents:
- pairs.append([obj.query, doc])
-
- adapted_request = EmbeddingReqInput(
- text=pairs,
- is_cross_encoder_request=True,
- )
-
- return adapted_request
-
-
-def v1_rerank_response(ret, obj: V1RerankReqInput):
-
- response = []
- for idx, ret_item in enumerate(ret):
- response.append(
- RerankResponse(
- score=ret[idx]["embedding"],
- document=obj.documents[idx],
- index=idx,
- meta_info=ret[idx]["meta_info"],
- )
- )
-
- response.sort(key=lambda x: x.score, reverse=True)
-
- return response
-
-
-async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request):
- adapted_request = v1_rerank_request(obj)
-
- try:
- ret = await tokenizer_manager.generate_request(
- adapted_request, raw_request
- ).__anext__()
-
- except ValueError as e:
- return create_error_response(str(e))
-
- if not isinstance(ret, list):
- ret = [ret]
-
- response = v1_rerank_response(
- ret,
- obj,
- )
-
- return response
-
-
-def to_openai_style_logprobs(
- input_token_logprobs=None,
- output_token_logprobs=None,
- input_top_logprobs=None,
- output_top_logprobs=None,
-):
- ret_logprobs = LogProbs()
-
- def append_token_logprobs(token_logprobs):
- for logprob, _, token_text in token_logprobs:
- ret_logprobs.tokens.append(token_text)
- ret_logprobs.token_logprobs.append(logprob)
-
- # Not supported yet
- ret_logprobs.text_offset.append(-1)
-
- def append_top_logprobs(top_logprobs):
- for tokens in top_logprobs:
- if tokens is not None:
- ret_logprobs.top_logprobs.append(
- {token[2]: token[0] for token in tokens}
- )
- else:
- ret_logprobs.top_logprobs.append(None)
-
- if input_token_logprobs is not None:
- append_token_logprobs(input_token_logprobs)
- if output_token_logprobs is not None:
- append_token_logprobs(output_token_logprobs)
- if input_top_logprobs is not None:
- append_top_logprobs(input_top_logprobs)
- if output_top_logprobs is not None:
- append_top_logprobs(output_top_logprobs)
-
- return ret_logprobs
-
-
-async def v1_score(tokenizer_manager, raw_request):
- try:
- # Parse request
- request_data = await raw_request.json()
- request = ScoringRequest(**request_data)
-
- # Use tokenizer_manager's score_request method directly
- scores = await tokenizer_manager.score_request(
- query=request.query,
- items=request.items,
- label_token_ids=request.label_token_ids,
- apply_softmax=request.apply_softmax,
- item_first=request.item_first,
- request=request,
- )
-
- # Create response with just the scores, without usage info
- response = ScoringResponse(
- scores=scores,
- model=request.model,
- )
- return response
-
- except Exception as e:
- logger.error(f"Error in v1_score: {str(e)}")
- return create_error_response(str(e))
diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py
deleted file mode 100644
index 71153b912..000000000
--- a/python/sglang/srt/openai_api/protocol.py
+++ /dev/null
@@ -1,551 +0,0 @@
-# Copyright 2023-2024 SGLang Team
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Pydantic models for OpenAI API protocol"""
-
-import time
-from typing import Dict, List, Optional, Union
-
-from pydantic import BaseModel, Field, model_serializer, root_validator
-from typing_extensions import Literal
-
-
-class ModelCard(BaseModel):
- """Model cards."""
-
- id: str
- object: str = "model"
- created: int = Field(default_factory=lambda: int(time.time()))
- owned_by: str = "sglang"
- root: Optional[str] = None
- max_model_len: Optional[int] = None
-
-
-class ModelList(BaseModel):
- """Model list consists of model cards."""
-
- object: str = "list"
- data: List[ModelCard] = Field(default_factory=list)
-
-
-class ErrorResponse(BaseModel):
- object: str = "error"
- message: str
- type: str
- param: Optional[str] = None
- code: int
-
-
-class LogProbs(BaseModel):
- text_offset: List[int] = Field(default_factory=list)
- token_logprobs: List[Optional[float]] = Field(default_factory=list)
- tokens: List[str] = Field(default_factory=list)
- top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
-
-
-class TopLogprob(BaseModel):
- token: str
- bytes: List[int]
- logprob: float
-
-
-class ChatCompletionTokenLogprob(BaseModel):
- token: str
- bytes: List[int]
- logprob: float
- top_logprobs: List[TopLogprob]
-
-
-class ChoiceLogprobs(BaseModel):
- # build for v1/chat/completions response
- content: List[ChatCompletionTokenLogprob]
-
-
-class UsageInfo(BaseModel):
- prompt_tokens: int = 0
- total_tokens: int = 0
- completion_tokens: Optional[int] = 0
- # only used to return cached tokens when --enable-cache-report is set
- prompt_tokens_details: Optional[Dict[str, int]] = None
-
-
-class StreamOptions(BaseModel):
- include_usage: Optional[bool] = False
-
-
-class JsonSchemaResponseFormat(BaseModel):
- name: str
- description: Optional[str] = None
- # use alias to workaround pydantic conflict
- schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
- strict: Optional[bool] = False
-
-
-class FileRequest(BaseModel):
- # https://platform.openai.com/docs/api-reference/files/create
- file: bytes # The File object (not file name) to be uploaded
- purpose: str = (
- "batch" # The intended purpose of the uploaded file, default is "batch"
- )
-
-
-class FileResponse(BaseModel):
- id: str
- object: str = "file"
- bytes: int
- created_at: int
- filename: str
- purpose: str
-
-
-class FileDeleteResponse(BaseModel):
- id: str
- object: str = "file"
- deleted: bool
-
-
-class BatchRequest(BaseModel):
- input_file_id: (
- str # The ID of an uploaded file that contains requests for the new batch
- )
- endpoint: str # The endpoint to be used for all requests in the batch
- completion_window: str # The time frame within which the batch should be processed
- metadata: Optional[dict] = None # Optional custom metadata for the batch
-
-
-class BatchResponse(BaseModel):
- id: str
- object: str = "batch"
- endpoint: str
- errors: Optional[dict] = None
- input_file_id: str
- completion_window: str
- status: str = "validating"
- output_file_id: Optional[str] = None
- error_file_id: Optional[str] = None
- created_at: int
- in_progress_at: Optional[int] = None
- expires_at: Optional[int] = None
- finalizing_at: Optional[int] = None
- completed_at: Optional[int] = None
- failed_at: Optional[int] = None
- expired_at: Optional[int] = None
- cancelling_at: Optional[int] = None
- cancelled_at: Optional[int] = None
- request_counts: Optional[dict] = None
- metadata: Optional[dict] = None
-
-
-class CompletionRequest(BaseModel):
- # Ordered by official OpenAI API documentation
- # https://platform.openai.com/docs/api-reference/completions/create
- model: str
- prompt: Union[List[int], List[List[int]], str, List[str]]
- best_of: Optional[int] = None
- echo: bool = False
- frequency_penalty: float = 0.0
- logit_bias: Optional[Dict[str, float]] = None
- logprobs: Optional[int] = None
- max_tokens: int = 16
- n: int = 1
- presence_penalty: float = 0.0
- seed: Optional[int] = None
- stop: Optional[Union[str, List[str]]] = None
- stream: bool = False
- stream_options: Optional[StreamOptions] = None
- suffix: Optional[str] = None
- temperature: float = 1.0
- top_p: float = 1.0
- user: Optional[str] = None
-
- # Extra parameters for SRT backend only and will be ignored by OpenAI models.
- top_k: int = -1
- min_p: float = 0.0
- min_tokens: int = 0
- json_schema: Optional[str] = None
- regex: Optional[str] = None
- ebnf: Optional[str] = None
- repetition_penalty: float = 1.0
- stop_token_ids: Optional[List[int]] = None
- no_stop_trim: bool = False
- ignore_eos: bool = False
- skip_special_tokens: bool = True
- lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
- session_params: Optional[Dict] = None
- return_hidden_states: Optional[bool] = False
-
- # For PD disaggregation
- bootstrap_host: Optional[str] = None
- bootstrap_port: Optional[int] = None
- bootstrap_room: Optional[int] = None
-
-
-class CompletionResponseChoice(BaseModel):
- index: int
- text: str
- logprobs: Optional[LogProbs] = None
- finish_reason: Literal["stop", "length", "content_filter", "abort"]
- matched_stop: Union[None, int, str] = None
- hidden_states: Optional[object] = None
-
- @model_serializer
- def _serialize(self):
- return exclude_if_none(self, ["hidden_states"])
-
-
-class CompletionResponse(BaseModel):
- id: str
- object: str = "text_completion"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- choices: List[CompletionResponseChoice]
- usage: UsageInfo
-
-
-class CompletionResponseStreamChoice(BaseModel):
- index: int
- text: str
- logprobs: Optional[LogProbs] = None
- finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
- matched_stop: Union[None, int, str] = None
- hidden_states: Optional[object] = None
-
- @model_serializer
- def _serialize(self):
- return exclude_if_none(self, ["hidden_states"])
-
-
-class CompletionStreamResponse(BaseModel):
- id: str
- object: str = "text_completion"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- choices: List[CompletionResponseStreamChoice]
- usage: Optional[UsageInfo] = None
-
-
-class ChatCompletionMessageContentTextPart(BaseModel):
- type: Literal["text"]
- text: str
-
-
-class ChatCompletionMessageContentImageURL(BaseModel):
- url: str
- detail: Optional[Literal["auto", "low", "high"]] = "auto"
-
-
-class ChatCompletionMessageContentAudioURL(BaseModel):
- url: str
-
-
-class ChatCompletionMessageContentImagePart(BaseModel):
- type: Literal["image_url"]
- image_url: ChatCompletionMessageContentImageURL
- modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
-
-
-class ChatCompletionMessageContentAudioPart(BaseModel):
- type: Literal["audio_url"]
- audio_url: ChatCompletionMessageContentAudioURL
-
-
-ChatCompletionMessageContentPart = Union[
- ChatCompletionMessageContentTextPart,
- ChatCompletionMessageContentImagePart,
- ChatCompletionMessageContentAudioPart,
-]
-
-
-class FunctionResponse(BaseModel):
- """Function response."""
-
- name: Optional[str] = None
- arguments: Optional[str] = None
-
-
-class ToolCall(BaseModel):
- """Tool call response."""
-
- id: Optional[str] = None
- index: Optional[int] = None
- type: Literal["function"] = "function"
- function: FunctionResponse
-
-
-class ChatCompletionMessageGenericParam(BaseModel):
- role: Literal["system", "assistant", "tool"]
- content: Union[str, List[ChatCompletionMessageContentTextPart], None]
- tool_call_id: Optional[str] = None
- name: Optional[str] = None
- reasoning_content: Optional[str] = None
- tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
-
-
-class ChatCompletionMessageUserParam(BaseModel):
- role: Literal["user"]
- content: Union[str, List[ChatCompletionMessageContentPart]]
-
-
-ChatCompletionMessageParam = Union[
- ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
-]
-
-
-class ResponseFormat(BaseModel):
- type: Literal["text", "json_object", "json_schema"]
- json_schema: Optional[JsonSchemaResponseFormat] = None
-
-
-class StructuresResponseFormat(BaseModel):
- begin: str
- schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
- end: str
-
-
-class StructuralTagResponseFormat(BaseModel):
- type: Literal["structural_tag"]
- structures: List[StructuresResponseFormat]
- triggers: List[str]
-
-
-class Function(BaseModel):
- """Function descriptions."""
-
- description: Optional[str] = Field(default=None, examples=[None])
- name: Optional[str] = None
- parameters: Optional[object] = None
- strict: bool = False
-
-
-class Tool(BaseModel):
- """Function wrapper."""
-
- type: str = Field(default="function", examples=["function"])
- function: Function
-
-
-class ToolChoiceFuncName(BaseModel):
- """The name of tool choice function."""
-
- name: Optional[str] = None
-
-
-class ToolChoice(BaseModel):
- """The tool choice definition."""
-
- function: ToolChoiceFuncName
- type: Literal["function"] = Field(default="function", examples=["function"])
-
-
-class ChatCompletionRequest(BaseModel):
- # Ordered by official OpenAI API documentation
- # https://platform.openai.com/docs/api-reference/chat/create
- messages: List[ChatCompletionMessageParam]
- model: str
- frequency_penalty: float = 0.0
- logit_bias: Optional[Dict[str, float]] = None
- logprobs: bool = False
- top_logprobs: Optional[int] = None
- max_tokens: Optional[int] = Field(
- default=None,
- deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
- description="The maximum number of tokens that can be generated in the chat completion. ",
- )
- max_completion_tokens: Optional[int] = Field(
- default=None,
- description="The maximum number of completion tokens for a chat completion request, "
- "including visible output tokens and reasoning tokens. Input tokens are not included. ",
- )
- n: int = 1
- presence_penalty: float = 0.0
- response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
- seed: Optional[int] = None
- stop: Optional[Union[str, List[str]]] = None
- stream: bool = False
- stream_options: Optional[StreamOptions] = None
- temperature: float = 0.7
- top_p: float = 1.0
- user: Optional[str] = None
- tools: Optional[List[Tool]] = Field(default=None, examples=[None])
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
- default="auto", examples=["none"]
- ) # noqa
-
- @root_validator(pre=True)
- def set_tool_choice_default(cls, values):
- if values.get("tool_choice") is None:
- if values.get("tools") is None:
- values["tool_choice"] = "none"
- else:
- values["tool_choice"] = "auto"
- return values
-
- # Extra parameters for SRT backend only and will be ignored by OpenAI models.
- top_k: int = -1
- min_p: float = 0.0
- min_tokens: int = 0
- regex: Optional[str] = None
- ebnf: Optional[str] = None
- repetition_penalty: float = 1.0
- stop_token_ids: Optional[List[int]] = None
- no_stop_trim: bool = False
- ignore_eos: bool = False
- continue_final_message: bool = False
- skip_special_tokens: bool = True
- lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
- session_params: Optional[Dict] = None
- separate_reasoning: bool = True
- stream_reasoning: bool = True
- chat_template_kwargs: Optional[Dict] = None
-
- # The request id.
- rid: Optional[str] = None
-
- # For PD disaggregation
- bootstrap_host: Optional[str] = None
- bootstrap_port: Optional[int] = None
- bootstrap_room: Optional[int] = None
-
- # Hidden States
- return_hidden_states: Optional[bool] = False
-
-
-class ChatMessage(BaseModel):
- role: Optional[str] = None
- content: Optional[str] = None
- reasoning_content: Optional[str] = None
- tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
-
-
-class ChatCompletionResponseChoice(BaseModel):
- index: int
- message: ChatMessage
- logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
- finish_reason: Literal[
- "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
- ]
- matched_stop: Union[None, int, str] = None
- hidden_states: Optional[object] = None
-
- @model_serializer
- def _serialize(self):
- return exclude_if_none(self, ["hidden_states"])
-
-
-class ChatCompletionResponse(BaseModel):
- id: str
- object: str = "chat.completion"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- choices: List[ChatCompletionResponseChoice]
- usage: UsageInfo
-
-
-class DeltaMessage(BaseModel):
- role: Optional[str] = None
- content: Optional[str] = None
- reasoning_content: Optional[str] = None
- tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
- hidden_states: Optional[object] = None
-
- @model_serializer
- def _serialize(self):
- return exclude_if_none(self, ["hidden_states"])
-
-
-class ChatCompletionResponseStreamChoice(BaseModel):
- index: int
- delta: DeltaMessage
- logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
- finish_reason: Optional[
- Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
- ] = None
- matched_stop: Union[None, int, str] = None
-
-
-class ChatCompletionStreamResponse(BaseModel):
- id: str
- object: str = "chat.completion.chunk"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- choices: List[ChatCompletionResponseStreamChoice]
- usage: Optional[UsageInfo] = None
-
-
-class MultimodalEmbeddingInput(BaseModel):
- text: Optional[str] = None
- image: Optional[str] = None
-
-
-class EmbeddingRequest(BaseModel):
- # Ordered by official OpenAI API documentation
- # https://platform.openai.com/docs/api-reference/embeddings/create
- input: Union[
- List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
- ]
- model: str
- encoding_format: str = "float"
- dimensions: int = None
- user: Optional[str] = None
-
- # The request id.
- rid: Optional[str] = None
-
-
-class EmbeddingObject(BaseModel):
- embedding: List[float]
- index: int
- object: str = "embedding"
-
-
-class EmbeddingResponse(BaseModel):
- data: List[EmbeddingObject]
- model: str
- object: str = "list"
- usage: Optional[UsageInfo] = None
-
-
-class ScoringRequest(BaseModel):
- query: Optional[Union[str, List[int]]] = (
- None # Query text or pre-tokenized token IDs
- )
- items: Optional[Union[str, List[str], List[List[int]]]] = (
- None # Item text(s) or pre-tokenized token IDs
- )
- label_token_ids: Optional[List[int]] = (
- None # Token IDs to compute probabilities for
- )
- apply_softmax: bool = False
- item_first: bool = False
- model: str
-
-
-class ScoringResponse(BaseModel):
- scores: List[
- List[float]
- ] # List of lists of probabilities, each in the order of label_token_ids
- model: str
- usage: Optional[UsageInfo] = None
- object: str = "scoring"
-
-
-class RerankResponse(BaseModel):
- score: float
- document: str
- index: int
- meta_info: Optional[dict] = None
-
-
-def exclude_if_none(obj, field_names: List[str]):
- omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
- return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py
index d8bf8f09c..746445bd9 100644
--- a/python/sglang/srt/reasoning_parser.py
+++ b/python/sglang/srt/reasoning_parser.py
@@ -1,4 +1,4 @@
-from typing import Dict, Tuple
+from typing import Dict, Optional, Tuple, Type
class StreamingParseResult:
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately.
"""
- text = text.replace(self.think_start_token, "").strip()
- if self.think_end_token not in text:
+ in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
+
+ if not in_reasoning:
+ return StreamingParseResult(normal_text=text)
+
+ # The text is considered to be in a reasoning block.
+ processed_text = text.replace(self.think_start_token, "").strip()
+
+ if self.think_end_token not in processed_text:
# Assume reasoning was truncated before `` token
- return StreamingParseResult(reasoning_text=text)
+ return StreamingParseResult(reasoning_text=processed_text)
# Extract reasoning content
- splits = text.split(self.think_end_token, maxsplit=1)
+ splits = processed_text.split(self.think_end_token, maxsplit=1)
reasoning_text = splits[0]
- text = splits[1].strip()
+ normal_text = splits[1].strip()
- return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
+ return StreamingParseResult(
+ normal_text=normal_text, reasoning_text=reasoning_text
+ )
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
"""
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
if not self.stripped_think_start and self.think_start_token in current_text:
current_text = current_text.replace(self.think_start_token, "")
self.stripped_think_start = True
+ self._in_reasoning = True
# Handle end of reasoning block
if self._in_reasoning and self.think_end_token in current_text:
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
"""
def __init__(self, stream_reasoning: bool = True):
- # Qwen3 is assumed to be reasoning until `` token
+ # Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
super().__init__(
"",
"",
- force_reasoning=True,
+ force_reasoning=False,
stream_reasoning=stream_reasoning,
)
@@ -151,12 +161,12 @@ class ReasoningParser:
If True, streams reasoning content as it arrives.
"""
- DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
+ DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
"deepseek-r1": DeepSeekR1Detector,
"qwen3": Qwen3Detector,
}
- def __init__(self, model_type: str = None, stream_reasoning: bool = True):
+ def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
if not model_type:
raise ValueError("Model type must be specified")
diff --git a/test/srt/openai/conftest.py b/test/srt/openai/conftest.py
deleted file mode 100644
index ed88d624b..000000000
--- a/test/srt/openai/conftest.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# sglang/test/srt/openai/conftest.py
-import os
-import socket
-import subprocess
-import sys
-import tempfile
-import time
-from contextlib import closing
-from typing import Generator
-
-import pytest
-import requests
-
-from sglang.srt.utils import kill_process_tree # reuse SGLang helper
-from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
-
-SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
-DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
-STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
-
-
-def _pick_free_port() -> int:
- with closing(socket.socket()) as s:
- s.bind(("127.0.0.1", 0))
- return s.getsockname()[1]
-
-
-def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None:
- start = time.perf_counter()
- while time.perf_counter() - start < timeout:
- if proc.poll() is not None: # crashed
- raise RuntimeError("api_server terminated prematurely")
- try:
- if requests.get(f"{base}/health", timeout=1).status_code == 200:
- return
- except requests.RequestException:
- pass
- time.sleep(0.4)
- raise RuntimeError("api_server readiness probe timed out")
-
-
-def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
- """Spawn the draft OpenAI-compatible server and wait until it's ready."""
- port = _pick_free_port()
- cmd = [
- sys.executable,
- "-m",
- SERVER_MODULE,
- "--model-path",
- model,
- "--host",
- "127.0.0.1",
- "--port",
- str(port),
- *map(str, kw.get("args", [])),
- ]
- env = {**os.environ, **kw.get("env", {})}
-
- # Write logs to a temp file so the child never blocks on a full pipe.
- log_file = tempfile.NamedTemporaryFile("w+", delete=False)
- proc = subprocess.Popen(
- cmd,
- env=env,
- stdout=log_file,
- stderr=subprocess.STDOUT,
- text=True,
- )
-
- base = f"http://127.0.0.1:{port}"
- try:
- _wait_until_healthy(proc, base, STARTUP_TIMEOUT)
- except Exception as e:
- proc.terminate()
- proc.wait(5)
- log_file.seek(0)
- print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr)
- raise e
- return proc, base, log_file
-
-
-@pytest.fixture(scope="session")
-def openai_server() -> Generator[str, None, None]:
- """PyTest fixture that provides the server's base URL and cleans up."""
- proc, base, log_file = launch_openai_server()
- yield base
- kill_process_tree(proc.pid)
- log_file.close()
diff --git a/test/srt/openai/test_protocol.py b/test/srt/openai/test_protocol.py
index 06260024a..65b4e4c50 100644
--- a/test/srt/openai/test_protocol.py
+++ b/test/srt/openai/test_protocol.py
@@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import (
class TestModelCard(unittest.TestCase):
"""Test ModelCard protocol model"""
- def test_basic_model_card_creation(self):
- """Test basic model card creation with required fields"""
- card = ModelCard(id="test-model")
- self.assertEqual(card.id, "test-model")
- self.assertEqual(card.object, "model")
- self.assertEqual(card.owned_by, "sglang")
- self.assertIsInstance(card.created, int)
- self.assertIsNone(card.root)
- self.assertIsNone(card.max_model_len)
-
- def test_model_card_with_optional_fields(self):
- """Test model card with optional fields"""
- card = ModelCard(
- id="test-model",
- root="/path/to/model",
- max_model_len=2048,
- created=1234567890,
- )
- self.assertEqual(card.id, "test-model")
- self.assertEqual(card.root, "/path/to/model")
- self.assertEqual(card.max_model_len, 2048)
- self.assertEqual(card.created, 1234567890)
-
def test_model_card_serialization(self):
"""Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096)
@@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase):
self.assertEqual(model_list.data[1].id, "model-2")
-class TestErrorResponse(unittest.TestCase):
- """Test ErrorResponse protocol model"""
-
- def test_basic_error_response(self):
- """Test basic error response creation"""
- error = ErrorResponse(
- message="Invalid request", type="BadRequestError", code=400
- )
- self.assertEqual(error.object, "error")
- self.assertEqual(error.message, "Invalid request")
- self.assertEqual(error.type, "BadRequestError")
- self.assertEqual(error.code, 400)
- self.assertIsNone(error.param)
-
- def test_error_response_with_param(self):
- """Test error response with parameter"""
- error = ErrorResponse(
- message="Invalid temperature",
- type="ValidationError",
- code=422,
- param="temperature",
- )
- self.assertEqual(error.param, "temperature")
-
-
-class TestUsageInfo(unittest.TestCase):
- """Test UsageInfo protocol model"""
-
- def test_basic_usage_info(self):
- """Test basic usage info creation"""
- usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
- self.assertEqual(usage.prompt_tokens, 10)
- self.assertEqual(usage.completion_tokens, 20)
- self.assertEqual(usage.total_tokens, 30)
- self.assertIsNone(usage.prompt_tokens_details)
-
- def test_usage_info_with_cache_details(self):
- """Test usage info with cache details"""
- usage = UsageInfo(
- prompt_tokens=10,
- completion_tokens=20,
- total_tokens=30,
- prompt_tokens_details={"cached_tokens": 5},
- )
- self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
-
-
class TestCompletionRequest(unittest.TestCase):
"""Test CompletionRequest protocol model"""
@@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream) # default
self.assertFalse(request.echo) # default
- def test_completion_request_with_options(self):
- """Test completion request with various options"""
- request = CompletionRequest(
- model="test-model",
- prompt=["Hello", "world"],
- max_tokens=100,
- temperature=0.7,
- top_p=0.9,
- n=2,
- stream=True,
- echo=True,
- stop=[".", "!"],
- logprobs=5,
- )
- self.assertEqual(request.prompt, ["Hello", "world"])
- self.assertEqual(request.max_tokens, 100)
- self.assertEqual(request.temperature, 0.7)
- self.assertEqual(request.top_p, 0.9)
- self.assertEqual(request.n, 2)
- self.assertTrue(request.stream)
- self.assertTrue(request.echo)
- self.assertEqual(request.stop, [".", "!"])
- self.assertEqual(request.logprobs, 5)
-
def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions"""
request = CompletionRequest(
@@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase):
CompletionRequest(model="test-model") # missing prompt
-class TestCompletionResponse(unittest.TestCase):
- """Test CompletionResponse protocol model"""
-
- def test_basic_completion_response(self):
- """Test basic completion response"""
- choice = CompletionResponseChoice(
- index=0, text="Hello world!", finish_reason="stop"
- )
- usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
- response = CompletionResponse(
- id="test-id", model="test-model", choices=[choice], usage=usage
- )
- self.assertEqual(response.id, "test-id")
- self.assertEqual(response.object, "text_completion")
- self.assertEqual(response.model, "test-model")
- self.assertEqual(len(response.choices), 1)
- self.assertEqual(response.choices[0].text, "Hello world!")
- self.assertEqual(response.usage.total_tokens, 5)
-
-
class TestChatCompletionRequest(unittest.TestCase):
"""Test ChatCompletionRequest protocol model"""
@@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream) # default
self.assertEqual(request.tool_choice, "none") # default when no tools
- def test_chat_completion_with_multimodal_content(self):
- """Test chat completion with multimodal content"""
- messages = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "What's in this image?"},
- {
- "type": "image_url",
- "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
- },
- ],
- }
- ]
- request = ChatCompletionRequest(model="test-model", messages=messages)
- self.assertEqual(len(request.messages[0].content), 2)
- self.assertEqual(request.messages[0].content[0].type, "text")
- self.assertEqual(request.messages[0].content[1].type, "image_url")
-
- def test_chat_completion_with_tools(self):
- """Test chat completion with tools"""
- messages = [{"role": "user", "content": "What's the weather?"}]
- tools = [
- {
- "type": "function",
- "function": {
- "name": "get_weather",
- "description": "Get weather information",
- "parameters": {
- "type": "object",
- "properties": {"location": {"type": "string"}},
- },
- },
- }
- ]
- request = ChatCompletionRequest(
- model="test-model", messages=messages, tools=tools
- )
- self.assertEqual(len(request.tools), 1)
- self.assertEqual(request.tools[0].function.name, "get_weather")
- self.assertEqual(request.tool_choice, "auto") # default when tools present
-
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}]
@@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
-class TestChatCompletionResponse(unittest.TestCase):
- """Test ChatCompletionResponse protocol model"""
-
- def test_basic_chat_completion_response(self):
- """Test basic chat completion response"""
- message = ChatMessage(role="assistant", content="Hello there!")
- choice = ChatCompletionResponseChoice(
- index=0, message=message, finish_reason="stop"
- )
- usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
- response = ChatCompletionResponse(
- id="test-id", model="test-model", choices=[choice], usage=usage
- )
- self.assertEqual(response.id, "test-id")
- self.assertEqual(response.object, "chat.completion")
- self.assertEqual(response.model, "test-model")
- self.assertEqual(len(response.choices), 1)
- self.assertEqual(response.choices[0].message.content, "Hello there!")
-
- def test_chat_completion_response_with_tool_calls(self):
- """Test chat completion response with tool calls"""
- tool_call = ToolCall(
- id="call_123",
- function=FunctionResponse(
- name="get_weather", arguments='{"location": "San Francisco"}'
- ),
- )
- message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
- choice = ChatCompletionResponseChoice(
- index=0, message=message, finish_reason="tool_calls"
- )
- usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
- response = ChatCompletionResponse(
- id="test-id", model="test-model", choices=[choice], usage=usage
- )
- self.assertEqual(
- response.choices[0].message.tool_calls[0].function.name, "get_weather"
- )
- self.assertEqual(response.choices[0].finish_reason, "tool_calls")
-
-
-class TestEmbeddingRequest(unittest.TestCase):
- """Test EmbeddingRequest protocol model"""
-
- def test_basic_embedding_request(self):
- """Test basic embedding request"""
- request = EmbeddingRequest(model="test-model", input="Hello world")
- self.assertEqual(request.model, "test-model")
- self.assertEqual(request.input, "Hello world")
- self.assertEqual(request.encoding_format, "float") # default
- self.assertIsNone(request.dimensions) # default
-
- def test_embedding_request_with_list_input(self):
- """Test embedding request with list input"""
- request = EmbeddingRequest(
- model="test-model", input=["Hello", "world"], dimensions=512
- )
- self.assertEqual(request.input, ["Hello", "world"])
- self.assertEqual(request.dimensions, 512)
-
- def test_multimodal_embedding_request(self):
- """Test multimodal embedding request"""
- multimodal_input = [
- MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
- MultimodalEmbeddingInput(text="World", image=None),
- ]
- request = EmbeddingRequest(model="test-model", input=multimodal_input)
- self.assertEqual(len(request.input), 2)
- self.assertEqual(request.input[0].text, "Hello")
- self.assertEqual(request.input[0].image, "base64_image_data")
- self.assertEqual(request.input[1].text, "World")
- self.assertIsNone(request.input[1].image)
-
-
-class TestEmbeddingResponse(unittest.TestCase):
- """Test EmbeddingResponse protocol model"""
-
- def test_basic_embedding_response(self):
- """Test basic embedding response"""
- embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
- usage = UsageInfo(prompt_tokens=3, total_tokens=3)
- response = EmbeddingResponse(
- data=[embedding_obj], model="test-model", usage=usage
- )
- self.assertEqual(response.object, "list")
- self.assertEqual(len(response.data), 1)
- self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
- self.assertEqual(response.data[0].index, 0)
- self.assertEqual(response.usage.prompt_tokens, 3)
-
-
-class TestScoringRequest(unittest.TestCase):
- """Test ScoringRequest protocol model"""
-
- def test_basic_scoring_request(self):
- """Test basic scoring request"""
- request = ScoringRequest(
- model="test-model", query="Hello", items=["World", "Earth"]
- )
- self.assertEqual(request.model, "test-model")
- self.assertEqual(request.query, "Hello")
- self.assertEqual(request.items, ["World", "Earth"])
- self.assertFalse(request.apply_softmax) # default
- self.assertFalse(request.item_first) # default
-
- def test_scoring_request_with_token_ids(self):
- """Test scoring request with token IDs"""
- request = ScoringRequest(
- model="test-model",
- query=[1, 2, 3],
- items=[[4, 5], [6, 7]],
- label_token_ids=[8, 9],
- apply_softmax=True,
- item_first=True,
- )
- self.assertEqual(request.query, [1, 2, 3])
- self.assertEqual(request.items, [[4, 5], [6, 7]])
- self.assertEqual(request.label_token_ids, [8, 9])
- self.assertTrue(request.apply_softmax)
- self.assertTrue(request.item_first)
-
-
-class TestScoringResponse(unittest.TestCase):
- """Test ScoringResponse protocol model"""
-
- def test_basic_scoring_response(self):
- """Test basic scoring response"""
- response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
- self.assertEqual(response.object, "scoring")
- self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
- self.assertEqual(response.model, "test-model")
- self.assertIsNone(response.usage) # default
-
-
-class TestFileOperations(unittest.TestCase):
- """Test file operation protocol models"""
-
- def test_file_request(self):
- """Test file request model"""
- file_data = b"test file content"
- request = FileRequest(file=file_data, purpose="batch")
- self.assertEqual(request.file, file_data)
- self.assertEqual(request.purpose, "batch")
-
- def test_file_response(self):
- """Test file response model"""
- response = FileResponse(
- id="file-123",
- bytes=1024,
- created_at=1234567890,
- filename="test.jsonl",
- purpose="batch",
- )
- self.assertEqual(response.id, "file-123")
- self.assertEqual(response.object, "file")
- self.assertEqual(response.bytes, 1024)
- self.assertEqual(response.filename, "test.jsonl")
-
- def test_file_delete_response(self):
- """Test file delete response model"""
- response = FileDeleteResponse(id="file-123", deleted=True)
- self.assertEqual(response.id, "file-123")
- self.assertEqual(response.object, "file")
- self.assertTrue(response.deleted)
-
-
-class TestBatchOperations(unittest.TestCase):
- """Test batch operation protocol models"""
-
- def test_batch_request(self):
- """Test batch request model"""
- request = BatchRequest(
- input_file_id="file-123",
- endpoint="/v1/chat/completions",
- completion_window="24h",
- metadata={"custom": "value"},
- )
- self.assertEqual(request.input_file_id, "file-123")
- self.assertEqual(request.endpoint, "/v1/chat/completions")
- self.assertEqual(request.completion_window, "24h")
- self.assertEqual(request.metadata, {"custom": "value"})
-
- def test_batch_response(self):
- """Test batch response model"""
- response = BatchResponse(
- id="batch-123",
- endpoint="/v1/chat/completions",
- input_file_id="file-123",
- completion_window="24h",
- created_at=1234567890,
- )
- self.assertEqual(response.id, "batch-123")
- self.assertEqual(response.object, "batch")
- self.assertEqual(response.status, "validating") # default
- self.assertEqual(response.endpoint, "/v1/chat/completions")
-
-
-class TestResponseFormats(unittest.TestCase):
- """Test response format protocol models"""
-
- def test_basic_response_format(self):
- """Test basic response format"""
- format_obj = ResponseFormat(type="json_object")
- self.assertEqual(format_obj.type, "json_object")
- self.assertIsNone(format_obj.json_schema)
-
- def test_json_schema_response_format(self):
- """Test JSON schema response format"""
- schema = {"type": "object", "properties": {"name": {"type": "string"}}}
- json_schema = JsonSchemaResponseFormat(
- name="person_schema", description="Person schema", schema=schema
- )
- format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
- self.assertEqual(format_obj.type, "json_schema")
- self.assertEqual(format_obj.json_schema.name, "person_schema")
- self.assertEqual(format_obj.json_schema.schema_, schema)
-
- def test_structural_tag_response_format(self):
- """Test structural tag response format"""
- structures = [
- {
- "begin": "",
- "schema_": {"type": "string"},
- "end": "",
- }
- ]
- format_obj = StructuralTagResponseFormat(
- type="structural_tag", structures=structures, triggers=["think"]
- )
- self.assertEqual(format_obj.type, "structural_tag")
- self.assertEqual(len(format_obj.structures), 1)
- self.assertEqual(format_obj.triggers, ["think"])
-
-
-class TestLogProbs(unittest.TestCase):
- """Test LogProbs protocol models"""
-
- def test_basic_logprobs(self):
- """Test basic LogProbs model"""
- logprobs = LogProbs(
- text_offset=[0, 5, 11],
- token_logprobs=[-0.1, -0.2, -0.3],
- tokens=["Hello", " ", "world"],
- top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
- )
- self.assertEqual(len(logprobs.tokens), 3)
- self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
- self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
-
- def test_choice_logprobs(self):
- """Test ChoiceLogprobs model"""
- token_logprob = ChatCompletionTokenLogprob(
- token="Hello",
- bytes=[72, 101, 108, 108, 111],
- logprob=-0.1,
- top_logprobs=[
- TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
- ],
- )
- choice_logprobs = ChoiceLogprobs(content=[token_logprob])
- self.assertEqual(len(choice_logprobs.content), 1)
- self.assertEqual(choice_logprobs.content[0].token, "Hello")
-
-
-class TestStreamingModels(unittest.TestCase):
- """Test streaming response models"""
-
- def test_stream_options(self):
- """Test StreamOptions model"""
- options = StreamOptions(include_usage=True)
- self.assertTrue(options.include_usage)
-
- def test_chat_completion_stream_response(self):
- """Test ChatCompletionStreamResponse model"""
- delta = DeltaMessage(role="assistant", content="Hello")
- choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
- response = ChatCompletionStreamResponse(
- id="test-id", model="test-model", choices=[choice]
- )
- self.assertEqual(response.object, "chat.completion.chunk")
- self.assertEqual(response.choices[0].delta.content, "Hello")
-
-
class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""
@@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase):
class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios"""
- def test_empty_messages_validation(self):
- """Test validation with empty messages"""
- with self.assertRaises(ValidationError):
- ChatCompletionRequest(model="test-model", messages=[])
-
def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}]
@@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase):
with self.assertRaises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
- def test_invalid_temperature_range(self):
- """Test invalid temperature values"""
- # Note: The current protocol doesn't enforce temperature range,
- # but this test documents expected behavior
- request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
- self.assertEqual(request.temperature, 5.0) # Currently allowed
-
def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized"""
original_request = ChatCompletionRequest(
diff --git a/test/srt/openai/test_server.py b/test/srt/openai/test_server.py
deleted file mode 100644
index 3de52f4cd..000000000
--- a/test/srt/openai/test_server.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# sglang/test/srt/openai/test_server.py
-import requests
-
-from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
-
-
-def test_health(openai_server: str):
- r = requests.get(f"{openai_server}/health")
- assert r.status_code == 200
- # FastAPI returns an empty body → r.text == ""
- assert r.text == ""
-
-
-def test_models_endpoint(openai_server: str):
- r = requests.get(f"{openai_server}/v1/models")
- assert r.status_code == 200, r.text
- payload = r.json()
-
- # Basic contract
- assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
-
- # Validate fields of the first model card
- first = payload["data"][0]
- for key in ("id", "root", "max_model_len"):
- assert key in first, f"missing {key} in {first}"
-
- # max_model_len must be positive
- assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
-
- # The server should report the same model id we launched it with
- ids = {m["id"] for m in payload["data"]}
- assert MODEL_ID in ids
-
-
-def test_get_model_info(openai_server: str):
- r = requests.get(f"{openai_server}/get_model_info")
- assert r.status_code == 200, r.text
- info = r.json()
-
- expected_keys = {"model_path", "tokenizer_path", "is_generation"}
- assert expected_keys.issubset(info.keys())
-
- # model_path must end with the one we passed on the CLI
- assert info["model_path"].endswith(MODEL_ID)
-
- # is_generation is documented as a boolean
- assert isinstance(info["is_generation"], bool)
-
-
-def test_unknown_route_returns_404(openai_server: str):
- r = requests.get(f"{openai_server}/definitely-not-a-real-route")
- assert r.status_code == 404
diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py
index ff38fccc7..701dc2e55 100644
--- a/test/srt/openai/test_serving_chat.py
+++ b/test/srt/openai/test_serving_chat.py
@@ -57,11 +57,21 @@ class _MockTokenizerManager:
self.create_abort_task = Mock()
+class _MockTemplateManager:
+ """Minimal mock for TemplateManager."""
+
+ def __init__(self):
+ self.chat_template_name: Optional[str] = "llama-3"
+ self.jinja_template_content_format: Optional[str] = None
+ self.completion_template_name: Optional[str] = None
+
+
class ServingChatTestCase(unittest.TestCase):
# ------------- common fixtures -------------
def setUp(self):
self.tm = _MockTokenizerManager()
- self.chat = OpenAIServingChat(self.tm)
+ self.template_manager = _MockTemplateManager()
+ self.chat = OpenAIServingChat(self.tm, self.template_manager)
# frequently reused requests
self.basic_req = ChatCompletionRequest(
@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req)
- # # ------------- tool-call branch -------------
- # def test_tool_call_request_conversion(self):
- # req = ChatCompletionRequest(
- # model="x",
- # messages=[{"role": "user", "content": "Weather?"}],
- # tools=[
- # {
- # "type": "function",
- # "function": {
- # "name": "get_weather",
- # "parameters": {"type": "object", "properties": {}},
- # },
- # }
- # ],
- # tool_choice="auto",
- # )
-
- # with patch.object(
- # self.chat,
- # "_process_messages",
- # return_value=("Prompt", [1, 2, 3], None, None, [], [""], None),
- # ):
- # adapted, _ = self.chat._convert_to_internal_request(req, "rid")
- # self.assertEqual(adapted.rid, "rid")
-
- # def test_tool_choice_none(self):
- # req = ChatCompletionRequest(
- # model="x",
- # messages=[{"role": "user", "content": "Hi"}],
- # tools=[{"type": "function", "function": {"name": "noop"}}],
- # tool_choice="none",
- # )
- # with patch.object(
- # self.chat,
- # "_process_messages",
- # return_value=("Prompt", [1, 2, 3], None, None, [], [""], None),
- # ):
- # adapted, _ = self.chat._convert_to_internal_request(req, "rid")
- # self.assertEqual(adapted.rid, "rid")
-
- # ------------- multimodal branch -------------
- def test_multimodal_request_with_images(self):
- self.tm.model_config.is_multimodal = True
-
- req = ChatCompletionRequest(
- model="x",
- messages=[
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "What's in the image?"},
- {
- "type": "image_url",
- "image_url": {"url": "data:image/jpeg;base64,"},
- },
- ],
- }
- ],
- )
-
- with patch.object(
- self.chat,
- "_apply_jinja_template",
- return_value=("prompt", [1, 2], ["img"], None, [], []),
- ), patch.object(
- self.chat,
- "_apply_conversation_template",
- return_value=("prompt", ["img"], None, [], []),
- ):
- out = self.chat._process_messages(req, True)
- _, _, image_data, *_ = out
- self.assertEqual(image_data, ["img"])
-
- # ------------- template handling -------------
- def test_jinja_template_processing(self):
- req = ChatCompletionRequest(
- model="x", messages=[{"role": "user", "content": "Hello"}]
- )
- self.tm.chat_template_name = None
- self.tm.tokenizer.chat_template = ""
-
- with patch.object(
- self.chat,
- "_apply_jinja_template",
- return_value=("processed", [1], None, None, [], [""]),
- ), patch("builtins.hasattr", return_value=True):
- prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
- self.assertEqual(prompt, "processed")
- self.assertEqual(prompt_ids, [1])
-
# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(
diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py
index 7a42523c7..c0568e93b 100644
--- a/test/srt/openai/test_serving_completions.py
+++ b/test/srt/openai/test_serving_completions.py
@@ -5,6 +5,7 @@ Run with:
"""
import unittest
+from typing import Optional
from unittest.mock import AsyncMock, Mock, patch
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
from sglang.srt.managers.tokenizer_manager import TokenizerManager
+class _MockTemplateManager:
+ """Minimal mock for TemplateManager."""
+
+ def __init__(self):
+ self.chat_template_name: Optional[str] = None
+ self.jinja_template_content_format: Optional[str] = None
+ self.completion_template_name: Optional[str] = (
+ None # Set to None to avoid template processing
+ )
+
+
class ServingCompletionTestCase(unittest.TestCase):
"""Bundle all prompt/echo tests in one TestCase."""
@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
tm.generate_request = AsyncMock()
tm.create_abort_task = Mock()
- self.sc = OpenAIServingCompletion(tm)
+ self.template_manager = _MockTemplateManager()
+ self.sc = OpenAIServingCompletion(tm, self.template_manager)
# ---------- prompt-handling ----------
def test_single_string_prompt(self):
@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
- def test_completion_template_handling(self):
- req = CompletionRequest(
- model="x", prompt="def f():", suffix="return 1", max_tokens=100
- )
- with patch(
- "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
- return_value=True,
- ), patch(
- "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
- return_value="processed_prompt",
- ):
- internal, _ = self.sc._convert_to_internal_request(req)
- self.assertEqual(internal.text, "processed_prompt")
-
# ---------- echo-handling ----------
def test_echo_with_string_prompt_streaming(self):
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py
index b6e3094df..1305e668d 100644
--- a/test/srt/openai/test_serving_embedding.py
+++ b/test/srt/openai/test_serving_embedding.py
@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
with the original adapter.py functionality and follows OpenAI API specifications.
"""
-import asyncio
-import json
-import time
import unittest
import uuid
-from typing import Any, Dict, List
-from unittest.mock import AsyncMock, Mock, patch
+from unittest.mock import Mock
from fastapi import Request
-from fastapi.responses import ORJSONResponse
-from pydantic_core import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
- EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
- ErrorResponse,
MultimodalEmbeddingInput,
- UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
@@ -58,11 +49,22 @@ class _MockTokenizerManager:
self.generate_request = Mock(return_value=mock_generate_embedding())
+# Mock TemplateManager for embedding tests
+class _MockTemplateManager:
+ def __init__(self):
+ self.chat_template_name = None # None for embeddings usually
+ self.jinja_template_content_format = None
+ self.completion_template_name = None
+
+
class ServingEmbeddingTestCase(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
self.tokenizer_manager = _MockTokenizerManager()
- self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
+ self.template_manager = _MockTemplateManager()
+ self.serving_embedding = OpenAIServingEmbedding(
+ self.tokenizer_manager, self.template_manager
+ )
self.request = Mock(spec=Request)
self.request.headers = {}
@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIsNone(adapted_request.image_data[1])
# self.assertEqual(adapted_request.rid, "test-id")
- def test_build_single_embedding_response(self):
- """Test building response for single embedding."""
- ret_data = [
- {
- "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
- "meta_info": {"prompt_tokens": 5},
- }
- ]
-
- response = self.serving_embedding._build_embedding_response(ret_data)
-
- self.assertIsInstance(response, EmbeddingResponse)
- self.assertEqual(response.model, "test-model")
- self.assertEqual(len(response.data), 1)
- self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
- self.assertEqual(response.data[0].index, 0)
- self.assertEqual(response.data[0].object, "embedding")
- self.assertEqual(response.usage.prompt_tokens, 5)
- self.assertEqual(response.usage.total_tokens, 5)
- self.assertEqual(response.usage.completion_tokens, 0)
-
- def test_build_multiple_embedding_response(self):
- """Test building response for multiple embeddings."""
- ret_data = [
- {
- "embedding": [0.1, 0.2, 0.3],
- "meta_info": {"prompt_tokens": 3},
- },
- {
- "embedding": [0.4, 0.5, 0.6],
- "meta_info": {"prompt_tokens": 4},
- },
- ]
-
- response = self.serving_embedding._build_embedding_response(ret_data)
-
- self.assertIsInstance(response, EmbeddingResponse)
- self.assertEqual(len(response.data), 2)
- self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
- self.assertEqual(response.data[0].index, 0)
- self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
- self.assertEqual(response.data[1].index, 1)
- self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
- self.assertEqual(response.usage.total_tokens, 7)
-
- def test_handle_request_success(self):
- """Test successful embedding request handling."""
-
- async def run_test():
- # Mock the generate_request to return expected data
- async def mock_generate():
- yield {
- "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
- "meta_info": {"prompt_tokens": 5},
- }
-
- self.serving_embedding.tokenizer_manager.generate_request = Mock(
- return_value=mock_generate()
- )
-
- response = await self.serving_embedding.handle_request(
- self.basic_req, self.request
- )
-
- self.assertIsInstance(response, EmbeddingResponse)
- self.assertEqual(len(response.data), 1)
- self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
-
- asyncio.run(run_test())
-
- def test_handle_request_validation_error(self):
- """Test handling request with validation error."""
-
- async def run_test():
- invalid_request = EmbeddingRequest(model="test-model", input="")
-
- response = await self.serving_embedding.handle_request(
- invalid_request, self.request
- )
-
- self.assertIsInstance(response, ORJSONResponse)
- self.assertEqual(response.status_code, 400)
-
- asyncio.run(run_test())
-
- def test_handle_request_generation_error(self):
- """Test handling request with generation error."""
-
- async def run_test():
- # Mock generate_request to raise an error
- async def mock_generate_error():
- raise ValueError("Generation failed")
- yield # This won't be reached but needed for async generator
-
- self.serving_embedding.tokenizer_manager.generate_request = Mock(
- return_value=mock_generate_error()
- )
-
- response = await self.serving_embedding.handle_request(
- self.basic_req, self.request
- )
-
- self.assertIsInstance(response, ORJSONResponse)
- self.assertEqual(response.status_code, 400)
-
- asyncio.run(run_test())
-
- def test_handle_request_internal_error(self):
- """Test handling request with internal server error."""
-
- async def run_test():
- # Mock _convert_to_internal_request to raise an exception
- with patch.object(
- self.serving_embedding,
- "_convert_to_internal_request",
- side_effect=Exception("Internal error"),
- ):
- response = await self.serving_embedding.handle_request(
- self.basic_req, self.request
- )
-
- self.assertIsInstance(response, ORJSONResponse)
- self.assertEqual(response.status_code, 500)
-
- asyncio.run(run_test())
-
if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py
index 327d084d9..26ff2b31b 100644
--- a/test/srt/run_suite.py
+++ b/test/srt/run_suite.py
@@ -29,6 +29,10 @@ suites = {
TestFile("models/test_reward_models.py", 132),
TestFile("models/test_vlm_models.py", 437),
TestFile("models/test_transformers_models.py", 320),
+ TestFile("openai/test_protocol.py", 10),
+ TestFile("openai/test_serving_chat.py", 10),
+ TestFile("openai/test_serving_completions.py", 10),
+ TestFile("openai/test_serving_embedding.py", 10),
TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22),
TestFile("test_create_kvindices.py", 2),
@@ -49,6 +53,7 @@ suites = {
TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 8),
TestFile("test_input_embeddings.py", 38),
+ TestFile("test_jinja_template_utils.py", 1),
TestFile("test_json_constrained.py", 98),
TestFile("test_large_max_new_tokens.py", 41),
TestFile("test_metrics.py", 32),
@@ -59,14 +64,8 @@ suites = {
TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
- TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
- TestFile("openai/test_server.py", 120),
- TestFile("openai/test_protocol.py", 60),
- TestFile("openai/test_serving_chat.py", 120),
- TestFile("openai/test_serving_completions.py", 120),
- TestFile("openai/test_serving_embedding.py", 120),
TestFile("test_openai_server_hidden_states.py", 240),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py
index ab6c8b999..35b75d715 100644
--- a/test/srt/test_function_call_parser.py
+++ b/test/srt/test_function_call_parser.py
@@ -3,6 +3,7 @@ import unittest
from xgrammar import GrammarCompiler, TokenizerInfo
+from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
@@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.hf_transformers_utils import get_tokenizer
-from sglang.srt.openai_api.protocol import Function, Tool
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_openai_adapter.py b/test/srt/test_jinja_template_utils.py
similarity index 95%
rename from test/srt/test_openai_adapter.py
rename to test/srt/test_jinja_template_utils.py
index 598ddfd49..b6bacd12f 100644
--- a/test/srt/test_openai_adapter.py
+++ b/test/srt/test_jinja_template_utils.py
@@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils.
import unittest
from unittest.mock import patch
-from sglang.srt.openai_api.utils import (
- detect_template_content_format,
+from sglang.srt.jinja_template_utils import (
+ detect_jinja_template_content_format,
process_content_for_template_format,
)
from sglang.test.test_utils import CustomTestCase
@@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
- result = detect_template_content_format(llama4_pattern)
+ result = detect_jinja_template_content_format(llama4_pattern)
self.assertEqual(result, "openai")
def test_detect_deepseek_string_format(self):
@@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
- result = detect_template_content_format(deepseek_pattern)
+ result = detect_jinja_template_content_format(deepseek_pattern)
self.assertEqual(result, "string")
def test_detect_invalid_template(self):
"""Test handling of invalid template (should default to 'string')."""
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
- result = detect_template_content_format(invalid_pattern)
+ result = detect_jinja_template_content_format(invalid_pattern)
self.assertEqual(result, "string")
def test_detect_empty_template(self):
"""Test handling of empty template (should default to 'string')."""
- result = detect_template_content_format("")
+ result = detect_jinja_template_content_format("")
self.assertEqual(result, "string")
def test_process_content_openai_format(self):
diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py
index 4913eb38c..8d5f5ad89 100644
--- a/test/srt/test_openai_server.py
+++ b/test/srt/test_openai_server.py
@@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase):
)
is_firsts = {}
+ is_finished = {}
for response in generator:
usage = response.usage
if usage is not None:
@@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase):
continue
index = response.choices[0].index
+ finish_reason = response.choices[0].finish_reason
+ if finish_reason is not None:
+ is_finished[index] = True
+
data = response.choices[0].delta
if is_firsts.get(index, True):
@@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase):
is_firsts[index] = False
continue
- if logprobs:
+ if logprobs and not is_finished.get(index, False):
assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
@@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase):
assert (
isinstance(data.content, str)
or isinstance(data.reasoning_content, str)
- or len(data.tool_calls) > 0
+ or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
or response.choices[0].finish_reason
)
assert response.id
@@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase):
index, True
), f"index {index} is not found in the response"
- def _create_batch(self, mode, client):
- if mode == "completion":
- input_file_path = "complete_input.jsonl"
- # write content to input file
- content = [
- {
- "custom_id": "request-1",
- "method": "POST",
- "url": "/v1/completions",
- "body": {
- "model": "gpt-3.5-turbo-instruct",
- "prompt": "List 3 names of famous soccer player: ",
- "max_tokens": 20,
- },
- },
- {
- "custom_id": "request-2",
- "method": "POST",
- "url": "/v1/completions",
- "body": {
- "model": "gpt-3.5-turbo-instruct",
- "prompt": "List 6 names of famous basketball player: ",
- "max_tokens": 40,
- },
- },
- {
- "custom_id": "request-3",
- "method": "POST",
- "url": "/v1/completions",
- "body": {
- "model": "gpt-3.5-turbo-instruct",
- "prompt": "List 6 names of famous tenniss player: ",
- "max_tokens": 40,
- },
- },
- ]
-
- else:
- input_file_path = "chat_input.jsonl"
- content = [
- {
- "custom_id": "request-1",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": "gpt-3.5-turbo-0125",
- "messages": [
- {
- "role": "system",
- "content": "You are a helpful assistant.",
- },
- {
- "role": "user",
- "content": "Hello! List 3 NBA players and tell a story",
- },
- ],
- "max_tokens": 30,
- },
- },
- {
- "custom_id": "request-2",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": "gpt-3.5-turbo-0125",
- "messages": [
- {"role": "system", "content": "You are an assistant. "},
- {
- "role": "user",
- "content": "Hello! List three capital and tell a story",
- },
- ],
- "max_tokens": 50,
- },
- },
- ]
-
- with open(input_file_path, "w") as file:
- for line in content:
- file.write(json.dumps(line) + "\n")
-
- with open(input_file_path, "rb") as file:
- uploaded_file = client.files.create(file=file, purpose="batch")
- if mode == "completion":
- endpoint = "/v1/completions"
- elif mode == "chat":
- endpoint = "/v1/chat/completions"
- completion_window = "24h"
- batch_job = client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint=endpoint,
- completion_window=completion_window,
- )
-
- return batch_job, content, uploaded_file
-
- def run_batch(self, mode):
- client = openai.Client(api_key=self.api_key, base_url=self.base_url)
- batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
-
- while batch_job.status not in ["completed", "failed", "cancelled"]:
- time.sleep(3)
- print(
- f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
- )
- batch_job = client.batches.retrieve(batch_job.id)
- assert (
- batch_job.status == "completed"
- ), f"Batch job status is not completed: {batch_job.status}"
- assert batch_job.request_counts.completed == len(content)
- assert batch_job.request_counts.failed == 0
- assert batch_job.request_counts.total == len(content)
-
- result_file_id = batch_job.output_file_id
- file_response = client.files.content(result_file_id)
- result_content = file_response.read().decode("utf-8") # Decode bytes to string
- results = [
- json.loads(line)
- for line in result_content.split("\n")
- if line.strip() != ""
- ]
- assert len(results) == len(content)
- for delete_fid in [uploaded_file.id, result_file_id]:
- del_pesponse = client.files.delete(delete_fid)
- assert del_pesponse.deleted
-
- def run_cancel_batch(self, mode):
- client = openai.Client(api_key=self.api_key, base_url=self.base_url)
- batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
-
- assert batch_job.status not in ["cancelling", "cancelled"]
-
- batch_job = client.batches.cancel(batch_id=batch_job.id)
- assert batch_job.status == "cancelling"
-
- while batch_job.status not in ["failed", "cancelled"]:
- batch_job = client.batches.retrieve(batch_job.id)
- print(
- f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
- )
- time.sleep(3)
-
- assert batch_job.status == "cancelled"
- del_response = client.files.delete(uploaded_file.id)
- assert del_response.deleted
-
def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
@@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase):
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num)
- def test_batch(self):
- for mode in ["completion", "chat"]:
- self.run_batch(mode)
-
- def test_cancel_batch(self):
- for mode in ["completion", "chat"]:
- self.run_cancel_batch(mode)
-
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
@@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white
assert len(models) == 1
assert isinstance(getattr(models[0], "max_model_len", None), int)
+ def test_retrieve_model(self):
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+
+ # Test retrieving an existing model
+ retrieved_model = client.models.retrieve(self.model)
+ self.assertEqual(retrieved_model.id, self.model)
+ self.assertEqual(retrieved_model.root, self.model)
+
+ # Test retrieving a non-existent model
+ with self.assertRaises(openai.NotFoundError):
+ client.models.retrieve("non-existent-model")
+
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
@@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase):
self.assertTrue(len(response.data[0].embedding) > 0)
self.assertTrue(len(response.data[1].embedding) > 0)
+ def test_embedding_single_batch_str(self):
+ """Test embedding with a List[str] and length equals to 1"""
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+ response = client.embeddings.create(model=self.model, input=["Hello world"])
+ self.assertEqual(len(response.data), 1)
+ self.assertTrue(len(response.data[0].embedding) > 0)
+
+ def test_embedding_single_int_list(self):
+ """Test embedding with a List[int] or List[List[int]]]"""
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+ response = client.embeddings.create(
+ model=self.model,
+ input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
+ )
+ self.assertEqual(len(response.data), 1)
+ self.assertTrue(len(response.data[0].embedding) > 0)
+
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+ response = client.embeddings.create(
+ model=self.model,
+ input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
+ )
+ self.assertEqual(len(response.data), 1)
+ self.assertTrue(len(response.data[0].embedding) > 0)
+
def test_empty_string_embedding(self):
"""Test embedding an empty string."""
diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py
index 5a90f0853..518ec8671 100644
--- a/test/srt/test_vlm_accuracy.py
+++ b/test/srt/test_vlm_accuracy.py
@@ -21,6 +21,7 @@ from transformers import (
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
+from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
@@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import (
MultimodalInputs,
)
from sglang.srt.model_executor.model_runner import ModelRunner
-from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs
diff --git a/test/srt/test_vlm_input_format.py b/test/srt/test_vlm_input_format.py
index 2911e04d1..d2670ecac 100644
--- a/test/srt/test_vlm_input_format.py
+++ b/test/srt/test_vlm_input_format.py
@@ -15,7 +15,7 @@ from transformers import (
from sglang import Engine
from sglang.srt.conversation import generate_chat_conv
-from sglang.srt.openai_api.protocol import ChatCompletionRequest
+from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"