feat(oai refactor): Replace openai_api with entrypoints/openai (#7351)
Co-authored-by: Jin Pan <jpan236@wisc.edu>
This commit is contained in:
@@ -20,7 +20,7 @@ from sglang.bench_serving import (
|
||||
get_gen_prefix_cache_path,
|
||||
)
|
||||
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
|
||||
from sglang.utils import encode_video_base64
|
||||
|
||||
# type of content fields, can be only prompts or with images/videos
|
||||
|
||||
@@ -64,11 +64,14 @@
|
||||
"text = \"Once upon a time\"\n",
|
||||
"\n",
|
||||
"curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
|
||||
" -H \"Content-Type: application/json\" \\\n",
|
||||
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
|
||||
"\n",
|
||||
"text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n",
|
||||
" \"embedding\"\n",
|
||||
"]\n",
|
||||
"result = subprocess.check_output(curl_text, shell=True)\n",
|
||||
"\n",
|
||||
"print(result)\n",
|
||||
"\n",
|
||||
"text_embedding = json.loads(result)[\"data\"][0][\"embedding\"]\n",
|
||||
"\n",
|
||||
"print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")"
|
||||
]
|
||||
@@ -152,6 +155,7 @@
|
||||
"input_ids = tokenizer.encode(text)\n",
|
||||
"\n",
|
||||
"curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
|
||||
" -H \"Content-Type: application/json\" \\\n",
|
||||
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
|
||||
"\n",
|
||||
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",
|
||||
|
||||
@@ -67,6 +67,7 @@
|
||||
"\n",
|
||||
"curl_command = f\"\"\"\n",
|
||||
"curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
|
||||
" -H \"Content-Type: application/json\" \\\\\n",
|
||||
" -d '{{\n",
|
||||
" \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
|
||||
" \"messages\": [\n",
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
"import requests\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
|
||||
"from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest\n",
|
||||
"from sglang.srt.conversation import chat_templates\n",
|
||||
"\n",
|
||||
"image = Image.open(\n",
|
||||
|
||||
@@ -15,9 +15,7 @@
|
||||
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import auto
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||
@@ -57,46 +55,6 @@ class CompletionTemplate:
|
||||
completion_templates: dict[str, CompletionTemplate] = {}
|
||||
|
||||
|
||||
def load_completion_template_for_openai_api(completion_template_arg):
|
||||
global completion_template_name
|
||||
|
||||
logger.info(
|
||||
f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
|
||||
)
|
||||
|
||||
if not completion_template_exists(completion_template_arg):
|
||||
if not os.path.exists(completion_template_arg):
|
||||
raise RuntimeError(
|
||||
f"Completion template {completion_template_arg} is not a built-in template name "
|
||||
"or a valid completion template file path."
|
||||
)
|
||||
|
||||
assert completion_template_arg.endswith(
|
||||
".json"
|
||||
), "unrecognized format of completion template file"
|
||||
with open(completion_template_arg, "r") as filep:
|
||||
template = json.load(filep)
|
||||
try:
|
||||
fim_position = FimPosition[template["fim_position"]]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Unknown fim position: {template['fim_position']}"
|
||||
) from None
|
||||
register_completion_template(
|
||||
CompletionTemplate(
|
||||
name=template["name"],
|
||||
fim_begin_token=template["fim_begin_token"],
|
||||
fim_middle_token=template["fim_middle_token"],
|
||||
fim_end_token=template["fim_end_token"],
|
||||
fim_position=fim_position,
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
completion_template_name = template["name"]
|
||||
else:
|
||||
completion_template_name = completion_template_arg
|
||||
|
||||
|
||||
def register_completion_template(template: CompletionTemplate, override: bool = False):
|
||||
"""Register a new completion template."""
|
||||
if not override:
|
||||
|
||||
@@ -11,7 +11,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Conversation chat templates."""
|
||||
"""Conversation chat templates.
|
||||
|
||||
This module provides conversation template definitions, data structures, and utilities
|
||||
for managing chat templates across different model types in SGLang.
|
||||
|
||||
Key components:
|
||||
- Conversation class: Defines the structure and behavior of chat templates
|
||||
- SeparatorStyle enum: Different conversation formatting styles
|
||||
- Template registry: Functions to register and retrieve templates by name or model path
|
||||
- Built-in templates: Pre-defined templates for popular models
|
||||
"""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
@@ -20,7 +30,7 @@ import re
|
||||
from enum import IntEnum, auto
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from sglang.srt.utils import read_system_prompt_from_file
|
||||
|
||||
|
||||
@@ -618,7 +628,7 @@ def generate_chat_conv(
|
||||
|
||||
|
||||
# llama2 template
|
||||
# reference: https://huggingface.co/blog/codellama#conversational-instructions
|
||||
# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
|
||||
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
import torch
|
||||
import uvloop
|
||||
|
||||
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
|
||||
from sglang.srt.entrypoints.EngineBase import EngineBase
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
guess_chat_template_name_from_model_path,
|
||||
load_chat_template_for_openai_api,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
@@ -123,12 +119,13 @@ class Engine(EngineBase):
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Launch subprocesses
|
||||
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args,
|
||||
port_args=port_args,
|
||||
)
|
||||
self.server_args = server_args
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
self.template_manager = template_manager
|
||||
self.scheduler_info = scheduler_info
|
||||
|
||||
context = zmq.Context(2)
|
||||
@@ -647,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
|
||||
def _launch_subprocesses(
|
||||
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
||||
) -> Tuple[TokenizerManager, Dict]:
|
||||
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
||||
"""
|
||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||
"""
|
||||
@@ -732,7 +729,7 @@ def _launch_subprocesses(
|
||||
|
||||
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
||||
# When using `Engine` as a Python API, we don't want to block here.
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
launch_dummy_health_check_server(server_args.host, server_args.port)
|
||||
|
||||
@@ -741,7 +738,7 @@ def _launch_subprocesses(
|
||||
logger.error(
|
||||
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
||||
)
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
# Launch detokenizer process
|
||||
detoken_proc = mp.Process(
|
||||
@@ -755,15 +752,15 @@ def _launch_subprocesses(
|
||||
|
||||
# Launch tokenizer process
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(
|
||||
tokenizer_manager, server_args.chat_template, server_args.model_path
|
||||
)
|
||||
else:
|
||||
guess_chat_template_name_from_model_path(server_args.model_path)
|
||||
|
||||
if server_args.completion_template:
|
||||
load_completion_template_for_openai_api(server_args.completion_template)
|
||||
# Initialize templates
|
||||
template_manager = TemplateManager()
|
||||
template_manager.initialize_templates(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
model_path=server_args.model_path,
|
||||
chat_template=server_args.chat_template,
|
||||
completion_template=server_args.completion_template,
|
||||
)
|
||||
|
||||
# Wait for the model to finish loading
|
||||
scheduler_infos = []
|
||||
@@ -787,4 +784,4 @@ def _launch_subprocesses(
|
||||
# Assume all schedulers have the same scheduler_info
|
||||
scheduler_info = scheduler_infos[0]
|
||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||
return tokenizer_manager, scheduler_info
|
||||
return tokenizer_manager, template_manager, scheduler_info
|
||||
|
||||
@@ -38,7 +38,8 @@ import orjson
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||
from fastapi import Depends, FastAPI, Request, UploadFile
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
@@ -47,6 +48,20 @@ from sglang.srt.disaggregation.utils import (
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
EmbeddingRequest,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ScoringRequest,
|
||||
V1RerankReqInput,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
||||
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -67,26 +82,11 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
V1RerankReqInput,
|
||||
VertexGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
v1_batches,
|
||||
v1_cancel_batch,
|
||||
v1_chat_completions,
|
||||
v1_completions,
|
||||
v1_delete_file,
|
||||
v1_embeddings,
|
||||
v1_files_create,
|
||||
v1_rerank,
|
||||
v1_retrieve_batch,
|
||||
v1_retrieve_file,
|
||||
v1_retrieve_file_content,
|
||||
v1_score,
|
||||
)
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -109,6 +109,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
@dataclasses.dataclass
|
||||
class _GlobalState:
|
||||
tokenizer_manager: TokenizerManager
|
||||
template_manager: TemplateManager
|
||||
scheduler_info: Dict
|
||||
|
||||
|
||||
@@ -123,6 +124,24 @@ def set_global_state(global_state: _GlobalState):
|
||||
@asynccontextmanager
|
||||
async def lifespan(fast_api_app: FastAPI):
|
||||
server_args: ServerArgs = fast_api_app.server_args
|
||||
|
||||
# Initialize OpenAI serving handlers
|
||||
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||
_global_state.tokenizer_manager, _global_state.template_manager
|
||||
)
|
||||
fast_api_app.state.openai_serving_chat = OpenAIServingChat(
|
||||
_global_state.tokenizer_manager, _global_state.template_manager
|
||||
)
|
||||
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
_global_state.tokenizer_manager, _global_state.template_manager
|
||||
)
|
||||
fast_api_app.state.openai_serving_score = OpenAIServingScore(
|
||||
_global_state.tokenizer_manager
|
||||
)
|
||||
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
|
||||
_global_state.tokenizer_manager
|
||||
)
|
||||
|
||||
if server_args.warmups is not None:
|
||||
await execute_warmups(
|
||||
server_args.warmups.split(","), _global_state.tokenizer_manager
|
||||
@@ -148,6 +167,36 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Custom exception handlers to change validation error status codes
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Override FastAPI's default 422 validation error with 400"""
|
||||
return ORJSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": exc.errors(),
|
||||
"body": exc.body,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def validate_json_request(raw_request: Request):
|
||||
"""Validate that the request content-type is application/json."""
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
media_type = content_type.split(";", maxsplit=1)[0]
|
||||
if media_type != "application/json":
|
||||
raise RequestValidationError(
|
||||
errors=[
|
||||
{
|
||||
"loc": ["header", "content-type"],
|
||||
"msg": "Unsupported Media Type: Only 'application/json' is allowed",
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
@@ -330,13 +379,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/v1/rerank", methods=["POST", "PUT"])
|
||||
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
|
||||
try:
|
||||
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
@app.api_route(
|
||||
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
||||
"""Endpoint for reranking documents based on query relevance."""
|
||||
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
||||
@@ -619,25 +669,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_v1_completions(raw_request: Request):
|
||||
return await v1_completions(_global_state.tokenizer_manager, raw_request)
|
||||
@app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
||||
async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
|
||||
"""OpenAI-compatible text completion endpoint."""
|
||||
return await raw_request.app.state.openai_serving_completion.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_v1_chat_completions(raw_request: Request):
|
||||
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
|
||||
async def openai_v1_chat_completions(
|
||||
request: ChatCompletionRequest, raw_request: Request
|
||||
):
|
||||
"""OpenAI-compatible chat completion endpoint."""
|
||||
return await raw_request.app.state.openai_serving_chat.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
||||
async def openai_v1_embeddings(raw_request: Request):
|
||||
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
|
||||
return response
|
||||
@app.post(
|
||||
"/v1/embeddings",
|
||||
response_class=ORJSONResponse,
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
)
|
||||
async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
|
||||
"""OpenAI-compatible embeddings endpoint."""
|
||||
return await raw_request.app.state.openai_serving_embedding.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/models", response_class=ORJSONResponse)
|
||||
def available_models():
|
||||
"""Show available models."""
|
||||
async def available_models():
|
||||
"""Show available models. OpenAI-compatible endpoint."""
|
||||
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
||||
model_cards = []
|
||||
for served_model_name in served_model_names:
|
||||
@@ -651,47 +715,31 @@ def available_models():
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
@app.post("/v1/files")
|
||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||
return await v1_files_create(
|
||||
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
|
||||
@app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
|
||||
async def retrieve_model(model: str):
|
||||
"""Retrieves a model instance, providing basic information about the model."""
|
||||
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
||||
|
||||
if model not in served_model_names:
|
||||
return ORJSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error": {
|
||||
"message": f"The model '{model}' does not exist",
|
||||
"type": "invalid_request_error",
|
||||
"param": "model",
|
||||
"code": "model_not_found",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return ModelCard(
|
||||
id=model,
|
||||
root=model,
|
||||
max_model_len=_global_state.tokenizer_manager.model_config.context_len,
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/v1/files/{file_id}")
|
||||
async def delete_file(file_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/files/delete
|
||||
return await v1_delete_file(file_id)
|
||||
|
||||
|
||||
@app.post("/v1/batches")
|
||||
async def openai_v1_batches(raw_request: Request):
|
||||
return await v1_batches(_global_state.tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
@app.post("/v1/batches/{batch_id}/cancel")
|
||||
async def cancel_batches(batch_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/batch/cancel
|
||||
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
|
||||
|
||||
|
||||
@app.get("/v1/batches/{batch_id}")
|
||||
async def retrieve_batch(batch_id: str):
|
||||
return await v1_retrieve_batch(batch_id)
|
||||
|
||||
|
||||
@app.get("/v1/files/{file_id}")
|
||||
async def retrieve_file(file_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/files/retrieve
|
||||
return await v1_retrieve_file(file_id)
|
||||
|
||||
|
||||
@app.get("/v1/files/{file_id}/content")
|
||||
async def retrieve_file_content(file_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||
return await v1_retrieve_file_content(file_id)
|
||||
|
||||
|
||||
## SageMaker API
|
||||
@app.get("/ping")
|
||||
async def sagemaker_health() -> Response:
|
||||
@@ -700,8 +748,13 @@ async def sagemaker_health() -> Response:
|
||||
|
||||
|
||||
@app.post("/invocations")
|
||||
async def sagemaker_chat_completions(raw_request: Request):
|
||||
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
||||
async def sagemaker_chat_completions(
|
||||
request: ChatCompletionRequest, raw_request: Request
|
||||
):
|
||||
"""OpenAI-compatible chat completion endpoint."""
|
||||
return await raw_request.app.state.openai_serving_chat.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
## Vertex AI API
|
||||
@@ -732,10 +785,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
||||
return ORJSONResponse({"predictions": ret})
|
||||
|
||||
|
||||
@app.post("/v1/score")
|
||||
async def v1_score_request(raw_request: Request):
|
||||
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||
return await v1_score(_global_state.tokenizer_manager, raw_request)
|
||||
return await raw_request.app.state.openai_serving_score.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
@@ -764,10 +819,13 @@ def launch_server(
|
||||
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
||||
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
||||
"""
|
||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args
|
||||
)
|
||||
set_global_state(
|
||||
_GlobalState(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
template_manager=template_manager,
|
||||
scheduler_info=scheduler_info,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
SGLang OpenAI-Compatible API Server.
|
||||
|
||||
This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response
|
||||
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_prometheus_middleware,
|
||||
delete_directory,
|
||||
get_bool_env_var,
|
||||
kill_process_tree,
|
||||
set_uvicorn_logging_configs,
|
||||
)
|
||||
from sglang.srt.warmup import execute_warmups
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
# Store global states
|
||||
class AppState:
|
||||
engine: Optional[Engine] = None
|
||||
server_args: Optional[ServerArgs] = None
|
||||
tokenizer_manager: Optional[TokenizerManager] = None
|
||||
scheduler_info: Optional[Dict] = None
|
||||
embedding_server: Optional[OpenAIServingEmbedding] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
app.state.server_args.enable_metrics = True # By default, we enable metrics
|
||||
|
||||
server_args = app.state.server_args
|
||||
|
||||
# Initialize engine
|
||||
logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...")
|
||||
|
||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||
app.state.tokenizer_manager = tokenizer_manager
|
||||
app.state.scheduler_info = scheduler_info
|
||||
app.state.serving_embedding = OpenAIServingEmbedding(
|
||||
tokenizer_manager=tokenizer_manager
|
||||
)
|
||||
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
# Initialize engine state attribute to None for now
|
||||
app.state.engine = None
|
||||
|
||||
if server_args.warmups is not None:
|
||||
await execute_warmups(
|
||||
server_args.warmups.split(","), app.state.tokenizer_manager
|
||||
)
|
||||
logger.info("Warmup ended")
|
||||
|
||||
warmup_thread = getattr(app, "warmup_thread", None)
|
||||
if warmup_thread is not None:
|
||||
warmup_thread.start()
|
||||
|
||||
yield
|
||||
|
||||
# Lifespan shutdown
|
||||
if hasattr(app.state, "engine") and app.state.engine is not None:
|
||||
logger.info("SGLang engine is shutting down.")
|
||||
# Add engine cleanup logic here when implemented
|
||||
|
||||
|
||||
# Fast API app with CORS enabled
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
# TODO: check where /openai.json is created or why we use this
|
||||
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/health", methods=["GET"])
|
||||
async def health() -> Response:
|
||||
"""Health check. Used for readiness and liveness probes."""
|
||||
# In the future, this could check engine health more deeply
|
||||
# For now, if the server is up, it's healthy.
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.api_route("/v1/models", methods=["GET"])
|
||||
async def show_models():
|
||||
"""Show available models. Currently, it returns the served model name.
|
||||
|
||||
This endpoint is compatible with the OpenAI API standard.
|
||||
"""
|
||||
served_model_names = [app.state.tokenizer_manager.served_model_name]
|
||||
model_cards = []
|
||||
for served_model_name in served_model_names:
|
||||
model_cards.append(
|
||||
ModelCard(
|
||||
id=served_model_name,
|
||||
root=served_model_name,
|
||||
max_model_len=app.state.tokenizer_manager.model_config.context_len,
|
||||
)
|
||||
)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
"""Get the model information."""
|
||||
result = {
|
||||
"model_path": app.state.tokenizer_manager.model_path,
|
||||
"tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path,
|
||||
"is_generation": app.state.tokenizer_manager.is_generation,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_v1_completions(raw_request: Request):
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_v1_chat_completions(raw_request: Request):
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
async def openai_v1_embeddings(raw_request: Request):
|
||||
try:
|
||||
request_json = await raw_request.json()
|
||||
request = EmbeddingRequest(**request_json)
|
||||
except Exception as e:
|
||||
return app.state.serving_embedding.create_error_response(
|
||||
f"Invalid request body, error: {str(e)}"
|
||||
)
|
||||
|
||||
ret = await app.state.serving_embedding.handle_request(request, raw_request)
|
||||
return ret
|
||||
|
||||
|
||||
@app.post("/v1/score")
|
||||
async def v1_score_request(raw_request: Request):
|
||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||
pass
|
||||
|
||||
|
||||
@app.api_route("/v1/models/{model_id}", methods=["GET"])
|
||||
async def show_model_detail(model_id: str):
|
||||
served_model_name = app.state.tokenizer_manager.served_model_name
|
||||
|
||||
return ModelCard(
|
||||
id=served_model_name,
|
||||
root=served_model_name,
|
||||
max_model_len=app.state.tokenizer_manager.model_config.context_len,
|
||||
)
|
||||
|
||||
|
||||
# Additional API endpoints will be implemented in separate serving_*.py modules
|
||||
# and mounted as APIRouters in future PRs
|
||||
|
||||
|
||||
def _wait_and_warmup(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||
image_token_text: str,
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
return
|
||||
# TODO: Please wait until the /generate implementation is complete,
|
||||
# or confirm if modifications are needed before removing this.
|
||||
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
||||
|
||||
# Wait until the server is launched
|
||||
success = False
|
||||
for _ in range(120):
|
||||
time.sleep(1)
|
||||
try:
|
||||
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
assert res.status_code == 200, f"{res=}, {res.text=}"
|
||||
success = True
|
||||
break
|
||||
except (AssertionError, requests.exceptions.RequestException):
|
||||
last_traceback = get_exception_traceback()
|
||||
pass
|
||||
|
||||
if not success:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
model_info = res.json()
|
||||
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
# TODO: Replace with OpenAI API
|
||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
}
|
||||
if server_args.skip_tokenizer_init:
|
||||
json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
|
||||
# TODO Workaround the bug that embedding errors for list of size 1
|
||||
if server_args.dp_size == 1:
|
||||
json_data["input_ids"] = json_data["input_ids"][0]
|
||||
else:
|
||||
json_data["text"] = ["The capital city of France is"] * server_args.dp_size
|
||||
# TODO Workaround the bug that embedding errors for list of size 1
|
||||
if server_args.dp_size == 1:
|
||||
json_data["text"] = json_data["text"][0]
|
||||
|
||||
# Debug dumping
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
json_data.pop("text", None)
|
||||
json_data["input_ids"] = np.load(
|
||||
server_args.debug_tensor_dump_input_file
|
||||
).tolist()
|
||||
json_data["sampling_params"]["max_new_tokens"] = 0
|
||||
|
||||
try:
|
||||
if server_args.disaggregation_mode == "null":
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200, f"{res}"
|
||||
else:
|
||||
logger.info(f"Start of prefill warmup ...")
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
|
||||
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
||||
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
||||
"bootstrap_room": [
|
||||
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
|
||||
for i in range(server_args.dp_size)
|
||||
],
|
||||
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
|
||||
}
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
||||
)
|
||||
logger.info(
|
||||
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
# Debug print
|
||||
# logger.info(f"{res.json()=}")
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("ready")
|
||||
|
||||
if server_args.delete_ckpt_after_loading:
|
||||
delete_directory(server_args.model_path)
|
||||
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
if server_args.pdlb_url is not None:
|
||||
register_disaggregation_server(
|
||||
server_args.disaggregation_mode,
|
||||
server_args.port,
|
||||
server_args.disaggregation_bootstrap_port,
|
||||
server_args.pdlb_url,
|
||||
)
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server")
|
||||
# Add arguments from ServerArgs. This allows reuse of existing CLI definitions.
|
||||
ServerArgs.add_cli_args(parser)
|
||||
# Potentially add server-specific arguments here in the future if needed
|
||||
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
# Store server_args in app.state for access in lifespan and endpoints
|
||||
app.state.server_args = server_args
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=server_args.log_level.upper(),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
|
||||
)
|
||||
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
None,
|
||||
None, # Never used
|
||||
None,
|
||||
),
|
||||
)
|
||||
app.warmup_thread = warmup_thread
|
||||
|
||||
try:
|
||||
# Start the server
|
||||
set_uvicorn_logging_configs()
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level.lower(),
|
||||
timeout_keep_alive=60, # Increased keep-alive for potentially long requests
|
||||
loop="uvloop", # Use uvloop for better performance if available
|
||||
)
|
||||
finally:
|
||||
warmup_thread.join()
|
||||
@@ -207,7 +207,7 @@ class CompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@@ -404,21 +404,13 @@ class ChatCompletionRequest(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_tool_choice_default(cls, values):
|
||||
if isinstance(values, dict):
|
||||
if values.get("tool_choice") is None:
|
||||
if values.get("tools") is None:
|
||||
values["tool_choice"] = "none"
|
||||
else:
|
||||
values["tool_choice"] = "auto"
|
||||
if values.get("tool_choice") is None:
|
||||
if values.get("tools") is None:
|
||||
values["tool_choice"] = "none"
|
||||
else:
|
||||
values["tool_choice"] = "auto"
|
||||
return values
|
||||
|
||||
@field_validator("messages")
|
||||
@classmethod
|
||||
def validate_messages_not_empty(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
return v
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
@@ -457,9 +449,11 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||
]
|
||||
finish_reason: Optional[
|
||||
Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||
]
|
||||
] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@@ -530,7 +524,7 @@ class EmbeddingRequest(BaseModel):
|
||||
input: EmbeddingInput
|
||||
model: str
|
||||
encoding_format: str = "float"
|
||||
dimensions: int = None
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# The request id.
|
||||
|
||||
@@ -2,16 +2,12 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
OpenAIServingRequest,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
@@ -51,7 +47,7 @@ class OpenAIServingBase(ABC):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in request: {e}")
|
||||
logger.exception(f"Error in request: {e}")
|
||||
return self.create_error_response(
|
||||
message=f"Internal server error: {str(e)}",
|
||||
err_type="InternalServerError",
|
||||
@@ -63,8 +59,12 @@ class OpenAIServingBase(ABC):
|
||||
"""Generate request ID based on request type"""
|
||||
pass
|
||||
|
||||
def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
|
||||
def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:
|
||||
"""Generate request ID based on request type"""
|
||||
return None
|
||||
|
||||
# TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError
|
||||
# Temporarily return None in this function until the rid logic is clear.
|
||||
if rid := getattr(request, "rid", None):
|
||||
return rid
|
||||
|
||||
@@ -83,7 +83,7 @@ class OpenAIServingBase(ABC):
|
||||
adapted_request: GenerateReqInput,
|
||||
request: OpenAIServingRequest,
|
||||
raw_request: Request,
|
||||
) -> StreamingResponse:
|
||||
) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:
|
||||
"""Handle streaming request
|
||||
|
||||
Override this method in child classes that support streaming requests.
|
||||
@@ -99,7 +99,7 @@ class OpenAIServingBase(ABC):
|
||||
adapted_request: GenerateReqInput,
|
||||
request: OpenAIServingRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[Any, ErrorResponse]:
|
||||
) -> Union[Any, ErrorResponse, ORJSONResponse]:
|
||||
"""Handle non-streaming request
|
||||
|
||||
Override this method in child classes that support non-streaming requests.
|
||||
@@ -110,7 +110,7 @@ class OpenAIServingBase(ABC):
|
||||
status_code=501,
|
||||
)
|
||||
|
||||
def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
|
||||
def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:
|
||||
"""Validate request"""
|
||||
pass
|
||||
|
||||
@@ -122,6 +122,7 @@ class OpenAIServingBase(ABC):
|
||||
param: Optional[str] = None,
|
||||
) -> ORJSONResponse:
|
||||
"""Create an error response"""
|
||||
# TODO: remove fastapi dependency in openai and move response handling to the entrypoint
|
||||
error = ErrorResponse(
|
||||
object="error",
|
||||
message=message,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@@ -6,7 +5,7 @@ import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
@@ -28,13 +27,14 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
detect_template_content_format,
|
||||
process_content_for_template_format,
|
||||
process_hidden_states_from_ret,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.jinja_template_utils import process_content_for_template_format
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.utils import convert_json_schema_to_str
|
||||
|
||||
@@ -42,13 +42,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServingBase):
|
||||
"""Handler for chat completion requests"""
|
||||
"""Handler for /v1/chat/completions requests"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Instance-specific cache for template content format detection
|
||||
self._cached_chat_template = None
|
||||
self._cached_template_format = None
|
||||
def __init__(
|
||||
self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager
|
||||
):
|
||||
super().__init__(tokenizer_manager)
|
||||
self.template_manager = template_manager
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "chatcmpl-"
|
||||
@@ -142,19 +142,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
)
|
||||
|
||||
# Use chat template
|
||||
if (
|
||||
hasattr(self.tokenizer_manager, "chat_template_name")
|
||||
and self.tokenizer_manager.chat_template_name is None
|
||||
):
|
||||
if self.template_manager.chat_template_name is None:
|
||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||
self._apply_jinja_template(request, tools, is_multimodal)
|
||||
)
|
||||
else:
|
||||
prompt, image_data, audio_data, modalities, stop = (
|
||||
self._apply_conversation_template(request)
|
||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||
self._apply_conversation_template(request, is_multimodal)
|
||||
)
|
||||
if not is_multimodal:
|
||||
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
||||
else:
|
||||
# Use raw prompt
|
||||
prompt_ids = request.messages
|
||||
@@ -181,23 +176,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
is_multimodal: bool,
|
||||
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
|
||||
"""Apply Jinja chat template"""
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
openai_compatible_messages = []
|
||||
image_data = []
|
||||
audio_data = []
|
||||
modalities = []
|
||||
|
||||
# Detect template content format
|
||||
current_template = self.tokenizer_manager.tokenizer.chat_template
|
||||
if current_template != self._cached_chat_template:
|
||||
self._cached_chat_template = current_template
|
||||
self._cached_template_format = detect_template_content_format(
|
||||
current_template
|
||||
)
|
||||
logger.info(
|
||||
f"Detected chat template content format: {self._cached_template_format}"
|
||||
)
|
||||
|
||||
template_content_format = self._cached_template_format
|
||||
template_content_format = self.template_manager.jinja_template_content_format
|
||||
|
||||
for message in request.messages:
|
||||
if message.content is None:
|
||||
@@ -262,14 +248,21 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if is_multimodal:
|
||||
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
|
||||
|
||||
stop = request.stop or []
|
||||
stop = request.stop
|
||||
image_data = image_data if image_data else None
|
||||
audio_data = audio_data if audio_data else None
|
||||
modalities = modalities if modalities else []
|
||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||
|
||||
def _apply_conversation_template(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]:
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
is_multimodal: bool,
|
||||
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
|
||||
"""Apply conversation template"""
|
||||
conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name)
|
||||
prompt = ""
|
||||
prompt_ids = []
|
||||
conv = generate_chat_conv(request, self.template_manager.chat_template_name)
|
||||
|
||||
# If we should continue the final assistant message, adjust the conversation.
|
||||
if (
|
||||
@@ -296,9 +289,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
else:
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
image_data = conv.image_data
|
||||
audio_data = conv.audio_data
|
||||
modalities = conv.modalities
|
||||
image_data = conv.image_data if conv.image_data else None
|
||||
audio_data = conv.audio_data if conv.audio_data else None
|
||||
modalities = conv.modalities if conv.modalities else []
|
||||
stop = conv.stop_str or [] if not request.ignore_eos else []
|
||||
|
||||
if request.stop:
|
||||
@@ -307,7 +300,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
else:
|
||||
stop.extend(request.stop)
|
||||
|
||||
return prompt, image_data, audio_data, modalities, stop
|
||||
if not is_multimodal:
|
||||
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
||||
|
||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||
|
||||
def _build_sampling_params(
|
||||
self,
|
||||
@@ -459,13 +455,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
|
||||
# Handle reasoning content
|
||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||
"enable_thinking", True
|
||||
)
|
||||
if (
|
||||
self.tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
and enable_thinking
|
||||
):
|
||||
reasoning_text, delta = self._process_reasoning_stream(
|
||||
index, delta, reasoning_parser_dict, content, request
|
||||
@@ -591,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
)
|
||||
yield f"data: {usage_chunk.model_dump_json()}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
except ValueError as e:
|
||||
error = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {error}\n\n"
|
||||
|
||||
@@ -602,7 +594,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
adapted_request: GenerateReqInput,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[ChatCompletionResponse, ErrorResponse]:
|
||||
) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]:
|
||||
"""Handle non-streaming chat completion request"""
|
||||
try:
|
||||
ret = await self.tokenizer_manager.generate_request(
|
||||
@@ -627,7 +619,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
request: ChatCompletionRequest,
|
||||
ret: List[Dict[str, Any]],
|
||||
created: int,
|
||||
) -> ChatCompletionResponse:
|
||||
) -> Union[ChatCompletionResponse, ORJSONResponse]:
|
||||
"""Build chat completion response from generation results"""
|
||||
choices = []
|
||||
|
||||
@@ -645,11 +637,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
# Handle reasoning content
|
||||
reasoning_text = None
|
||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||
"enable_thinking", True
|
||||
)
|
||||
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||
if reasoning_parser and request.separate_reasoning and enable_thinking:
|
||||
if reasoning_parser and request.separate_reasoning:
|
||||
try:
|
||||
parser = ReasoningParser(
|
||||
model_type=reasoning_parser, stream_reasoning=False
|
||||
@@ -691,9 +680,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
choices.append(choice_data)
|
||||
|
||||
# Calculate usage
|
||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||
usage = UsageProcessor.calculate_response_usage(
|
||||
ret, n_choices=request.n, enable_cache_report=cache_report
|
||||
ret,
|
||||
n_choices=request.n,
|
||||
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
@@ -821,6 +811,25 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
reasoning_parser = reasoning_parser_dict[index]
|
||||
return reasoning_parser.parse_stream_chunk(delta)
|
||||
|
||||
def _get_enable_thinking_from_request(request: ChatCompletionRequest) -> bool:
|
||||
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
||||
|
||||
NOTE: This parameter is only useful for models that support enable_thinking
|
||||
flag, such as Qwen3.
|
||||
|
||||
Args:
|
||||
request_obj: The request object (or an item from a list of requests).
|
||||
Returns:
|
||||
The boolean value of 'enable_thinking' if found and not True, otherwise True.
|
||||
"""
|
||||
if (
|
||||
hasattr(request, "chat_template_kwargs")
|
||||
and request.chat_template_kwargs
|
||||
and request.chat_template_kwargs.get("enable_thinking") is not None
|
||||
):
|
||||
return request.chat_template_kwargs.get("enable_thinking")
|
||||
return True
|
||||
|
||||
async def _process_tool_call_stream(
|
||||
self,
|
||||
index: int,
|
||||
|
||||
@@ -3,12 +3,9 @@ import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
|
||||
from sglang.srt.code_completion_parser import (
|
||||
generate_completion_prompt_from_request,
|
||||
is_completion_template_defined,
|
||||
)
|
||||
from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
@@ -24,12 +21,22 @@ from sglang.srt.entrypoints.openai.utils import (
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"""Handler for completion requests"""
|
||||
"""Handler for /v1/completion requests"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_manager: TokenizerManager,
|
||||
template_manager: TemplateManager,
|
||||
):
|
||||
super().__init__(tokenizer_manager)
|
||||
self.template_manager = template_manager
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "cmpl-"
|
||||
@@ -47,7 +54,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
)
|
||||
# Process prompt
|
||||
prompt = request.prompt
|
||||
if is_completion_template_defined():
|
||||
if self.template_manager.completion_template_name is not None:
|
||||
prompt = generate_completion_prompt_from_request(request)
|
||||
|
||||
# Set logprob start length based on echo and logprobs
|
||||
@@ -141,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
hidden_states = {}
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
@@ -152,6 +160,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
hidden_states[index] = content["meta_info"].get("hidden_states", None)
|
||||
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
# Handle echo for first chunk
|
||||
@@ -192,7 +201,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
hidden_states = content["meta_info"].get("hidden_states", None)
|
||||
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
@@ -269,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
adapted_request: GenerateReqInput,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[CompletionResponse, ErrorResponse]:
|
||||
) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
|
||||
"""Handle non-streaming completion request"""
|
||||
try:
|
||||
generator = self.tokenizer_manager.generate_request(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
from sglang.srt.conversation import generate_embedding_convs
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
@@ -13,10 +14,20 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
"""Handler for embedding requests"""
|
||||
"""Handler for v1/embeddings requests"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_manager: TokenizerManager,
|
||||
template_manager: TemplateManager,
|
||||
):
|
||||
super().__init__(tokenizer_manager)
|
||||
self.template_manager = template_manager
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "embd-"
|
||||
@@ -68,11 +79,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif isinstance(prompt, list):
|
||||
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||
# List of strings - if it's a single string in a list, treat as single string
|
||||
if len(prompt) == 1:
|
||||
prompt_kwargs = {"text": prompt[0]}
|
||||
else:
|
||||
prompt_kwargs = {"text": prompt}
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
||||
# Handle multimodal embedding inputs
|
||||
texts = []
|
||||
@@ -84,11 +91,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
|
||||
generate_prompts = []
|
||||
# Check if we have a chat template for multimodal embeddings
|
||||
chat_template_name = getattr(
|
||||
self.tokenizer_manager, "chat_template_name", None
|
||||
)
|
||||
if chat_template_name is not None:
|
||||
convs = generate_embedding_convs(texts, images, chat_template_name)
|
||||
if self.template_manager.chat_template_name is not None:
|
||||
convs = generate_embedding_convs(
|
||||
texts, images, self.template_manager.chat_template_name
|
||||
)
|
||||
for conv in convs:
|
||||
generate_prompts.append(conv.get_prompt())
|
||||
else:
|
||||
@@ -122,7 +128,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
adapted_request: EmbeddingReqInput,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:
|
||||
"""Handle the embedding request"""
|
||||
try:
|
||||
ret = await self.tokenizer_manager.generate_request(
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
@@ -15,7 +16,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingRerank(OpenAIServingBase):
|
||||
"""Handler for rerank requests"""
|
||||
"""Handler for /v1/rerank requests"""
|
||||
|
||||
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
|
||||
# to another module in the future.
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "rerank-"
|
||||
@@ -61,7 +65,7 @@ class OpenAIServingRerank(OpenAIServingBase):
|
||||
adapted_request: EmbeddingReqInput,
|
||||
request: V1RerankReqInput,
|
||||
raw_request: Request,
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:
|
||||
"""Handle the rerank request"""
|
||||
try:
|
||||
ret = await self.tokenizer_manager.generate_request(
|
||||
@@ -74,16 +78,16 @@ class OpenAIServingRerank(OpenAIServingBase):
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = self._build_rerank_response(ret, request)
|
||||
return response
|
||||
responses = self._build_rerank_response(ret, request)
|
||||
return responses
|
||||
|
||||
def _build_rerank_response(
|
||||
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
|
||||
) -> List[RerankResponse]:
|
||||
"""Build the rerank response from generation results"""
|
||||
response = []
|
||||
responses = []
|
||||
for idx, ret_item in enumerate(ret):
|
||||
response.append(
|
||||
responses.append(
|
||||
RerankResponse(
|
||||
score=ret_item["embedding"],
|
||||
document=request.documents[idx],
|
||||
@@ -93,6 +97,6 @@ class OpenAIServingRerank(OpenAIServingBase):
|
||||
)
|
||||
|
||||
# Sort by score in descending order (highest relevance first)
|
||||
response.sort(key=lambda x: x.score, reverse=True)
|
||||
responses.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
return response
|
||||
return responses
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
@@ -14,7 +14,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingScore(OpenAIServingBase):
|
||||
"""Handler for scoring requests"""
|
||||
"""Handler for /v1/score requests"""
|
||||
|
||||
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
|
||||
# to another module in the future.
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "score-"
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import jinja2.nodes
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
@@ -13,168 +10,6 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# JINJA TEMPLATE CONTENT FORMAT DETECTION
|
||||
# ============================================================================
|
||||
#
|
||||
# This adapts vLLM's approach for detecting chat template content format:
|
||||
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
|
||||
# - Analyzes Jinja template AST to detect content iteration patterns
|
||||
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
|
||||
# - 'string' format: templates that expect simple string content
|
||||
# - Processes content accordingly to match template expectations
|
||||
|
||||
|
||||
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||
"""Check if node is a variable access like {{ varname }}"""
|
||||
if isinstance(node, jinja2.nodes.Name):
|
||||
return node.ctx == "load" and node.name == varname
|
||||
return False
|
||||
|
||||
|
||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
|
||||
if isinstance(node, jinja2.nodes.Getitem):
|
||||
return (
|
||||
_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key
|
||||
)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getattr):
|
||||
return _is_var_access(node.node, varname) and node.attr == key
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_var_or_elems_access(
|
||||
node: jinja2.nodes.Node,
|
||||
varname: str,
|
||||
key: str = None,
|
||||
) -> bool:
|
||||
"""Check if node accesses varname or varname[key] with filters/tests"""
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return node.node is not None and _is_var_or_elems_access(
|
||||
node.node, varname, key
|
||||
)
|
||||
if isinstance(node, jinja2.nodes.Test):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||
node.arg, jinja2.nodes.Slice
|
||||
):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
||||
|
||||
|
||||
def _try_extract_ast(chat_template: str):
|
||||
"""Try to parse the Jinja template into an AST"""
|
||||
try:
|
||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||
return jinja_compiled.environment.parse(chat_template)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error when compiling Jinja template: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def detect_template_content_format(chat_template: str) -> str:
|
||||
"""
|
||||
Detect whether a chat template expects 'string' or 'openai' content format.
|
||||
|
||||
- 'string': content is a simple string (like DeepSeek templates)
|
||||
- 'openai': content is a list of structured dicts (like Llama4 templates)
|
||||
|
||||
Detection logic:
|
||||
- If template has loops like {%- for content in message['content'] -%} → 'openai'
|
||||
- Otherwise → 'string'
|
||||
"""
|
||||
jinja_ast = _try_extract_ast(chat_template)
|
||||
if jinja_ast is None:
|
||||
return "string"
|
||||
|
||||
try:
|
||||
# Look for patterns like: {%- for content in message['content'] -%}
|
||||
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
|
||||
# Check if iterating over message['content'] or similar
|
||||
if _is_var_or_elems_access(loop_iter, "message", "content"):
|
||||
return "openai" # Found content iteration → openai format
|
||||
|
||||
return "string" # No content loops found → string format
|
||||
except Exception as e:
|
||||
logger.debug(f"Error when parsing AST of Jinja template: {e}")
|
||||
return "string"
|
||||
|
||||
|
||||
def process_content_for_template_format(
|
||||
msg_dict: dict,
|
||||
content_format: str,
|
||||
image_data: list,
|
||||
audio_data: list,
|
||||
modalities: list,
|
||||
) -> dict:
|
||||
"""
|
||||
Process message content based on detected template format.
|
||||
|
||||
Args:
|
||||
msg_dict: Message dictionary with content
|
||||
content_format: 'string' or 'openai' (detected via AST analysis)
|
||||
image_data: List to append extracted image URLs
|
||||
audio_data: List to append extracted audio URLs
|
||||
modalities: List to append modalities
|
||||
|
||||
Returns:
|
||||
Processed message dictionary
|
||||
"""
|
||||
if not isinstance(msg_dict.get("content"), list):
|
||||
# Already a string or None, no processing needed
|
||||
return {k: v for k, v in msg_dict.items() if v is not None}
|
||||
|
||||
if content_format == "openai":
|
||||
# OpenAI format: preserve structured content list, normalize types
|
||||
processed_content_parts = []
|
||||
for chunk in msg_dict["content"]:
|
||||
if isinstance(chunk, dict):
|
||||
chunk_type = chunk.get("type")
|
||||
|
||||
if chunk_type == "image_url":
|
||||
image_data.append(chunk["image_url"]["url"])
|
||||
if chunk.get("modalities"):
|
||||
modalities.append(chunk.get("modalities"))
|
||||
# Normalize to simple 'image' type for template compatibility
|
||||
processed_content_parts.append({"type": "image"})
|
||||
elif chunk_type == "audio_url":
|
||||
audio_data.append(chunk["audio_url"]["url"])
|
||||
# Normalize to simple 'audio' type
|
||||
processed_content_parts.append({"type": "audio"})
|
||||
else:
|
||||
# Keep other content as-is (text, etc.)
|
||||
processed_content_parts.append(chunk)
|
||||
|
||||
new_msg = {
|
||||
k: v for k, v in msg_dict.items() if v is not None and k != "content"
|
||||
}
|
||||
new_msg["content"] = processed_content_parts
|
||||
return new_msg
|
||||
|
||||
else: # content_format == "string"
|
||||
# String format: flatten to text only (for templates like DeepSeek)
|
||||
text_parts = []
|
||||
for chunk in msg_dict["content"]:
|
||||
if isinstance(chunk, dict) and chunk.get("type") == "text":
|
||||
text_parts.append(chunk["text"])
|
||||
# Note: For string format, we ignore images/audio since the template
|
||||
# doesn't expect structured content - multimodal placeholders would
|
||||
# need to be inserted differently
|
||||
|
||||
new_msg = msg_dict.copy()
|
||||
new_msg["content"] = " ".join(text_parts) if text_parts else ""
|
||||
new_msg = {k: v for k, v in new_msg.items() if v is not None}
|
||||
return new_msg
|
||||
|
||||
|
||||
def to_openai_style_logprobs(
|
||||
input_token_logprobs=None,
|
||||
output_token_logprobs=None,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
||||
from partial_json_parser.core.exceptions import MalformedJSON
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
ToolCallItem,
|
||||
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
|
||||
_is_complete_json,
|
||||
_partial_json_loads,
|
||||
)
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.function_call.utils import _is_complete_json
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
StructuralTagResponseFormat,
|
||||
StructuresResponseFormat,
|
||||
Tool,
|
||||
ToolChoice,
|
||||
)
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import ToolCallItem
|
||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
StructuralTagResponseFormat,
|
||||
StructuresResponseFormat,
|
||||
Tool,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
Utility functions for OpenAI API adapter.
|
||||
"""Template utilities for Jinja template processing.
|
||||
|
||||
This module provides utilities for analyzing and processing Jinja chat templates,
|
||||
including content format detection and message processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import jinja2.nodes
|
||||
import jinja2
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
|
||||
return None
|
||||
|
||||
|
||||
def detect_template_content_format(chat_template: str) -> str:
|
||||
def detect_jinja_template_content_format(chat_template: str) -> str:
|
||||
"""
|
||||
Detect whether a chat template expects 'string' or 'openai' content format.
|
||||
|
||||
@@ -864,12 +864,6 @@ class SetInternalStateReq:
|
||||
server_args: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class V1RerankReqInput:
|
||||
query: str
|
||||
documents: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetInternalStateReqOutput:
|
||||
updated: bool
|
||||
|
||||
226
python/sglang/srt/managers/template_manager.py
Normal file
226
python/sglang/srt/managers/template_manager.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Centralized template management for chat templates and completion templates.
|
||||
|
||||
This module provides a unified interface for managing both chat conversation templates
|
||||
and code completion templates, eliminating global state and improving modularity.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from sglang.srt.code_completion_parser import (
|
||||
CompletionTemplate,
|
||||
FimPosition,
|
||||
completion_template_exists,
|
||||
register_completion_template,
|
||||
)
|
||||
from sglang.srt.conversation import (
|
||||
Conversation,
|
||||
SeparatorStyle,
|
||||
chat_template_exists,
|
||||
get_conv_template_by_model_path,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""
|
||||
Centralized manager for chat and completion templates.
|
||||
|
||||
This class encapsulates all template-related state and operations,
|
||||
eliminating the need for global variables and providing a clean
|
||||
interface for template management.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._chat_template_name: Optional[str] = None
|
||||
self._completion_template_name: Optional[str] = None
|
||||
self._jinja_template_content_format: Optional[str] = None
|
||||
|
||||
@property
|
||||
def chat_template_name(self) -> Optional[str]:
|
||||
"""Get the current chat template name."""
|
||||
return self._chat_template_name
|
||||
|
||||
@property
|
||||
def completion_template_name(self) -> Optional[str]:
|
||||
"""Get the current completion template name."""
|
||||
return self._completion_template_name
|
||||
|
||||
@property
|
||||
def jinja_template_content_format(self) -> Optional[str]:
|
||||
"""Get the detected template content format ('string' or 'openai' or None)."""
|
||||
return self._jinja_template_content_format
|
||||
|
||||
def load_chat_template(
|
||||
self, tokenizer_manager, chat_template_arg: str, model_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Load a chat template from various sources.
|
||||
|
||||
Args:
|
||||
tokenizer_manager: The tokenizer manager instance
|
||||
chat_template_arg: Template name or file path
|
||||
model_path: Path to the model
|
||||
"""
|
||||
logger.info(f"Loading chat template: {chat_template_arg}")
|
||||
|
||||
if not chat_template_exists(chat_template_arg):
|
||||
if not os.path.exists(chat_template_arg):
|
||||
raise RuntimeError(
|
||||
f"Chat template {chat_template_arg} is not a built-in template name "
|
||||
"or a valid chat template file path."
|
||||
)
|
||||
|
||||
if chat_template_arg.endswith(".jinja"):
|
||||
self._load_jinja_template(tokenizer_manager, chat_template_arg)
|
||||
else:
|
||||
self._load_json_chat_template(chat_template_arg)
|
||||
else:
|
||||
self._chat_template_name = chat_template_arg
|
||||
|
||||
def guess_chat_template_from_model_path(self, model_path: str) -> None:
|
||||
"""
|
||||
Infer chat template name from model path.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
"""
|
||||
template_name = get_conv_template_by_model_path(model_path)
|
||||
if template_name is not None:
|
||||
logger.info(f"Inferred chat template from model path: {template_name}")
|
||||
self._chat_template_name = template_name
|
||||
|
||||
def load_completion_template(self, completion_template_arg: str) -> None:
|
||||
"""
|
||||
Load completion template for code completion.
|
||||
|
||||
Args:
|
||||
completion_template_arg: Template name or file path
|
||||
"""
|
||||
logger.info(f"Loading completion template: {completion_template_arg}")
|
||||
|
||||
if not completion_template_exists(completion_template_arg):
|
||||
if not os.path.exists(completion_template_arg):
|
||||
raise RuntimeError(
|
||||
f"Completion template {completion_template_arg} is not a built-in template name "
|
||||
"or a valid completion template file path."
|
||||
)
|
||||
|
||||
self._load_json_completion_template(completion_template_arg)
|
||||
else:
|
||||
self._completion_template_name = completion_template_arg
|
||||
|
||||
def initialize_templates(
|
||||
self,
|
||||
tokenizer_manager,
|
||||
model_path: str,
|
||||
chat_template: Optional[str] = None,
|
||||
completion_template: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize all templates based on provided configuration.
|
||||
|
||||
Args:
|
||||
tokenizer_manager: The tokenizer manager instance
|
||||
model_path: Path to the model
|
||||
chat_template: Optional chat template name/path
|
||||
completion_template: Optional completion template name/path
|
||||
"""
|
||||
# Load chat template
|
||||
if chat_template:
|
||||
self.load_chat_template(tokenizer_manager, chat_template, model_path)
|
||||
else:
|
||||
self.guess_chat_template_from_model_path(model_path)
|
||||
|
||||
# Load completion template
|
||||
if completion_template:
|
||||
self.load_completion_template(completion_template)
|
||||
|
||||
def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
|
||||
"""Load a Jinja template file."""
|
||||
with open(template_path, "r") as f:
|
||||
chat_template = "".join(f.readlines()).strip("\n")
|
||||
tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
|
||||
self._chat_template_name = None
|
||||
# Detect content format from the loaded template
|
||||
self._jinja_template_content_format = detect_jinja_template_content_format(
|
||||
chat_template
|
||||
)
|
||||
logger.info(
|
||||
f"Detected chat template content format: {self._jinja_template_content_format}"
|
||||
)
|
||||
|
||||
def _load_json_chat_template(self, template_path: str) -> None:
|
||||
"""Load a JSON chat template file."""
|
||||
assert template_path.endswith(
|
||||
".json"
|
||||
), "unrecognized format of chat template file"
|
||||
|
||||
with open(template_path, "r") as filep:
|
||||
template = json.load(filep)
|
||||
try:
|
||||
sep_style = SeparatorStyle[template["sep_style"]]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Unknown separator style: {template['sep_style']}"
|
||||
) from None
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name=template["name"],
|
||||
system_template=template["system"] + "\n{system_message}",
|
||||
system_message=template.get("system_message", ""),
|
||||
roles=(template["user"], template["assistant"]),
|
||||
sep_style=sep_style,
|
||||
sep=template.get("sep", "\n"),
|
||||
stop_str=template["stop_str"],
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
self._chat_template_name = template["name"]
|
||||
|
||||
def _load_json_completion_template(self, template_path: str) -> None:
|
||||
"""Load a JSON completion template file."""
|
||||
assert template_path.endswith(
|
||||
".json"
|
||||
), "unrecognized format of completion template file"
|
||||
|
||||
with open(template_path, "r") as filep:
|
||||
template = json.load(filep)
|
||||
try:
|
||||
fim_position = FimPosition[template["fim_position"]]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Unknown fim position: {template['fim_position']}"
|
||||
) from None
|
||||
|
||||
register_completion_template(
|
||||
CompletionTemplate(
|
||||
name=template["name"],
|
||||
fim_begin_token=template["fim_begin_token"],
|
||||
fim_middle_token=template["fim_middle_token"],
|
||||
fim_end_token=template["fim_end_token"],
|
||||
fim_position=fim_position,
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
self._completion_template_name = template["name"]
|
||||
@@ -1058,12 +1058,7 @@ class TokenizerManager:
|
||||
"lora_path",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(["text", "output_ids", "embedding"])
|
||||
elif self.log_requests_level == 1:
|
||||
max_length = 2048
|
||||
elif self.log_requests_level == 2:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,551 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Pydantic models for OpenAI API protocol"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_serializer, root_validator
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
"""Model cards."""
|
||||
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "sglang"
|
||||
root: Optional[str] = None
|
||||
max_model_len: Optional[int] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
"""Model list consists of model cards."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
object: str = "error"
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str] = None
|
||||
code: int
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TopLogprob(BaseModel):
|
||||
token: str
|
||||
bytes: List[int]
|
||||
logprob: float
|
||||
|
||||
|
||||
class ChatCompletionTokenLogprob(BaseModel):
|
||||
token: str
|
||||
bytes: List[int]
|
||||
logprob: float
|
||||
top_logprobs: List[TopLogprob]
|
||||
|
||||
|
||||
class ChoiceLogprobs(BaseModel):
|
||||
# build for v1/chat/completions response
|
||||
content: List[ChatCompletionTokenLogprob]
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
# only used to return cached tokens when --enable-cache-report is set
|
||||
prompt_tokens_details: Optional[Dict[str, int]] = None
|
||||
|
||||
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: Optional[bool] = False
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
# use alias to workaround pydantic conflict
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
strict: Optional[bool] = False
|
||||
|
||||
|
||||
class FileRequest(BaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/files/create
|
||||
file: bytes # The File object (not file name) to be uploaded
|
||||
purpose: str = (
|
||||
"batch" # The intended purpose of the uploaded file, default is "batch"
|
||||
)
|
||||
|
||||
|
||||
class FileResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "file"
|
||||
bytes: int
|
||||
created_at: int
|
||||
filename: str
|
||||
purpose: str
|
||||
|
||||
|
||||
class FileDeleteResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "file"
|
||||
deleted: bool
|
||||
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
input_file_id: (
|
||||
str # The ID of an uploaded file that contains requests for the new batch
|
||||
)
|
||||
endpoint: str # The endpoint to be used for all requests in the batch
|
||||
completion_window: str # The time frame within which the batch should be processed
|
||||
metadata: Optional[dict] = None # Optional custom metadata for the batch
|
||||
|
||||
|
||||
class BatchResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "batch"
|
||||
endpoint: str
|
||||
errors: Optional[dict] = None
|
||||
input_file_id: str
|
||||
completion_window: str
|
||||
status: str = "validating"
|
||||
output_file_id: Optional[str] = None
|
||||
error_file_id: Optional[str] = None
|
||||
created_at: int
|
||||
in_progress_at: Optional[int] = None
|
||||
expires_at: Optional[int] = None
|
||||
finalizing_at: Optional[int] = None
|
||||
completed_at: Optional[int] = None
|
||||
failed_at: Optional[int] = None
|
||||
expired_at: Optional[int] = None
|
||||
cancelling_at: Optional[int] = None
|
||||
cancelled_at: Optional[int] = None
|
||||
request_counts: Optional[dict] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: bool = False
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: int = 16
|
||||
n: int = 1
|
||||
presence_penalty: float = 0.0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[str] = None
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
user: Optional[str] = None
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
min_tokens: int = 0
|
||||
json_schema: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
return_hidden_states: Optional[bool] = False
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[str] = None
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class ChatCompletionMessageContentTextPart(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ChatCompletionMessageContentImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentAudioURL(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ChatCompletionMessageContentImagePart(BaseModel):
|
||||
type: Literal["image_url"]
|
||||
image_url: ChatCompletionMessageContentImageURL
|
||||
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentAudioPart(BaseModel):
|
||||
type: Literal["audio_url"]
|
||||
audio_url: ChatCompletionMessageContentAudioURL
|
||||
|
||||
|
||||
ChatCompletionMessageContentPart = Union[
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentAudioPart,
|
||||
]
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call response."""
|
||||
|
||||
id: Optional[str] = None
|
||||
index: Optional[int] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionResponse
|
||||
|
||||
|
||||
class ChatCompletionMessageGenericParam(BaseModel):
|
||||
role: Literal["system", "assistant", "tool"]
|
||||
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
|
||||
tool_call_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionMessageUserParam(BaseModel):
|
||||
role: Literal["user"]
|
||||
content: Union[str, List[ChatCompletionMessageContentPart]]
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[
|
||||
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
|
||||
]
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StructuresResponseFormat(BaseModel):
|
||||
begin: str
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
strict: bool = False
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Function wrapper."""
|
||||
|
||||
type: str = Field(default="function", examples=["function"])
|
||||
function: Function
|
||||
|
||||
|
||||
class ToolChoiceFuncName(BaseModel):
|
||||
"""The name of tool choice function."""
|
||||
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
"""The tool choice definition."""
|
||||
|
||||
function: ToolChoiceFuncName
|
||||
type: Literal["function"] = Field(default="function", examples=["function"])
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
|
||||
description="The maximum number of tokens that can be generated in the chat completion. ",
|
||||
)
|
||||
max_completion_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The maximum number of completion tokens for a chat completion request, "
|
||||
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
|
||||
)
|
||||
n: int = 1
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: float = 0.7
|
||||
top_p: float = 1.0
|
||||
user: Optional[str] = None
|
||||
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
||||
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
||||
default="auto", examples=["none"]
|
||||
) # noqa
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_tool_choice_default(cls, values):
|
||||
if values.get("tool_choice") is None:
|
||||
if values.get("tools") is None:
|
||||
values["tool_choice"] = "none"
|
||||
else:
|
||||
values["tool_choice"] = "auto"
|
||||
return values
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
min_tokens: int = 0
|
||||
regex: Optional[str] = None
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
continue_final_message: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
separate_reasoning: bool = True
|
||||
stream_reasoning: bool = True
|
||||
chat_template_kwargs: Optional[Dict] = None
|
||||
|
||||
# The request id.
|
||||
rid: Optional[str] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[str] = None
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
# Hidden States
|
||||
return_hidden_states: Optional[bool] = False
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||
]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: Optional[
|
||||
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
|
||||
] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class MultimodalEmbeddingInput(BaseModel):
|
||||
text: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
input: Union[
|
||||
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
||||
]
|
||||
model: str
|
||||
encoding_format: str = "float"
|
||||
dimensions: int = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# The request id.
|
||||
rid: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingObject(BaseModel):
|
||||
embedding: List[float]
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
data: List[EmbeddingObject]
|
||||
model: str
|
||||
object: str = "list"
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class ScoringRequest(BaseModel):
|
||||
query: Optional[Union[str, List[int]]] = (
|
||||
None # Query text or pre-tokenized token IDs
|
||||
)
|
||||
items: Optional[Union[str, List[str], List[List[int]]]] = (
|
||||
None # Item text(s) or pre-tokenized token IDs
|
||||
)
|
||||
label_token_ids: Optional[List[int]] = (
|
||||
None # Token IDs to compute probabilities for
|
||||
)
|
||||
apply_softmax: bool = False
|
||||
item_first: bool = False
|
||||
model: str
|
||||
|
||||
|
||||
class ScoringResponse(BaseModel):
|
||||
scores: List[
|
||||
List[float]
|
||||
] # List of lists of probabilities, each in the order of label_token_ids
|
||||
model: str
|
||||
usage: Optional[UsageInfo] = None
|
||||
object: str = "scoring"
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
score: float
|
||||
document: str
|
||||
index: int
|
||||
meta_info: Optional[dict] = None
|
||||
|
||||
|
||||
def exclude_if_none(obj, field_names: List[str]):
|
||||
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
|
||||
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
|
||||
|
||||
class StreamingParseResult:
|
||||
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
|
||||
One-time parsing: Detects and parses reasoning sections in the provided text.
|
||||
Returns both reasoning content and normal text separately.
|
||||
"""
|
||||
text = text.replace(self.think_start_token, "").strip()
|
||||
if self.think_end_token not in text:
|
||||
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
|
||||
|
||||
if not in_reasoning:
|
||||
return StreamingParseResult(normal_text=text)
|
||||
|
||||
# The text is considered to be in a reasoning block.
|
||||
processed_text = text.replace(self.think_start_token, "").strip()
|
||||
|
||||
if self.think_end_token not in processed_text:
|
||||
# Assume reasoning was truncated before `</think>` token
|
||||
return StreamingParseResult(reasoning_text=text)
|
||||
return StreamingParseResult(reasoning_text=processed_text)
|
||||
|
||||
# Extract reasoning content
|
||||
splits = text.split(self.think_end_token, maxsplit=1)
|
||||
splits = processed_text.split(self.think_end_token, maxsplit=1)
|
||||
reasoning_text = splits[0]
|
||||
text = splits[1].strip()
|
||||
normal_text = splits[1].strip()
|
||||
|
||||
return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
|
||||
return StreamingParseResult(
|
||||
normal_text=normal_text, reasoning_text=reasoning_text
|
||||
)
|
||||
|
||||
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
||||
"""
|
||||
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
|
||||
if not self.stripped_think_start and self.think_start_token in current_text:
|
||||
current_text = current_text.replace(self.think_start_token, "")
|
||||
self.stripped_think_start = True
|
||||
self._in_reasoning = True
|
||||
|
||||
# Handle end of reasoning block
|
||||
if self._in_reasoning and self.think_end_token in current_text:
|
||||
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
|
||||
"""
|
||||
|
||||
def __init__(self, stream_reasoning: bool = True):
|
||||
# Qwen3 is assumed to be reasoning until `</think>` token
|
||||
# Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
|
||||
super().__init__(
|
||||
"<think>",
|
||||
"</think>",
|
||||
force_reasoning=True,
|
||||
force_reasoning=False,
|
||||
stream_reasoning=stream_reasoning,
|
||||
)
|
||||
|
||||
@@ -151,12 +161,12 @@ class ReasoningParser:
|
||||
If True, streams reasoning content as it arrives.
|
||||
"""
|
||||
|
||||
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
|
||||
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
|
||||
"deepseek-r1": DeepSeekR1Detector,
|
||||
"qwen3": Qwen3Detector,
|
||||
}
|
||||
|
||||
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
|
||||
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
|
||||
if not model_type:
|
||||
raise ValueError("Model type must be specified")
|
||||
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
# sglang/test/srt/openai/conftest.py
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import closing
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
|
||||
DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
|
||||
|
||||
|
||||
def _pick_free_port() -> int:
|
||||
with closing(socket.socket()) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None:
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < timeout:
|
||||
if proc.poll() is not None: # crashed
|
||||
raise RuntimeError("api_server terminated prematurely")
|
||||
try:
|
||||
if requests.get(f"{base}/health", timeout=1).status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(0.4)
|
||||
raise RuntimeError("api_server readiness probe timed out")
|
||||
|
||||
|
||||
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
|
||||
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
|
||||
port = _pick_free_port()
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
SERVER_MODULE,
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
*map(str, kw.get("args", [])),
|
||||
]
|
||||
env = {**os.environ, **kw.get("env", {})}
|
||||
|
||||
# Write logs to a temp file so the child never blocks on a full pipe.
|
||||
log_file = tempfile.NamedTemporaryFile("w+", delete=False)
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
base = f"http://127.0.0.1:{port}"
|
||||
try:
|
||||
_wait_until_healthy(proc, base, STARTUP_TIMEOUT)
|
||||
except Exception as e:
|
||||
proc.terminate()
|
||||
proc.wait(5)
|
||||
log_file.seek(0)
|
||||
print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr)
|
||||
raise e
|
||||
return proc, base, log_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openai_server() -> Generator[str, None, None]:
|
||||
"""PyTest fixture that provides the server's base URL and cleans up."""
|
||||
proc, base, log_file = launch_openai_server()
|
||||
yield base
|
||||
kill_process_tree(proc.pid)
|
||||
log_file.close()
|
||||
@@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
class TestModelCard(unittest.TestCase):
|
||||
"""Test ModelCard protocol model"""
|
||||
|
||||
def test_basic_model_card_creation(self):
|
||||
"""Test basic model card creation with required fields"""
|
||||
card = ModelCard(id="test-model")
|
||||
self.assertEqual(card.id, "test-model")
|
||||
self.assertEqual(card.object, "model")
|
||||
self.assertEqual(card.owned_by, "sglang")
|
||||
self.assertIsInstance(card.created, int)
|
||||
self.assertIsNone(card.root)
|
||||
self.assertIsNone(card.max_model_len)
|
||||
|
||||
def test_model_card_with_optional_fields(self):
|
||||
"""Test model card with optional fields"""
|
||||
card = ModelCard(
|
||||
id="test-model",
|
||||
root="/path/to/model",
|
||||
max_model_len=2048,
|
||||
created=1234567890,
|
||||
)
|
||||
self.assertEqual(card.id, "test-model")
|
||||
self.assertEqual(card.root, "/path/to/model")
|
||||
self.assertEqual(card.max_model_len, 2048)
|
||||
self.assertEqual(card.created, 1234567890)
|
||||
|
||||
def test_model_card_serialization(self):
|
||||
"""Test model card JSON serialization"""
|
||||
card = ModelCard(id="test-model", max_model_len=4096)
|
||||
@@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase):
|
||||
self.assertEqual(model_list.data[1].id, "model-2")
|
||||
|
||||
|
||||
class TestErrorResponse(unittest.TestCase):
|
||||
"""Test ErrorResponse protocol model"""
|
||||
|
||||
def test_basic_error_response(self):
|
||||
"""Test basic error response creation"""
|
||||
error = ErrorResponse(
|
||||
message="Invalid request", type="BadRequestError", code=400
|
||||
)
|
||||
self.assertEqual(error.object, "error")
|
||||
self.assertEqual(error.message, "Invalid request")
|
||||
self.assertEqual(error.type, "BadRequestError")
|
||||
self.assertEqual(error.code, 400)
|
||||
self.assertIsNone(error.param)
|
||||
|
||||
def test_error_response_with_param(self):
|
||||
"""Test error response with parameter"""
|
||||
error = ErrorResponse(
|
||||
message="Invalid temperature",
|
||||
type="ValidationError",
|
||||
code=422,
|
||||
param="temperature",
|
||||
)
|
||||
self.assertEqual(error.param, "temperature")
|
||||
|
||||
|
||||
class TestUsageInfo(unittest.TestCase):
|
||||
"""Test UsageInfo protocol model"""
|
||||
|
||||
def test_basic_usage_info(self):
|
||||
"""Test basic usage info creation"""
|
||||
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
self.assertEqual(usage.prompt_tokens, 10)
|
||||
self.assertEqual(usage.completion_tokens, 20)
|
||||
self.assertEqual(usage.total_tokens, 30)
|
||||
self.assertIsNone(usage.prompt_tokens_details)
|
||||
|
||||
def test_usage_info_with_cache_details(self):
|
||||
"""Test usage info with cache details"""
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
total_tokens=30,
|
||||
prompt_tokens_details={"cached_tokens": 5},
|
||||
)
|
||||
self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
|
||||
|
||||
|
||||
class TestCompletionRequest(unittest.TestCase):
|
||||
"""Test CompletionRequest protocol model"""
|
||||
|
||||
@@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase):
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertFalse(request.echo) # default
|
||||
|
||||
def test_completion_request_with_options(self):
|
||||
"""Test completion request with various options"""
|
||||
request = CompletionRequest(
|
||||
model="test-model",
|
||||
prompt=["Hello", "world"],
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
n=2,
|
||||
stream=True,
|
||||
echo=True,
|
||||
stop=[".", "!"],
|
||||
logprobs=5,
|
||||
)
|
||||
self.assertEqual(request.prompt, ["Hello", "world"])
|
||||
self.assertEqual(request.max_tokens, 100)
|
||||
self.assertEqual(request.temperature, 0.7)
|
||||
self.assertEqual(request.top_p, 0.9)
|
||||
self.assertEqual(request.n, 2)
|
||||
self.assertTrue(request.stream)
|
||||
self.assertTrue(request.echo)
|
||||
self.assertEqual(request.stop, [".", "!"])
|
||||
self.assertEqual(request.logprobs, 5)
|
||||
|
||||
def test_completion_request_sglang_extensions(self):
|
||||
"""Test completion request with SGLang-specific extensions"""
|
||||
request = CompletionRequest(
|
||||
@@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase):
|
||||
CompletionRequest(model="test-model") # missing prompt
|
||||
|
||||
|
||||
class TestCompletionResponse(unittest.TestCase):
|
||||
"""Test CompletionResponse protocol model"""
|
||||
|
||||
def test_basic_completion_response(self):
|
||||
"""Test basic completion response"""
|
||||
choice = CompletionResponseChoice(
|
||||
index=0, text="Hello world!", finish_reason="stop"
|
||||
)
|
||||
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
|
||||
response = CompletionResponse(
|
||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||
)
|
||||
self.assertEqual(response.id, "test-id")
|
||||
self.assertEqual(response.object, "text_completion")
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertEqual(len(response.choices), 1)
|
||||
self.assertEqual(response.choices[0].text, "Hello world!")
|
||||
self.assertEqual(response.usage.total_tokens, 5)
|
||||
|
||||
|
||||
class TestChatCompletionRequest(unittest.TestCase):
|
||||
"""Test ChatCompletionRequest protocol model"""
|
||||
|
||||
@@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase):
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertEqual(request.tool_choice, "none") # default when no tools
|
||||
|
||||
def test_chat_completion_with_multimodal_content(self):
|
||||
"""Test chat completion with multimodal content"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
request = ChatCompletionRequest(model="test-model", messages=messages)
|
||||
self.assertEqual(len(request.messages[0].content), 2)
|
||||
self.assertEqual(request.messages[0].content[0].type, "text")
|
||||
self.assertEqual(request.messages[0].content[1].type, "image_url")
|
||||
|
||||
def test_chat_completion_with_tools(self):
|
||||
"""Test chat completion with tools"""
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model", messages=messages, tools=tools
|
||||
)
|
||||
self.assertEqual(len(request.tools), 1)
|
||||
self.assertEqual(request.tools[0].function.name, "get_weather")
|
||||
self.assertEqual(request.tool_choice, "auto") # default when tools present
|
||||
|
||||
def test_chat_completion_tool_choice_validation(self):
|
||||
"""Test tool choice validation logic"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
@@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase):
|
||||
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||
|
||||
|
||||
class TestChatCompletionResponse(unittest.TestCase):
|
||||
"""Test ChatCompletionResponse protocol model"""
|
||||
|
||||
def test_basic_chat_completion_response(self):
|
||||
"""Test basic chat completion response"""
|
||||
message = ChatMessage(role="assistant", content="Hello there!")
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0, message=message, finish_reason="stop"
|
||||
)
|
||||
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||
)
|
||||
self.assertEqual(response.id, "test-id")
|
||||
self.assertEqual(response.object, "chat.completion")
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertEqual(len(response.choices), 1)
|
||||
self.assertEqual(response.choices[0].message.content, "Hello there!")
|
||||
|
||||
def test_chat_completion_response_with_tool_calls(self):
|
||||
"""Test chat completion response with tool calls"""
|
||||
tool_call = ToolCall(
|
||||
id="call_123",
|
||||
function=FunctionResponse(
|
||||
name="get_weather", arguments='{"location": "San Francisco"}'
|
||||
),
|
||||
)
|
||||
message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0, message=message, finish_reason="tool_calls"
|
||||
)
|
||||
usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||
)
|
||||
self.assertEqual(
|
||||
response.choices[0].message.tool_calls[0].function.name, "get_weather"
|
||||
)
|
||||
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
|
||||
|
||||
|
||||
class TestEmbeddingRequest(unittest.TestCase):
|
||||
"""Test EmbeddingRequest protocol model"""
|
||||
|
||||
def test_basic_embedding_request(self):
|
||||
"""Test basic embedding request"""
|
||||
request = EmbeddingRequest(model="test-model", input="Hello world")
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(request.input, "Hello world")
|
||||
self.assertEqual(request.encoding_format, "float") # default
|
||||
self.assertIsNone(request.dimensions) # default
|
||||
|
||||
def test_embedding_request_with_list_input(self):
|
||||
"""Test embedding request with list input"""
|
||||
request = EmbeddingRequest(
|
||||
model="test-model", input=["Hello", "world"], dimensions=512
|
||||
)
|
||||
self.assertEqual(request.input, ["Hello", "world"])
|
||||
self.assertEqual(request.dimensions, 512)
|
||||
|
||||
def test_multimodal_embedding_request(self):
|
||||
"""Test multimodal embedding request"""
|
||||
multimodal_input = [
|
||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
||||
MultimodalEmbeddingInput(text="World", image=None),
|
||||
]
|
||||
request = EmbeddingRequest(model="test-model", input=multimodal_input)
|
||||
self.assertEqual(len(request.input), 2)
|
||||
self.assertEqual(request.input[0].text, "Hello")
|
||||
self.assertEqual(request.input[0].image, "base64_image_data")
|
||||
self.assertEqual(request.input[1].text, "World")
|
||||
self.assertIsNone(request.input[1].image)
|
||||
|
||||
|
||||
class TestEmbeddingResponse(unittest.TestCase):
|
||||
"""Test EmbeddingResponse protocol model"""
|
||||
|
||||
def test_basic_embedding_response(self):
|
||||
"""Test basic embedding response"""
|
||||
embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
|
||||
usage = UsageInfo(prompt_tokens=3, total_tokens=3)
|
||||
response = EmbeddingResponse(
|
||||
data=[embedding_obj], model="test-model", usage=usage
|
||||
)
|
||||
self.assertEqual(response.object, "list")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.usage.prompt_tokens, 3)
|
||||
|
||||
|
||||
class TestScoringRequest(unittest.TestCase):
|
||||
"""Test ScoringRequest protocol model"""
|
||||
|
||||
def test_basic_scoring_request(self):
|
||||
"""Test basic scoring request"""
|
||||
request = ScoringRequest(
|
||||
model="test-model", query="Hello", items=["World", "Earth"]
|
||||
)
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(request.query, "Hello")
|
||||
self.assertEqual(request.items, ["World", "Earth"])
|
||||
self.assertFalse(request.apply_softmax) # default
|
||||
self.assertFalse(request.item_first) # default
|
||||
|
||||
def test_scoring_request_with_token_ids(self):
|
||||
"""Test scoring request with token IDs"""
|
||||
request = ScoringRequest(
|
||||
model="test-model",
|
||||
query=[1, 2, 3],
|
||||
items=[[4, 5], [6, 7]],
|
||||
label_token_ids=[8, 9],
|
||||
apply_softmax=True,
|
||||
item_first=True,
|
||||
)
|
||||
self.assertEqual(request.query, [1, 2, 3])
|
||||
self.assertEqual(request.items, [[4, 5], [6, 7]])
|
||||
self.assertEqual(request.label_token_ids, [8, 9])
|
||||
self.assertTrue(request.apply_softmax)
|
||||
self.assertTrue(request.item_first)
|
||||
|
||||
|
||||
class TestScoringResponse(unittest.TestCase):
|
||||
"""Test ScoringResponse protocol model"""
|
||||
|
||||
def test_basic_scoring_response(self):
|
||||
"""Test basic scoring response"""
|
||||
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
|
||||
self.assertEqual(response.object, "scoring")
|
||||
self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertIsNone(response.usage) # default
|
||||
|
||||
|
||||
class TestFileOperations(unittest.TestCase):
|
||||
"""Test file operation protocol models"""
|
||||
|
||||
def test_file_request(self):
|
||||
"""Test file request model"""
|
||||
file_data = b"test file content"
|
||||
request = FileRequest(file=file_data, purpose="batch")
|
||||
self.assertEqual(request.file, file_data)
|
||||
self.assertEqual(request.purpose, "batch")
|
||||
|
||||
def test_file_response(self):
|
||||
"""Test file response model"""
|
||||
response = FileResponse(
|
||||
id="file-123",
|
||||
bytes=1024,
|
||||
created_at=1234567890,
|
||||
filename="test.jsonl",
|
||||
purpose="batch",
|
||||
)
|
||||
self.assertEqual(response.id, "file-123")
|
||||
self.assertEqual(response.object, "file")
|
||||
self.assertEqual(response.bytes, 1024)
|
||||
self.assertEqual(response.filename, "test.jsonl")
|
||||
|
||||
def test_file_delete_response(self):
|
||||
"""Test file delete response model"""
|
||||
response = FileDeleteResponse(id="file-123", deleted=True)
|
||||
self.assertEqual(response.id, "file-123")
|
||||
self.assertEqual(response.object, "file")
|
||||
self.assertTrue(response.deleted)
|
||||
|
||||
|
||||
class TestBatchOperations(unittest.TestCase):
|
||||
"""Test batch operation protocol models"""
|
||||
|
||||
def test_batch_request(self):
|
||||
"""Test batch request model"""
|
||||
request = BatchRequest(
|
||||
input_file_id="file-123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"custom": "value"},
|
||||
)
|
||||
self.assertEqual(request.input_file_id, "file-123")
|
||||
self.assertEqual(request.endpoint, "/v1/chat/completions")
|
||||
self.assertEqual(request.completion_window, "24h")
|
||||
self.assertEqual(request.metadata, {"custom": "value"})
|
||||
|
||||
def test_batch_response(self):
|
||||
"""Test batch response model"""
|
||||
response = BatchResponse(
|
||||
id="batch-123",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="file-123",
|
||||
completion_window="24h",
|
||||
created_at=1234567890,
|
||||
)
|
||||
self.assertEqual(response.id, "batch-123")
|
||||
self.assertEqual(response.object, "batch")
|
||||
self.assertEqual(response.status, "validating") # default
|
||||
self.assertEqual(response.endpoint, "/v1/chat/completions")
|
||||
|
||||
|
||||
class TestResponseFormats(unittest.TestCase):
|
||||
"""Test response format protocol models"""
|
||||
|
||||
def test_basic_response_format(self):
|
||||
"""Test basic response format"""
|
||||
format_obj = ResponseFormat(type="json_object")
|
||||
self.assertEqual(format_obj.type, "json_object")
|
||||
self.assertIsNone(format_obj.json_schema)
|
||||
|
||||
def test_json_schema_response_format(self):
|
||||
"""Test JSON schema response format"""
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
json_schema = JsonSchemaResponseFormat(
|
||||
name="person_schema", description="Person schema", schema=schema
|
||||
)
|
||||
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
|
||||
self.assertEqual(format_obj.type, "json_schema")
|
||||
self.assertEqual(format_obj.json_schema.name, "person_schema")
|
||||
self.assertEqual(format_obj.json_schema.schema_, schema)
|
||||
|
||||
def test_structural_tag_response_format(self):
|
||||
"""Test structural tag response format"""
|
||||
structures = [
|
||||
{
|
||||
"begin": "<thinking>",
|
||||
"schema_": {"type": "string"},
|
||||
"end": "</thinking>",
|
||||
}
|
||||
]
|
||||
format_obj = StructuralTagResponseFormat(
|
||||
type="structural_tag", structures=structures, triggers=["think"]
|
||||
)
|
||||
self.assertEqual(format_obj.type, "structural_tag")
|
||||
self.assertEqual(len(format_obj.structures), 1)
|
||||
self.assertEqual(format_obj.triggers, ["think"])
|
||||
|
||||
|
||||
class TestLogProbs(unittest.TestCase):
|
||||
"""Test LogProbs protocol models"""
|
||||
|
||||
def test_basic_logprobs(self):
|
||||
"""Test basic LogProbs model"""
|
||||
logprobs = LogProbs(
|
||||
text_offset=[0, 5, 11],
|
||||
token_logprobs=[-0.1, -0.2, -0.3],
|
||||
tokens=["Hello", " ", "world"],
|
||||
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
|
||||
)
|
||||
self.assertEqual(len(logprobs.tokens), 3)
|
||||
self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
|
||||
self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
|
||||
|
||||
def test_choice_logprobs(self):
|
||||
"""Test ChoiceLogprobs model"""
|
||||
token_logprob = ChatCompletionTokenLogprob(
|
||||
token="Hello",
|
||||
bytes=[72, 101, 108, 108, 111],
|
||||
logprob=-0.1,
|
||||
top_logprobs=[
|
||||
TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
|
||||
],
|
||||
)
|
||||
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
|
||||
self.assertEqual(len(choice_logprobs.content), 1)
|
||||
self.assertEqual(choice_logprobs.content[0].token, "Hello")
|
||||
|
||||
|
||||
class TestStreamingModels(unittest.TestCase):
|
||||
"""Test streaming response models"""
|
||||
|
||||
def test_stream_options(self):
|
||||
"""Test StreamOptions model"""
|
||||
options = StreamOptions(include_usage=True)
|
||||
self.assertTrue(options.include_usage)
|
||||
|
||||
def test_chat_completion_stream_response(self):
|
||||
"""Test ChatCompletionStreamResponse model"""
|
||||
delta = DeltaMessage(role="assistant", content="Hello")
|
||||
choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
|
||||
response = ChatCompletionStreamResponse(
|
||||
id="test-id", model="test-model", choices=[choice]
|
||||
)
|
||||
self.assertEqual(response.object, "chat.completion.chunk")
|
||||
self.assertEqual(response.choices[0].delta.content, "Hello")
|
||||
|
||||
|
||||
class TestModelSerialization(unittest.TestCase):
|
||||
"""Test model serialization with hidden states"""
|
||||
|
||||
@@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase):
|
||||
class TestValidationEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases and validation scenarios"""
|
||||
|
||||
def test_empty_messages_validation(self):
|
||||
"""Test validation with empty messages"""
|
||||
with self.assertRaises(ValidationError):
|
||||
ChatCompletionRequest(model="test-model", messages=[])
|
||||
|
||||
def test_invalid_tool_choice_type(self):
|
||||
"""Test invalid tool choice type"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
@@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase):
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
||||
|
||||
def test_invalid_temperature_range(self):
|
||||
"""Test invalid temperature values"""
|
||||
# Note: The current protocol doesn't enforce temperature range,
|
||||
# but this test documents expected behavior
|
||||
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
|
||||
self.assertEqual(request.temperature, 5.0) # Currently allowed
|
||||
|
||||
def test_model_serialization_roundtrip(self):
|
||||
"""Test that models can be serialized and deserialized"""
|
||||
original_request = ChatCompletionRequest(
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# sglang/test/srt/openai/test_server.py
|
||||
import requests
|
||||
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
|
||||
|
||||
|
||||
def test_health(openai_server: str):
|
||||
r = requests.get(f"{openai_server}/health")
|
||||
assert r.status_code == 200
|
||||
# FastAPI returns an empty body → r.text == ""
|
||||
assert r.text == ""
|
||||
|
||||
|
||||
def test_models_endpoint(openai_server: str):
|
||||
r = requests.get(f"{openai_server}/v1/models")
|
||||
assert r.status_code == 200, r.text
|
||||
payload = r.json()
|
||||
|
||||
# Basic contract
|
||||
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
|
||||
|
||||
# Validate fields of the first model card
|
||||
first = payload["data"][0]
|
||||
for key in ("id", "root", "max_model_len"):
|
||||
assert key in first, f"missing {key} in {first}"
|
||||
|
||||
# max_model_len must be positive
|
||||
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
|
||||
|
||||
# The server should report the same model id we launched it with
|
||||
ids = {m["id"] for m in payload["data"]}
|
||||
assert MODEL_ID in ids
|
||||
|
||||
|
||||
def test_get_model_info(openai_server: str):
|
||||
r = requests.get(f"{openai_server}/get_model_info")
|
||||
assert r.status_code == 200, r.text
|
||||
info = r.json()
|
||||
|
||||
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
|
||||
assert expected_keys.issubset(info.keys())
|
||||
|
||||
# model_path must end with the one we passed on the CLI
|
||||
assert info["model_path"].endswith(MODEL_ID)
|
||||
|
||||
# is_generation is documented as a boolean
|
||||
assert isinstance(info["is_generation"], bool)
|
||||
|
||||
|
||||
def test_unknown_route_returns_404(openai_server: str):
|
||||
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
|
||||
assert r.status_code == 404
|
||||
@@ -57,11 +57,21 @@ class _MockTokenizerManager:
|
||||
self.create_abort_task = Mock()
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = None
|
||||
|
||||
|
||||
class ServingChatTestCase(unittest.TestCase):
|
||||
# ------------- common fixtures -------------
|
||||
def setUp(self):
|
||||
self.tm = _MockTokenizerManager()
|
||||
self.chat = OpenAIServingChat(self.tm)
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.chat = OpenAIServingChat(self.tm, self.template_manager)
|
||||
|
||||
# frequently reused requests
|
||||
self.basic_req = ChatCompletionRequest(
|
||||
@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
|
||||
self.assertFalse(adapted.stream)
|
||||
self.assertEqual(processed, self.basic_req)
|
||||
|
||||
# # ------------- tool-call branch -------------
|
||||
# def test_tool_call_request_conversion(self):
|
||||
# req = ChatCompletionRequest(
|
||||
# model="x",
|
||||
# messages=[{"role": "user", "content": "Weather?"}],
|
||||
# tools=[
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_weather",
|
||||
# "parameters": {"type": "object", "properties": {}},
|
||||
# },
|
||||
# }
|
||||
# ],
|
||||
# tool_choice="auto",
|
||||
# )
|
||||
|
||||
# with patch.object(
|
||||
# self.chat,
|
||||
# "_process_messages",
|
||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
# ):
|
||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||
# self.assertEqual(adapted.rid, "rid")
|
||||
|
||||
# def test_tool_choice_none(self):
|
||||
# req = ChatCompletionRequest(
|
||||
# model="x",
|
||||
# messages=[{"role": "user", "content": "Hi"}],
|
||||
# tools=[{"type": "function", "function": {"name": "noop"}}],
|
||||
# tool_choice="none",
|
||||
# )
|
||||
# with patch.object(
|
||||
# self.chat,
|
||||
# "_process_messages",
|
||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
# ):
|
||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||
# self.assertEqual(adapted.rid, "rid")
|
||||
|
||||
# ------------- multimodal branch -------------
|
||||
def test_multimodal_request_with_images(self):
|
||||
self.tm.model_config.is_multimodal = True
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in the image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,"},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_apply_jinja_template",
|
||||
return_value=("prompt", [1, 2], ["img"], None, [], []),
|
||||
), patch.object(
|
||||
self.chat,
|
||||
"_apply_conversation_template",
|
||||
return_value=("prompt", ["img"], None, [], []),
|
||||
):
|
||||
out = self.chat._process_messages(req, True)
|
||||
_, _, image_data, *_ = out
|
||||
self.assertEqual(image_data, ["img"])
|
||||
|
||||
# ------------- template handling -------------
|
||||
def test_jinja_template_processing(self):
|
||||
req = ChatCompletionRequest(
|
||||
model="x", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
self.tm.chat_template_name = None
|
||||
self.tm.tokenizer.chat_template = "<jinja>"
|
||||
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_apply_jinja_template",
|
||||
return_value=("processed", [1], None, None, [], ["</s>"]),
|
||||
), patch("builtins.hasattr", return_value=True):
|
||||
prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
|
||||
self.assertEqual(prompt, "processed")
|
||||
self.assertEqual(prompt_ids, [1])
|
||||
|
||||
# ------------- sampling-params -------------
|
||||
def test_sampling_param_build(self):
|
||||
req = ChatCompletionRequest(
|
||||
|
||||
@@ -5,6 +5,7 @@ Run with:
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||
@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = None
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = (
|
||||
None # Set to None to avoid template processing
|
||||
)
|
||||
|
||||
|
||||
class ServingCompletionTestCase(unittest.TestCase):
|
||||
"""Bundle all prompt/echo tests in one TestCase."""
|
||||
|
||||
@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
|
||||
tm.generate_request = AsyncMock()
|
||||
tm.create_abort_task = Mock()
|
||||
|
||||
self.sc = OpenAIServingCompletion(tm)
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.sc = OpenAIServingCompletion(tm, self.template_manager)
|
||||
|
||||
# ---------- prompt-handling ----------
|
||||
def test_single_string_prompt(self):
|
||||
@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||
|
||||
def test_completion_template_handling(self):
|
||||
req = CompletionRequest(
|
||||
model="x", prompt="def f():", suffix="return 1", max_tokens=100
|
||||
)
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
|
||||
return_value=True,
|
||||
), patch(
|
||||
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
|
||||
return_value="processed_prompt",
|
||||
):
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.text, "processed_prompt")
|
||||
|
||||
# ---------- echo-handling ----------
|
||||
def test_echo_with_string_prompt_streaming(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
|
||||
|
||||
@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
|
||||
with the original adapter.py functionality and follows OpenAI API specifications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
EmbeddingObject,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
MultimodalEmbeddingInput,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
@@ -58,11 +49,22 @@ class _MockTokenizerManager:
|
||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||
|
||||
|
||||
# Mock TemplateManager for embedding tests
|
||||
class _MockTemplateManager:
|
||||
def __init__(self):
|
||||
self.chat_template_name = None # None for embeddings usually
|
||||
self.jinja_template_content_format = None
|
||||
self.completion_template_name = None
|
||||
|
||||
|
||||
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.tokenizer_manager = _MockTokenizerManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(
|
||||
self.tokenizer_manager, self.template_manager
|
||||
)
|
||||
|
||||
self.request = Mock(spec=Request)
|
||||
self.request.headers = {}
|
||||
@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
self.assertIsNone(adapted_request.image_data[1])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
|
||||
def test_build_single_embedding_response(self):
|
||||
"""Test building response for single embedding."""
|
||||
ret_data = [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
]
|
||||
|
||||
response = self.serving_embedding._build_embedding_response(ret_data)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.data[0].object, "embedding")
|
||||
self.assertEqual(response.usage.prompt_tokens, 5)
|
||||
self.assertEqual(response.usage.total_tokens, 5)
|
||||
self.assertEqual(response.usage.completion_tokens, 0)
|
||||
|
||||
def test_build_multiple_embedding_response(self):
|
||||
"""Test building response for multiple embeddings."""
|
||||
ret_data = [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"meta_info": {"prompt_tokens": 3},
|
||||
},
|
||||
{
|
||||
"embedding": [0.4, 0.5, 0.6],
|
||||
"meta_info": {"prompt_tokens": 4},
|
||||
},
|
||||
]
|
||||
|
||||
response = self.serving_embedding._build_embedding_response(ret_data)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 2)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
|
||||
self.assertEqual(response.data[1].index, 1)
|
||||
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
||||
self.assertEqual(response.usage.total_tokens, 7)
|
||||
|
||||
def test_handle_request_success(self):
|
||||
"""Test successful embedding request handling."""
|
||||
|
||||
async def run_test():
|
||||
# Mock the generate_request to return expected data
|
||||
async def mock_generate():
|
||||
yield {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate()
|
||||
)
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_validation_error(self):
|
||||
"""Test handling request with validation error."""
|
||||
|
||||
async def run_test():
|
||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
invalid_request, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_generation_error(self):
|
||||
"""Test handling request with generation error."""
|
||||
|
||||
async def run_test():
|
||||
# Mock generate_request to raise an error
|
||||
async def mock_generate_error():
|
||||
raise ValueError("Generation failed")
|
||||
yield # This won't be reached but needed for async generator
|
||||
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate_error()
|
||||
)
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_internal_error(self):
|
||||
"""Test handling request with internal server error."""
|
||||
|
||||
async def run_test():
|
||||
# Mock _convert_to_internal_request to raise an exception
|
||||
with patch.object(
|
||||
self.serving_embedding,
|
||||
"_convert_to_internal_request",
|
||||
side_effect=Exception("Internal error"),
|
||||
):
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 500)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -29,6 +29,10 @@ suites = {
|
||||
TestFile("models/test_reward_models.py", 132),
|
||||
TestFile("models/test_vlm_models.py", 437),
|
||||
TestFile("models/test_transformers_models.py", 320),
|
||||
TestFile("openai/test_protocol.py", 10),
|
||||
TestFile("openai/test_serving_chat.py", 10),
|
||||
TestFile("openai/test_serving_completions.py", 10),
|
||||
TestFile("openai/test_serving_embedding.py", 10),
|
||||
TestFile("test_abort.py", 51),
|
||||
TestFile("test_block_int8.py", 22),
|
||||
TestFile("test_create_kvindices.py", 2),
|
||||
@@ -49,6 +53,7 @@ suites = {
|
||||
TestFile("test_hidden_states.py", 55),
|
||||
TestFile("test_int8_kernel.py", 8),
|
||||
TestFile("test_input_embeddings.py", 38),
|
||||
TestFile("test_jinja_template_utils.py", 1),
|
||||
TestFile("test_json_constrained.py", 98),
|
||||
TestFile("test_large_max_new_tokens.py", 41),
|
||||
TestFile("test_metrics.py", 32),
|
||||
@@ -59,14 +64,8 @@ suites = {
|
||||
TestFile("test_mla_fp8.py", 93),
|
||||
TestFile("test_no_chunked_prefill.py", 108),
|
||||
TestFile("test_no_overlap_scheduler.py", 234),
|
||||
TestFile("test_openai_adapter.py", 1),
|
||||
TestFile("test_openai_function_calling.py", 60),
|
||||
TestFile("test_openai_server.py", 149),
|
||||
TestFile("openai/test_server.py", 120),
|
||||
TestFile("openai/test_protocol.py", 60),
|
||||
TestFile("openai/test_serving_chat.py", 120),
|
||||
TestFile("openai/test_serving_completions.py", 120),
|
||||
TestFile("openai/test_serving_embedding.py", 120),
|
||||
TestFile("test_openai_server_hidden_states.py", 240),
|
||||
TestFile("test_penalty.py", 41),
|
||||
TestFile("test_page_size.py", 60),
|
||||
|
||||
@@ -3,6 +3,7 @@ import unittest
|
||||
|
||||
from xgrammar import GrammarCompiler, TokenizerInfo
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Function, Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
@@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.openai_api.protocol import Function, Tool
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils.
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from sglang.srt.openai_api.utils import (
|
||||
detect_template_content_format,
|
||||
from sglang.srt.jinja_template_utils import (
|
||||
detect_jinja_template_content_format,
|
||||
process_content_for_template_format,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
@@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
|
||||
{%- endfor %}
|
||||
"""
|
||||
|
||||
result = detect_template_content_format(llama4_pattern)
|
||||
result = detect_jinja_template_content_format(llama4_pattern)
|
||||
self.assertEqual(result, "openai")
|
||||
|
||||
def test_detect_deepseek_string_format(self):
|
||||
@@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase):
|
||||
{%- endfor %}
|
||||
"""
|
||||
|
||||
result = detect_template_content_format(deepseek_pattern)
|
||||
result = detect_jinja_template_content_format(deepseek_pattern)
|
||||
self.assertEqual(result, "string")
|
||||
|
||||
def test_detect_invalid_template(self):
|
||||
"""Test handling of invalid template (should default to 'string')."""
|
||||
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
|
||||
|
||||
result = detect_template_content_format(invalid_pattern)
|
||||
result = detect_jinja_template_content_format(invalid_pattern)
|
||||
self.assertEqual(result, "string")
|
||||
|
||||
def test_detect_empty_template(self):
|
||||
"""Test handling of empty template (should default to 'string')."""
|
||||
result = detect_template_content_format("")
|
||||
result = detect_jinja_template_content_format("")
|
||||
self.assertEqual(result, "string")
|
||||
|
||||
def test_process_content_openai_format(self):
|
||||
@@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase):
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
is_finished = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
@@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase):
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason is not None:
|
||||
is_finished[index] = True
|
||||
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
@@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase):
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs:
|
||||
if logprobs and not is_finished.get(index, False):
|
||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
@@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase):
|
||||
assert (
|
||||
isinstance(data.content, str)
|
||||
or isinstance(data.reasoning_content, str)
|
||||
or len(data.tool_calls) > 0
|
||||
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
|
||||
or response.choices[0].finish_reason
|
||||
)
|
||||
assert response.id
|
||||
@@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase):
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def _create_batch(self, mode, client):
|
||||
if mode == "completion":
|
||||
input_file_path = "complete_input.jsonl"
|
||||
# write content to input file
|
||||
content = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 3 names of famous soccer player: ",
|
||||
"max_tokens": 20,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-2",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 6 names of famous basketball player: ",
|
||||
"max_tokens": 40,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-3",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 6 names of famous tenniss player: ",
|
||||
"max_tokens": 40,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
else:
|
||||
input_file_path = "chat_input.jsonl"
|
||||
content = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! List 3 NBA players and tell a story",
|
||||
},
|
||||
],
|
||||
"max_tokens": 30,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-2",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an assistant. "},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! List three capital and tell a story",
|
||||
},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with open(input_file_path, "w") as file:
|
||||
for line in content:
|
||||
file.write(json.dumps(line) + "\n")
|
||||
|
||||
with open(input_file_path, "rb") as file:
|
||||
uploaded_file = client.files.create(file=file, purpose="batch")
|
||||
if mode == "completion":
|
||||
endpoint = "/v1/completions"
|
||||
elif mode == "chat":
|
||||
endpoint = "/v1/chat/completions"
|
||||
completion_window = "24h"
|
||||
batch_job = client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint=endpoint,
|
||||
completion_window=completion_window,
|
||||
)
|
||||
|
||||
return batch_job, content, uploaded_file
|
||||
|
||||
def run_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
|
||||
|
||||
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||
time.sleep(3)
|
||||
print(
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
batch_job = client.batches.retrieve(batch_job.id)
|
||||
assert (
|
||||
batch_job.status == "completed"
|
||||
), f"Batch job status is not completed: {batch_job.status}"
|
||||
assert batch_job.request_counts.completed == len(content)
|
||||
assert batch_job.request_counts.failed == 0
|
||||
assert batch_job.request_counts.total == len(content)
|
||||
|
||||
result_file_id = batch_job.output_file_id
|
||||
file_response = client.files.content(result_file_id)
|
||||
result_content = file_response.read().decode("utf-8") # Decode bytes to string
|
||||
results = [
|
||||
json.loads(line)
|
||||
for line in result_content.split("\n")
|
||||
if line.strip() != ""
|
||||
]
|
||||
assert len(results) == len(content)
|
||||
for delete_fid in [uploaded_file.id, result_file_id]:
|
||||
del_pesponse = client.files.delete(delete_fid)
|
||||
assert del_pesponse.deleted
|
||||
|
||||
def run_cancel_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
|
||||
|
||||
assert batch_job.status not in ["cancelling", "cancelled"]
|
||||
|
||||
batch_job = client.batches.cancel(batch_id=batch_job.id)
|
||||
assert batch_job.status == "cancelling"
|
||||
|
||||
while batch_job.status not in ["failed", "cancelled"]:
|
||||
batch_job = client.batches.retrieve(batch_job.id)
|
||||
print(
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
time.sleep(3)
|
||||
|
||||
assert batch_job.status == "cancelled"
|
||||
del_response = client.files.delete(uploaded_file.id)
|
||||
assert del_response.deleted
|
||||
|
||||
def test_completion(self):
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
@@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase):
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||
|
||||
def test_batch(self):
|
||||
for mode in ["completion", "chat"]:
|
||||
self.run_batch(mode)
|
||||
|
||||
def test_cancel_batch(self):
|
||||
for mode in ["completion", "chat"]:
|
||||
self.run_cancel_batch(mode)
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
@@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white
|
||||
assert len(models) == 1
|
||||
assert isinstance(getattr(models[0], "max_model_len", None), int)
|
||||
|
||||
def test_retrieve_model(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Test retrieving an existing model
|
||||
retrieved_model = client.models.retrieve(self.model)
|
||||
self.assertEqual(retrieved_model.id, self.model)
|
||||
self.assertEqual(retrieved_model.root, self.model)
|
||||
|
||||
# Test retrieving a non-existent model
|
||||
with self.assertRaises(openai.NotFoundError):
|
||||
client.models.retrieve("non-existent-model")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# EBNF Test Class: TestOpenAIServerEBNF
|
||||
@@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase):
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
self.assertTrue(len(response.data[1].embedding) > 0)
|
||||
|
||||
def test_embedding_single_batch_str(self):
|
||||
"""Test embedding with a List[str] and length equals to 1"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_embedding_single_int_list(self):
|
||||
"""Test embedding with a List[int] or List[List[int]]]"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_empty_string_embedding(self):
|
||||
"""Test embedding an empty string."""
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import (
|
||||
from sglang import Engine
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
@@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import (
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from transformers import (
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user