diff --git a/benchmark/hicache/data_processing.py b/benchmark/hicache/data_processing.py index fcc44086e..0152406a8 100644 --- a/benchmark/hicache/data_processing.py +++ b/benchmark/hicache/data_processing.py @@ -20,7 +20,7 @@ from sglang.bench_serving import ( get_gen_prefix_cache_path, ) from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path -from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart +from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart from sglang.utils import encode_video_base64 # type of content fields, can be only prompts or with images/videos diff --git a/docs/backend/openai_api_embeddings.ipynb b/docs/backend/openai_api_embeddings.ipynb index e4a40cd5c..a8828daac 100644 --- a/docs/backend/openai_api_embeddings.ipynb +++ b/docs/backend/openai_api_embeddings.ipynb @@ -64,11 +64,14 @@ "text = \"Once upon a time\"\n", "\n", "curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n", + " -H \"Content-Type: application/json\" \\\n", " -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n", "\n", - "text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n", - " \"embedding\"\n", - "]\n", + "result = subprocess.check_output(curl_text, shell=True)\n", + "\n", + "print(result)\n", + "\n", + "text_embedding = json.loads(result)[\"data\"][0][\"embedding\"]\n", "\n", "print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")" ] @@ -152,6 +155,7 @@ "input_ids = tokenizer.encode(text)\n", "\n", "curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n", + " -H \"Content-Type: application/json\" \\\n", " -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n", "\n", "input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n", diff --git a/docs/backend/openai_api_vision.ipynb b/docs/backend/openai_api_vision.ipynb index 0c80fdc0d..f183fd8a0 100644 --- a/docs/backend/openai_api_vision.ipynb +++ b/docs/backend/openai_api_vision.ipynb @@ -67,6 +67,7 @@ "\n", "curl_command = f\"\"\"\n", "curl -s http://localhost:{port}/v1/chat/completions \\\\\n", + " -H \"Content-Type: application/json\" \\\\\n", " -d '{{\n", " \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n", " \"messages\": [\n", diff --git a/docs/backend/vlm_query.ipynb b/docs/backend/vlm_query.ipynb index 519811f75..b47d55580 100644 --- a/docs/backend/vlm_query.ipynb +++ b/docs/backend/vlm_query.ipynb @@ -36,7 +36,7 @@ "import requests\n", "from PIL import Image\n", "\n", - "from sglang.srt.openai_api.protocol import ChatCompletionRequest\n", + "from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest\n", "from sglang.srt.conversation import chat_templates\n", "\n", "image = Image.open(\n", diff --git a/python/sglang/srt/code_completion_parser.py b/python/sglang/srt/code_completion_parser.py index 5b32d8fb6..0067ac471 100644 --- a/python/sglang/srt/code_completion_parser.py +++ b/python/sglang/srt/code_completion_parser.py @@ -15,9 +15,7 @@ import dataclasses -import json import logging -import os from enum import auto from sglang.srt.entrypoints.openai.protocol import CompletionRequest @@ -57,46 +55,6 @@ class CompletionTemplate: completion_templates: dict[str, CompletionTemplate] = {} -def load_completion_template_for_openai_api(completion_template_arg): - global completion_template_name - - logger.info( - f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}" - ) - - if not completion_template_exists(completion_template_arg): - if not os.path.exists(completion_template_arg): - raise RuntimeError( - f"Completion template {completion_template_arg} is not a built-in template name " - "or a valid completion template file path." - ) - - assert completion_template_arg.endswith( - ".json" - ), "unrecognized format of completion template file" - with open(completion_template_arg, "r") as filep: - template = json.load(filep) - try: - fim_position = FimPosition[template["fim_position"]] - except KeyError: - raise ValueError( - f"Unknown fim position: {template['fim_position']}" - ) from None - register_completion_template( - CompletionTemplate( - name=template["name"], - fim_begin_token=template["fim_begin_token"], - fim_middle_token=template["fim_middle_token"], - fim_end_token=template["fim_end_token"], - fim_position=fim_position, - ), - override=True, - ) - completion_template_name = template["name"] - else: - completion_template_name = completion_template_arg - - def register_completion_template(template: CompletionTemplate, override: bool = False): """Register a new completion template.""" if not override: diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index ec5765a1f..661f47700 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -11,7 +11,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Conversation chat templates.""" +"""Conversation chat templates. + +This module provides conversation template definitions, data structures, and utilities +for managing chat templates across different model types in SGLang. + +Key components: +- Conversation class: Defines the structure and behavior of chat templates +- SeparatorStyle enum: Different conversation formatting styles +- Template registry: Functions to register and retrieve templates by name or model path +- Built-in templates: Pre-defined templates for popular models +""" # Adapted from # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py @@ -20,7 +30,7 @@ import re from enum import IntEnum, auto from typing import Callable, Dict, List, Optional, Tuple, Union -from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.utils import read_system_prompt_from_file @@ -618,7 +628,7 @@ def generate_chat_conv( # llama2 template -# reference: https://huggingface.co/blog/codellama#conversational-instructions +# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 register_conv_template( Conversation( diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 4c88a5289..49bb76dba 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import torch import uvloop -from sglang.srt.code_completion_parser import load_completion_template_for_openai_api from sglang.srt.entrypoints.EngineBase import EngineBase from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, @@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.openai_api.adapter import ( - guess_chat_template_name_from_model_path, - load_chat_template_for_openai_api, -) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( @@ -123,12 +119,13 @@ class Engine(EngineBase): logger.info(f"{server_args=}") # Launch subprocesses - tokenizer_manager, scheduler_info = _launch_subprocesses( + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( server_args=server_args, port_args=port_args, ) self.server_args = server_args self.tokenizer_manager = tokenizer_manager + self.template_manager = template_manager self.scheduler_info = scheduler_info context = zmq.Context(2) @@ -647,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs): def _launch_subprocesses( server_args: ServerArgs, port_args: Optional[PortArgs] = None -) -> Tuple[TokenizerManager, Dict]: +) -> Tuple[TokenizerManager, TemplateManager, Dict]: """ Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. """ @@ -732,7 +729,7 @@ def _launch_subprocesses( if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": # When using `Engine` as a Python API, we don't want to block here. - return None, None + return None, None, None launch_dummy_health_check_server(server_args.host, server_args.port) @@ -741,7 +738,7 @@ def _launch_subprocesses( logger.error( f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" ) - return None, None + return None, None, None # Launch detokenizer process detoken_proc = mp.Process( @@ -755,15 +752,15 @@ def _launch_subprocesses( # Launch tokenizer process tokenizer_manager = TokenizerManager(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api( - tokenizer_manager, server_args.chat_template, server_args.model_path - ) - else: - guess_chat_template_name_from_model_path(server_args.model_path) - if server_args.completion_template: - load_completion_template_for_openai_api(server_args.completion_template) + # Initialize templates + template_manager = TemplateManager() + template_manager.initialize_templates( + tokenizer_manager=tokenizer_manager, + model_path=server_args.model_path, + chat_template=server_args.chat_template, + completion_template=server_args.completion_template, + ) # Wait for the model to finish loading scheduler_infos = [] @@ -787,4 +784,4 @@ def _launch_subprocesses( # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] - return tokenizer_manager, scheduler_info + return tokenizer_manager, template_manager, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 9262d10a9..b3e646147 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -38,7 +38,8 @@ import orjson import requests import uvicorn import uvloop -from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi import Depends, FastAPI, Request, UploadFile +from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse @@ -47,6 +48,20 @@ from sglang.srt.disaggregation.utils import ( register_disaggregation_server, ) from sglang.srt.entrypoints.engine import _launch_subprocesses +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, + EmbeddingRequest, + ModelCard, + ModelList, + ScoringRequest, + V1RerankReqInput, +) +from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat +from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion +from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank +from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( AbortReq, @@ -67,26 +82,11 @@ from sglang.srt.managers.io_struct import ( UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, - V1RerankReqInput, VertexGenerateReqInput, ) +from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer -from sglang.srt.openai_api.adapter import ( - v1_batches, - v1_cancel_batch, - v1_chat_completions, - v1_completions, - v1_delete_file, - v1_embeddings, - v1_files_create, - v1_rerank, - v1_retrieve_batch, - v1_retrieve_file, - v1_retrieve_file_content, - v1_score, -) -from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -109,6 +109,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @dataclasses.dataclass class _GlobalState: tokenizer_manager: TokenizerManager + template_manager: TemplateManager scheduler_info: Dict @@ -123,6 +124,24 @@ def set_global_state(global_state: _GlobalState): @asynccontextmanager async def lifespan(fast_api_app: FastAPI): server_args: ServerArgs = fast_api_app.server_args + + # Initialize OpenAI serving handlers + fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( + _global_state.tokenizer_manager, _global_state.template_manager + ) + fast_api_app.state.openai_serving_chat = OpenAIServingChat( + _global_state.tokenizer_manager, _global_state.template_manager + ) + fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding( + _global_state.tokenizer_manager, _global_state.template_manager + ) + fast_api_app.state.openai_serving_score = OpenAIServingScore( + _global_state.tokenizer_manager + ) + fast_api_app.state.openai_serving_rerank = OpenAIServingRerank( + _global_state.tokenizer_manager + ) + if server_args.warmups is not None: await execute_warmups( server_args.warmups.split(","), _global_state.tokenizer_manager @@ -148,6 +167,36 @@ app.add_middleware( allow_headers=["*"], ) + +# Custom exception handlers to change validation error status codes +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """Override FastAPI's default 422 validation error with 400""" + return ORJSONResponse( + status_code=400, + content={ + "detail": exc.errors(), + "body": exc.body, + }, + ) + + +async def validate_json_request(raw_request: Request): + """Validate that the request content-type is application/json.""" + content_type = raw_request.headers.get("content-type", "").lower() + media_type = content_type.split(";", maxsplit=1)[0] + if media_type != "application/json": + raise RequestValidationError( + errors=[ + { + "loc": ["header", "content-type"], + "msg": "Unsupported Media Type: Only 'application/json' is allowed", + "type": "value_error", + } + ] + ) + + HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) @@ -330,13 +379,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): return _create_error_response(e) -@app.api_route("/v1/rerank", methods=["POST", "PUT"]) -async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request): - try: - ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request) - return ret - except ValueError as e: - return _create_error_response(e) +@app.api_route( + "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)] +) +async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request): + """Endpoint for reranking documents based on query relevance.""" + return await raw_request.app.state.openai_serving_rerank.handle_request( + request, raw_request + ) @app.api_route("/flush_cache", methods=["GET", "POST"]) @@ -619,25 +669,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re ##### OpenAI-compatible API endpoints ##### -@app.post("/v1/completions") -async def openai_v1_completions(raw_request: Request): - return await v1_completions(_global_state.tokenizer_manager, raw_request) +@app.post("/v1/completions", dependencies=[Depends(validate_json_request)]) +async def openai_v1_completions(request: CompletionRequest, raw_request: Request): + """OpenAI-compatible text completion endpoint.""" + return await raw_request.app.state.openai_serving_completion.handle_request( + request, raw_request + ) -@app.post("/v1/chat/completions") -async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) +@app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)]) +async def openai_v1_chat_completions( + request: ChatCompletionRequest, raw_request: Request +): + """OpenAI-compatible chat completion endpoint.""" + return await raw_request.app.state.openai_serving_chat.handle_request( + request, raw_request + ) -@app.post("/v1/embeddings", response_class=ORJSONResponse) -async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(_global_state.tokenizer_manager, raw_request) - return response +@app.post( + "/v1/embeddings", + response_class=ORJSONResponse, + dependencies=[Depends(validate_json_request)], +) +async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request): + """OpenAI-compatible embeddings endpoint.""" + return await raw_request.app.state.openai_serving_embedding.handle_request( + request, raw_request + ) @app.get("/v1/models", response_class=ORJSONResponse) -def available_models(): - """Show available models.""" +async def available_models(): + """Show available models. OpenAI-compatible endpoint.""" served_model_names = [_global_state.tokenizer_manager.served_model_name] model_cards = [] for served_model_name in served_model_names: @@ -651,47 +715,31 @@ def available_models(): return ModelList(data=model_cards) -@app.post("/v1/files") -async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path +@app.get("/v1/models/{model:path}", response_class=ORJSONResponse) +async def retrieve_model(model: str): + """Retrieves a model instance, providing basic information about the model.""" + served_model_names = [_global_state.tokenizer_manager.served_model_name] + + if model not in served_model_names: + return ORJSONResponse( + status_code=404, + content={ + "error": { + "message": f"The model '{model}' does not exist", + "type": "invalid_request_error", + "param": "model", + "code": "model_not_found", + } + }, + ) + + return ModelCard( + id=model, + root=model, + max_model_len=_global_state.tokenizer_manager.model_config.context_len, ) -@app.delete("/v1/files/{file_id}") -async def delete_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/delete - return await v1_delete_file(file_id) - - -@app.post("/v1/batches") -async def openai_v1_batches(raw_request: Request): - return await v1_batches(_global_state.tokenizer_manager, raw_request) - - -@app.post("/v1/batches/{batch_id}/cancel") -async def cancel_batches(batch_id: str): - # https://platform.openai.com/docs/api-reference/batch/cancel - return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id) - - -@app.get("/v1/batches/{batch_id}") -async def retrieve_batch(batch_id: str): - return await v1_retrieve_batch(batch_id) - - -@app.get("/v1/files/{file_id}") -async def retrieve_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve - return await v1_retrieve_file(file_id) - - -@app.get("/v1/files/{file_id}/content") -async def retrieve_file_content(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve-contents - return await v1_retrieve_file_content(file_id) - - ## SageMaker API @app.get("/ping") async def sagemaker_health() -> Response: @@ -700,8 +748,13 @@ async def sagemaker_health() -> Response: @app.post("/invocations") -async def sagemaker_chat_completions(raw_request: Request): - return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) +async def sagemaker_chat_completions( + request: ChatCompletionRequest, raw_request: Request +): + """OpenAI-compatible chat completion endpoint.""" + return await raw_request.app.state.openai_serving_chat.handle_request( + request, raw_request + ) ## Vertex AI API @@ -732,10 +785,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque return ORJSONResponse({"predictions": ret}) -@app.post("/v1/score") -async def v1_score_request(raw_request: Request): +@app.post("/v1/score", dependencies=[Depends(validate_json_request)]) +async def v1_score_request(request: ScoringRequest, raw_request: Request): """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation.""" - return await v1_score(_global_state.tokenizer_manager, raw_request) + return await raw_request.app.state.openai_serving_score.handle_request( + request, raw_request + ) def _create_error_response(e): @@ -764,10 +819,13 @@ def launch_server( 1. The HTTP server, Engine, and TokenizerManager both run in the main process. 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. """ - tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( + server_args=server_args + ) set_global_state( _GlobalState( tokenizer_manager=tokenizer_manager, + template_manager=template_manager, scheduler_info=scheduler_info, ) ) diff --git a/python/sglang/srt/entrypoints/openai/api_server.py b/python/sglang/srt/entrypoints/openai/api_server.py deleted file mode 100644 index a31643395..000000000 --- a/python/sglang/srt/entrypoints/openai/api_server.py +++ /dev/null @@ -1,388 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -SGLang OpenAI-Compatible API Server. - -This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI. -""" - -import argparse -import asyncio -import logging -import multiprocessing -import os -import threading -import time -from contextlib import asynccontextmanager -from typing import Callable, Dict, Optional - -import numpy as np -import requests -import uvicorn -import uvloop -from fastapi import FastAPI, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import Response - -from sglang.srt.disaggregation.utils import ( - FAKE_BOOTSTRAP_HOST, - register_disaggregation_server, -) -from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses -from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.metrics.func_timer import enable_func_timer -from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ( - add_prometheus_middleware, - delete_directory, - get_bool_env_var, - kill_process_tree, - set_uvicorn_logging_configs, -) -from sglang.srt.warmup import execute_warmups -from sglang.utils import get_exception_traceback - -logger = logging.getLogger(__name__) -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - - -# Store global states -class AppState: - engine: Optional[Engine] = None - server_args: Optional[ServerArgs] = None - tokenizer_manager: Optional[TokenizerManager] = None - scheduler_info: Optional[Dict] = None - embedding_server: Optional[OpenAIServingEmbedding] = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - app.state.server_args.enable_metrics = True # By default, we enable metrics - - server_args = app.state.server_args - - # Initialize engine - logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...") - - tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) - app.state.tokenizer_manager = tokenizer_manager - app.state.scheduler_info = scheduler_info - app.state.serving_embedding = OpenAIServingEmbedding( - tokenizer_manager=tokenizer_manager - ) - - if server_args.enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() - - # Initialize engine state attribute to None for now - app.state.engine = None - - if server_args.warmups is not None: - await execute_warmups( - server_args.warmups.split(","), app.state.tokenizer_manager - ) - logger.info("Warmup ended") - - warmup_thread = getattr(app, "warmup_thread", None) - if warmup_thread is not None: - warmup_thread.start() - - yield - - # Lifespan shutdown - if hasattr(app.state, "engine") and app.state.engine is not None: - logger.info("SGLang engine is shutting down.") - # Add engine cleanup logic here when implemented - - -# Fast API app with CORS enabled -app = FastAPI( - lifespan=lifespan, - # TODO: check where /openai.json is created or why we use this - openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json", -) -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@app.api_route("/health", methods=["GET"]) -async def health() -> Response: - """Health check. Used for readiness and liveness probes.""" - # In the future, this could check engine health more deeply - # For now, if the server is up, it's healthy. - return Response(status_code=200) - - -@app.api_route("/v1/models", methods=["GET"]) -async def show_models(): - """Show available models. Currently, it returns the served model name. - - This endpoint is compatible with the OpenAI API standard. - """ - served_model_names = [app.state.tokenizer_manager.served_model_name] - model_cards = [] - for served_model_name in served_model_names: - model_cards.append( - ModelCard( - id=served_model_name, - root=served_model_name, - max_model_len=app.state.tokenizer_manager.model_config.context_len, - ) - ) - return ModelList(data=model_cards) - - -@app.get("/get_model_info") -async def get_model_info(): - """Get the model information.""" - result = { - "model_path": app.state.tokenizer_manager.model_path, - "tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path, - "is_generation": app.state.tokenizer_manager.is_generation, - } - return result - - -@app.post("/v1/completions") -async def openai_v1_completions(raw_request: Request): - pass - - -@app.post("/v1/chat/completions") -async def openai_v1_chat_completions(raw_request: Request): - pass - - -@app.post("/v1/embeddings") -async def openai_v1_embeddings(raw_request: Request): - try: - request_json = await raw_request.json() - request = EmbeddingRequest(**request_json) - except Exception as e: - return app.state.serving_embedding.create_error_response( - f"Invalid request body, error: {str(e)}" - ) - - ret = await app.state.serving_embedding.handle_request(request, raw_request) - return ret - - -@app.post("/v1/score") -async def v1_score_request(raw_request: Request): - """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation.""" - pass - - -@app.api_route("/v1/models/{model_id}", methods=["GET"]) -async def show_model_detail(model_id: str): - served_model_name = app.state.tokenizer_manager.served_model_name - - return ModelCard( - id=served_model_name, - root=served_model_name, - max_model_len=app.state.tokenizer_manager.model_config.context_len, - ) - - -# Additional API endpoints will be implemented in separate serving_*.py modules -# and mounted as APIRouters in future PRs - - -def _wait_and_warmup( - server_args: ServerArgs, - pipe_finish_writer: Optional[multiprocessing.connection.Connection], - image_token_text: str, - launch_callback: Optional[Callable[[], None]] = None, -): - return - # TODO: Please wait until the /generate implementation is complete, - # or confirm if modifications are needed before removing this. - - headers = {} - url = server_args.url() - if server_args.api_key: - headers["Authorization"] = f"Bearer {server_args.api_key}" - - # Wait until the server is launched - success = False - for _ in range(120): - time.sleep(1) - try: - res = requests.get(url + "/get_model_info", timeout=5, headers=headers) - assert res.status_code == 200, f"{res=}, {res.text=}" - success = True - break - except (AssertionError, requests.exceptions.RequestException): - last_traceback = get_exception_traceback() - pass - - if not success: - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - model_info = res.json() - - # Send a warmup request - request_name = "/generate" if model_info["is_generation"] else "/encode" - # TODO: Replace with OpenAI API - max_new_tokens = 8 if model_info["is_generation"] else 1 - json_data = { - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - } - if server_args.skip_tokenizer_init: - json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)] - # TODO Workaround the bug that embedding errors for list of size 1 - if server_args.dp_size == 1: - json_data["input_ids"] = json_data["input_ids"][0] - else: - json_data["text"] = ["The capital city of France is"] * server_args.dp_size - # TODO Workaround the bug that embedding errors for list of size 1 - if server_args.dp_size == 1: - json_data["text"] = json_data["text"][0] - - # Debug dumping - if server_args.debug_tensor_dump_input_file: - json_data.pop("text", None) - json_data["input_ids"] = np.load( - server_args.debug_tensor_dump_input_file - ).tolist() - json_data["sampling_params"]["max_new_tokens"] = 0 - - try: - if server_args.disaggregation_mode == "null": - res = requests.post( - url + request_name, - json=json_data, - headers=headers, - timeout=600, - ) - assert res.status_code == 200, f"{res}" - else: - logger.info(f"Start of prefill warmup ...") - json_data = { - "sampling_params": { - "temperature": 0.0, - "max_new_tokens": 8, - "ignore_eos": True, - }, - "bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size, - # This is a hack to ensure fake transfer is enabled during prefill warmup - # ensure each dp rank has a unique bootstrap_room during prefill warmup - "bootstrap_room": [ - i * (2**63 // server_args.dp_size) + (i % server_args.tp_size) - for i in range(server_args.dp_size) - ], - "input_ids": [[0, 1, 2, 3]] * server_args.dp_size, - } - res = requests.post( - url + request_name, - json=json_data, - headers=headers, - timeout=1800, # because of deep gemm precache is very long if not precache. - ) - logger.info( - f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" - ) - - except Exception: - last_traceback = get_exception_traceback() - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - # Debug print - # logger.info(f"{res.json()=}") - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("ready") - - if server_args.delete_ckpt_after_loading: - delete_directory(server_args.model_path) - - if server_args.debug_tensor_dump_input_file: - kill_process_tree(os.getpid()) - - if server_args.pdlb_url is not None: - register_disaggregation_server( - server_args.disaggregation_mode, - server_args.port, - server_args.disaggregation_bootstrap_port, - server_args.pdlb_url, - ) - - if launch_callback is not None: - launch_callback() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server") - # Add arguments from ServerArgs. This allows reuse of existing CLI definitions. - ServerArgs.add_cli_args(parser) - # Potentially add server-specific arguments here in the future if needed - - args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) - - # Store server_args in app.state for access in lifespan and endpoints - app.state.server_args = server_args - - # Configure logging - logging.basicConfig( - level=server_args.log_level.upper(), - format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s", - ) - - # Send a warmup request - we will create the thread launch it - # in the lifespan after all other warmups have fired. - warmup_thread = threading.Thread( - target=_wait_and_warmup, - args=( - server_args, - None, - None, # Never used - None, - ), - ) - app.warmup_thread = warmup_thread - - try: - # Start the server - set_uvicorn_logging_configs() - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level.lower(), - timeout_keep_alive=60, # Increased keep-alive for potentially long requests - loop="uvloop", # Use uvloop for better performance if available - ) - finally: - warmup_thread.join() diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 9acd50d02..89b2d3ab6 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -207,7 +207,7 @@ class CompletionResponseChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Literal["stop", "length", "content_filter", "abort"] + finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None matched_stop: Union[None, int, str] = None hidden_states: Optional[object] = None @@ -404,21 +404,13 @@ class ChatCompletionRequest(BaseModel): @model_validator(mode="before") @classmethod def set_tool_choice_default(cls, values): - if isinstance(values, dict): - if values.get("tool_choice") is None: - if values.get("tools") is None: - values["tool_choice"] = "none" - else: - values["tool_choice"] = "auto" + if values.get("tool_choice") is None: + if values.get("tools") is None: + values["tool_choice"] = "none" + else: + values["tool_choice"] = "auto" return values - @field_validator("messages") - @classmethod - def validate_messages_not_empty(cls, v): - if not v: - raise ValueError("Messages cannot be empty") - return v - # Extra parameters for SRT backend only and will be ignored by OpenAI models. top_k: int = -1 min_p: float = 0.0 @@ -457,9 +449,11 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None - finish_reason: Literal[ - "stop", "length", "tool_calls", "content_filter", "function_call", "abort" - ] + finish_reason: Optional[ + Literal[ + "stop", "length", "tool_calls", "content_filter", "function_call", "abort" + ] + ] = None matched_stop: Union[None, int, str] = None hidden_states: Optional[object] = None @@ -530,7 +524,7 @@ class EmbeddingRequest(BaseModel): input: EmbeddingInput model: str encoding_format: str = "float" - dimensions: int = None + dimensions: Optional[int] = None user: Optional[str] = None # The request id. diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 8e22c26c4..ba7514f0d 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -2,16 +2,12 @@ import json import logging import uuid from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from fastapi import Request from fastapi.responses import ORJSONResponse, StreamingResponse -from sglang.srt.entrypoints.openai.protocol import ( - ErrorResponse, - OpenAIServingRequest, - UsageInfo, -) +from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -51,7 +47,7 @@ class OpenAIServingBase(ABC): ) except Exception as e: - logger.error(f"Error in request: {e}") + logger.exception(f"Error in request: {e}") return self.create_error_response( message=f"Internal server error: {str(e)}", err_type="InternalServerError", @@ -63,8 +59,12 @@ class OpenAIServingBase(ABC): """Generate request ID based on request type""" pass - def _generate_request_id_base(self, request: OpenAIServingRequest) -> str: + def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]: """Generate request ID based on request type""" + return None + + # TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError + # Temporarily return None in this function until the rid logic is clear. if rid := getattr(request, "rid", None): return rid @@ -83,7 +83,7 @@ class OpenAIServingBase(ABC): adapted_request: GenerateReqInput, request: OpenAIServingRequest, raw_request: Request, - ) -> StreamingResponse: + ) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]: """Handle streaming request Override this method in child classes that support streaming requests. @@ -99,7 +99,7 @@ class OpenAIServingBase(ABC): adapted_request: GenerateReqInput, request: OpenAIServingRequest, raw_request: Request, - ) -> Union[Any, ErrorResponse]: + ) -> Union[Any, ErrorResponse, ORJSONResponse]: """Handle non-streaming request Override this method in child classes that support non-streaming requests. @@ -110,7 +110,7 @@ class OpenAIServingBase(ABC): status_code=501, ) - def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]: + def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]: """Validate request""" pass @@ -122,6 +122,7 @@ class OpenAIServingBase(ABC): param: Optional[str] = None, ) -> ORJSONResponse: """Create an error response""" + # TODO: remove fastapi dependency in openai and move response handling to the entrypoint error = ErrorResponse( object="error", message=message, diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index cb128ff41..b91fee18c 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -1,4 +1,3 @@ -import base64 import json import logging import time @@ -6,7 +5,7 @@ import uuid from typing import Any, AsyncGenerator, Dict, List, Optional, Union from fastapi import Request -from fastapi.responses import StreamingResponse +from fastapi.responses import ORJSONResponse, StreamingResponse from sglang.srt.conversation import generate_chat_conv from sglang.srt.entrypoints.openai.protocol import ( @@ -28,13 +27,14 @@ from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor from sglang.srt.entrypoints.openai.utils import ( - detect_template_content_format, - process_content_for_template_format, process_hidden_states_from_ret, to_openai_style_logprobs, ) from sglang.srt.function_call.function_call_parser import FunctionCallParser +from sglang.srt.jinja_template_utils import process_content_for_template_format from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.template_manager import TemplateManager +from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.reasoning_parser import ReasoningParser from sglang.utils import convert_json_schema_to_str @@ -42,13 +42,13 @@ logger = logging.getLogger(__name__) class OpenAIServingChat(OpenAIServingBase): - """Handler for chat completion requests""" + """Handler for /v1/chat/completions requests""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Instance-specific cache for template content format detection - self._cached_chat_template = None - self._cached_template_format = None + def __init__( + self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager + ): + super().__init__(tokenizer_manager) + self.template_manager = template_manager def _request_id_prefix(self) -> str: return "chatcmpl-" @@ -142,19 +142,14 @@ class OpenAIServingChat(OpenAIServingBase): ) # Use chat template - if ( - hasattr(self.tokenizer_manager, "chat_template_name") - and self.tokenizer_manager.chat_template_name is None - ): + if self.template_manager.chat_template_name is None: prompt, prompt_ids, image_data, audio_data, modalities, stop = ( self._apply_jinja_template(request, tools, is_multimodal) ) else: - prompt, image_data, audio_data, modalities, stop = ( - self._apply_conversation_template(request) + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + self._apply_conversation_template(request, is_multimodal) ) - if not is_multimodal: - prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt) else: # Use raw prompt prompt_ids = request.messages @@ -181,23 +176,14 @@ class OpenAIServingChat(OpenAIServingBase): is_multimodal: bool, ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]: """Apply Jinja chat template""" + prompt = "" + prompt_ids = [] openai_compatible_messages = [] image_data = [] audio_data = [] modalities = [] - # Detect template content format - current_template = self.tokenizer_manager.tokenizer.chat_template - if current_template != self._cached_chat_template: - self._cached_chat_template = current_template - self._cached_template_format = detect_template_content_format( - current_template - ) - logger.info( - f"Detected chat template content format: {self._cached_template_format}" - ) - - template_content_format = self._cached_template_format + template_content_format = self.template_manager.jinja_template_content_format for message in request.messages: if message.content is None: @@ -262,14 +248,21 @@ class OpenAIServingChat(OpenAIServingBase): if is_multimodal: prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids) - stop = request.stop or [] + stop = request.stop + image_data = image_data if image_data else None + audio_data = audio_data if audio_data else None + modalities = modalities if modalities else [] return prompt, prompt_ids, image_data, audio_data, modalities, stop def _apply_conversation_template( - self, request: ChatCompletionRequest - ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]: + self, + request: ChatCompletionRequest, + is_multimodal: bool, + ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]: """Apply conversation template""" - conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name) + prompt = "" + prompt_ids = [] + conv = generate_chat_conv(request, self.template_manager.chat_template_name) # If we should continue the final assistant message, adjust the conversation. if ( @@ -296,9 +289,9 @@ class OpenAIServingChat(OpenAIServingBase): else: prompt = conv.get_prompt() - image_data = conv.image_data - audio_data = conv.audio_data - modalities = conv.modalities + image_data = conv.image_data if conv.image_data else None + audio_data = conv.audio_data if conv.audio_data else None + modalities = conv.modalities if conv.modalities else [] stop = conv.stop_str or [] if not request.ignore_eos else [] if request.stop: @@ -307,7 +300,10 @@ class OpenAIServingChat(OpenAIServingBase): else: stop.extend(request.stop) - return prompt, image_data, audio_data, modalities, stop + if not is_multimodal: + prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt) + + return prompt, prompt_ids, image_data, audio_data, modalities, stop def _build_sampling_params( self, @@ -459,13 +455,9 @@ class OpenAIServingChat(OpenAIServingBase): stream_buffers[index] = stream_buffer + delta # Handle reasoning content - enable_thinking = getattr(request, "chat_template_kwargs", {}).get( - "enable_thinking", True - ) if ( self.tokenizer_manager.server_args.reasoning_parser and request.separate_reasoning - and enable_thinking ): reasoning_text, delta = self._process_reasoning_stream( index, delta, reasoning_parser_dict, content, request @@ -591,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase): ) yield f"data: {usage_chunk.model_dump_json()}\n\n" - except Exception as e: + except ValueError as e: error = self.create_streaming_error_response(str(e)) yield f"data: {error}\n\n" @@ -602,7 +594,7 @@ class OpenAIServingChat(OpenAIServingBase): adapted_request: GenerateReqInput, request: ChatCompletionRequest, raw_request: Request, - ) -> Union[ChatCompletionResponse, ErrorResponse]: + ) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]: """Handle non-streaming chat completion request""" try: ret = await self.tokenizer_manager.generate_request( @@ -627,7 +619,7 @@ class OpenAIServingChat(OpenAIServingBase): request: ChatCompletionRequest, ret: List[Dict[str, Any]], created: int, - ) -> ChatCompletionResponse: + ) -> Union[ChatCompletionResponse, ORJSONResponse]: """Build chat completion response from generation results""" choices = [] @@ -645,11 +637,8 @@ class OpenAIServingChat(OpenAIServingBase): # Handle reasoning content reasoning_text = None - enable_thinking = getattr(request, "chat_template_kwargs", {}).get( - "enable_thinking", True - ) reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser - if reasoning_parser and request.separate_reasoning and enable_thinking: + if reasoning_parser and request.separate_reasoning: try: parser = ReasoningParser( model_type=reasoning_parser, stream_reasoning=False @@ -691,9 +680,10 @@ class OpenAIServingChat(OpenAIServingBase): choices.append(choice_data) # Calculate usage - cache_report = self.tokenizer_manager.server_args.enable_cache_report usage = UsageProcessor.calculate_response_usage( - ret, n_choices=request.n, enable_cache_report=cache_report + ret, + n_choices=request.n, + enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report, ) return ChatCompletionResponse( @@ -821,6 +811,25 @@ class OpenAIServingChat(OpenAIServingBase): reasoning_parser = reasoning_parser_dict[index] return reasoning_parser.parse_stream_chunk(delta) + def _get_enable_thinking_from_request(request: ChatCompletionRequest) -> bool: + """Extracts the 'enable_thinking' flag from request chat_template_kwargs. + + NOTE: This parameter is only useful for models that support enable_thinking + flag, such as Qwen3. + + Args: + request_obj: The request object (or an item from a list of requests). + Returns: + The boolean value of 'enable_thinking' if found and not True, otherwise True. + """ + if ( + hasattr(request, "chat_template_kwargs") + and request.chat_template_kwargs + and request.chat_template_kwargs.get("enable_thinking") is not None + ): + return request.chat_template_kwargs.get("enable_thinking") + return True + async def _process_tool_call_stream( self, index: int, diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index af715b32d..3db881641 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -3,12 +3,9 @@ import time from typing import Any, AsyncGenerator, Dict, List, Union from fastapi import Request -from fastapi.responses import StreamingResponse +from fastapi.responses import ORJSONResponse, StreamingResponse -from sglang.srt.code_completion_parser import ( - generate_completion_prompt_from_request, - is_completion_template_defined, -) +from sglang.srt.code_completion_parser import generate_completion_prompt_from_request from sglang.srt.entrypoints.openai.protocol import ( CompletionRequest, CompletionResponse, @@ -24,12 +21,22 @@ from sglang.srt.entrypoints.openai.utils import ( to_openai_style_logprobs, ) from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.template_manager import TemplateManager +from sglang.srt.managers.tokenizer_manager import TokenizerManager logger = logging.getLogger(__name__) class OpenAIServingCompletion(OpenAIServingBase): - """Handler for completion requests""" + """Handler for /v1/completion requests""" + + def __init__( + self, + tokenizer_manager: TokenizerManager, + template_manager: TemplateManager, + ): + super().__init__(tokenizer_manager) + self.template_manager = template_manager def _request_id_prefix(self) -> str: return "cmpl-" @@ -47,7 +54,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ) # Process prompt prompt = request.prompt - if is_completion_template_defined(): + if self.template_manager.completion_template_name is not None: prompt = generate_completion_prompt_from_request(request) # Set logprob start length based on echo and logprobs @@ -141,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServingBase): prompt_tokens = {} completion_tokens = {} cached_tokens = {} + hidden_states = {} try: async for content in self.tokenizer_manager.generate_request( @@ -152,6 +160,7 @@ class OpenAIServingCompletion(OpenAIServingBase): prompt_tokens[index] = content["meta_info"]["prompt_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"] cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) + hidden_states[index] = content["meta_info"].get("hidden_states", None) stream_buffer = stream_buffers.get(index, "") # Handle echo for first chunk @@ -192,7 +201,6 @@ class OpenAIServingCompletion(OpenAIServingBase): delta = text[len(stream_buffer) :] stream_buffers[index] = stream_buffer + delta finish_reason = content["meta_info"]["finish_reason"] - hidden_states = content["meta_info"].get("hidden_states", None) choice_data = CompletionResponseStreamChoice( index=index, @@ -269,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServingBase): adapted_request: GenerateReqInput, request: CompletionRequest, raw_request: Request, - ) -> Union[CompletionResponse, ErrorResponse]: + ) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]: """Handle non-streaming completion request""" try: generator = self.tokenizer_manager.generate_request( diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 4fe60f230..4f2db1dbe 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional, Union from fastapi import Request +from fastapi.responses import ORJSONResponse from sglang.srt.conversation import generate_embedding_convs from sglang.srt.entrypoints.openai.protocol import ( @@ -13,10 +14,20 @@ from sglang.srt.entrypoints.openai.protocol import ( ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.managers.io_struct import EmbeddingReqInput +from sglang.srt.managers.template_manager import TemplateManager +from sglang.srt.managers.tokenizer_manager import TokenizerManager class OpenAIServingEmbedding(OpenAIServingBase): - """Handler for embedding requests""" + """Handler for v1/embeddings requests""" + + def __init__( + self, + tokenizer_manager: TokenizerManager, + template_manager: TemplateManager, + ): + super().__init__(tokenizer_manager) + self.template_manager = template_manager def _request_id_prefix(self) -> str: return "embd-" @@ -68,11 +79,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): prompt_kwargs = {"text": prompt} elif isinstance(prompt, list): if len(prompt) > 0 and isinstance(prompt[0], str): - # List of strings - if it's a single string in a list, treat as single string - if len(prompt) == 1: - prompt_kwargs = {"text": prompt[0]} - else: - prompt_kwargs = {"text": prompt} + prompt_kwargs = {"text": prompt} elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput): # Handle multimodal embedding inputs texts = [] @@ -84,11 +91,10 @@ class OpenAIServingEmbedding(OpenAIServingBase): generate_prompts = [] # Check if we have a chat template for multimodal embeddings - chat_template_name = getattr( - self.tokenizer_manager, "chat_template_name", None - ) - if chat_template_name is not None: - convs = generate_embedding_convs(texts, images, chat_template_name) + if self.template_manager.chat_template_name is not None: + convs = generate_embedding_convs( + texts, images, self.template_manager.chat_template_name + ) for conv in convs: generate_prompts.append(conv.get_prompt()) else: @@ -122,7 +128,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): adapted_request: EmbeddingReqInput, request: EmbeddingRequest, raw_request: Request, - ) -> Union[EmbeddingResponse, ErrorResponse]: + ) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]: """Handle the embedding request""" try: ret = await self.tokenizer_manager.generate_request( diff --git a/python/sglang/srt/entrypoints/openai/serving_rerank.py b/python/sglang/srt/entrypoints/openai/serving_rerank.py index 50be5c3cc..b053c55b3 100644 --- a/python/sglang/srt/entrypoints/openai/serving_rerank.py +++ b/python/sglang/srt/entrypoints/openai/serving_rerank.py @@ -2,6 +2,7 @@ import logging from typing import Any, Dict, List, Optional, Union from fastapi import Request +from fastapi.responses import ORJSONResponse from sglang.srt.entrypoints.openai.protocol import ( ErrorResponse, @@ -15,7 +16,10 @@ logger = logging.getLogger(__name__) class OpenAIServingRerank(OpenAIServingBase): - """Handler for rerank requests""" + """Handler for /v1/rerank requests""" + + # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved + # to another module in the future. def _request_id_prefix(self) -> str: return "rerank-" @@ -61,7 +65,7 @@ class OpenAIServingRerank(OpenAIServingBase): adapted_request: EmbeddingReqInput, request: V1RerankReqInput, raw_request: Request, - ) -> Union[RerankResponse, ErrorResponse]: + ) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]: """Handle the rerank request""" try: ret = await self.tokenizer_manager.generate_request( @@ -74,16 +78,16 @@ class OpenAIServingRerank(OpenAIServingBase): if not isinstance(ret, list): ret = [ret] - response = self._build_rerank_response(ret, request) - return response + responses = self._build_rerank_response(ret, request) + return responses def _build_rerank_response( self, ret: List[Dict[str, Any]], request: V1RerankReqInput ) -> List[RerankResponse]: """Build the rerank response from generation results""" - response = [] + responses = [] for idx, ret_item in enumerate(ret): - response.append( + responses.append( RerankResponse( score=ret_item["embedding"], document=request.documents[idx], @@ -93,6 +97,6 @@ class OpenAIServingRerank(OpenAIServingBase): ) # Sort by score in descending order (highest relevance first) - response.sort(key=lambda x: x.score, reverse=True) + responses.sort(key=lambda x: x.score, reverse=True) - return response + return responses diff --git a/python/sglang/srt/entrypoints/openai/serving_score.py b/python/sglang/srt/entrypoints/openai/serving_score.py index af73a866a..fc8ce5dca 100644 --- a/python/sglang/srt/entrypoints/openai/serving_score.py +++ b/python/sglang/srt/entrypoints/openai/serving_score.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Union from fastapi import Request @@ -14,7 +14,10 @@ logger = logging.getLogger(__name__) class OpenAIServingScore(OpenAIServingBase): - """Handler for scoring requests""" + """Handler for /v1/score requests""" + + # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved + # to another module in the future. def _request_id_prefix(self) -> str: return "score-" diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index e80125cf6..94ac5458d 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -1,9 +1,6 @@ import logging from typing import Any, Dict, List, Optional, Union -import jinja2.nodes -import transformers.utils.chat_template_utils as hf_chat_utils - from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, CompletionRequest, @@ -13,168 +10,6 @@ from sglang.srt.entrypoints.openai.protocol import ( logger = logging.getLogger(__name__) -# ============================================================================ -# JINJA TEMPLATE CONTENT FORMAT DETECTION -# ============================================================================ -# -# This adapts vLLM's approach for detecting chat template content format: -# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313 -# - Analyzes Jinja template AST to detect content iteration patterns -# - 'openai' format: templates with {%- for content in message['content'] -%} loops -# - 'string' format: templates that expect simple string content -# - Processes content accordingly to match template expectations - - -def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: - """Check if node is a variable access like {{ varname }}""" - if isinstance(node, jinja2.nodes.Name): - return node.ctx == "load" and node.name == varname - return False - - -def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: - """Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}""" - if isinstance(node, jinja2.nodes.Getitem): - return ( - _is_var_access(node.node, varname) - and isinstance(node.arg, jinja2.nodes.Const) - and node.arg.value == key - ) - - if isinstance(node, jinja2.nodes.Getattr): - return _is_var_access(node.node, varname) and node.attr == key - - return False - - -def _is_var_or_elems_access( - node: jinja2.nodes.Node, - varname: str, - key: str = None, -) -> bool: - """Check if node accesses varname or varname[key] with filters/tests""" - if isinstance(node, jinja2.nodes.Filter): - return node.node is not None and _is_var_or_elems_access( - node.node, varname, key - ) - if isinstance(node, jinja2.nodes.Test): - return _is_var_or_elems_access(node.node, varname, key) - - if isinstance(node, jinja2.nodes.Getitem) and isinstance( - node.arg, jinja2.nodes.Slice - ): - return _is_var_or_elems_access(node.node, varname, key) - - return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) - - -def _try_extract_ast(chat_template: str): - """Try to parse the Jinja template into an AST""" - try: - jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) - return jinja_compiled.environment.parse(chat_template) - except Exception as e: - logger.debug(f"Error when compiling Jinja template: {e}") - return None - - -def detect_template_content_format(chat_template: str) -> str: - """ - Detect whether a chat template expects 'string' or 'openai' content format. - - - 'string': content is a simple string (like DeepSeek templates) - - 'openai': content is a list of structured dicts (like Llama4 templates) - - Detection logic: - - If template has loops like {%- for content in message['content'] -%} → 'openai' - - Otherwise → 'string' - """ - jinja_ast = _try_extract_ast(chat_template) - if jinja_ast is None: - return "string" - - try: - # Look for patterns like: {%- for content in message['content'] -%} - for loop_ast in jinja_ast.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter - - # Check if iterating over message['content'] or similar - if _is_var_or_elems_access(loop_iter, "message", "content"): - return "openai" # Found content iteration → openai format - - return "string" # No content loops found → string format - except Exception as e: - logger.debug(f"Error when parsing AST of Jinja template: {e}") - return "string" - - -def process_content_for_template_format( - msg_dict: dict, - content_format: str, - image_data: list, - audio_data: list, - modalities: list, -) -> dict: - """ - Process message content based on detected template format. - - Args: - msg_dict: Message dictionary with content - content_format: 'string' or 'openai' (detected via AST analysis) - image_data: List to append extracted image URLs - audio_data: List to append extracted audio URLs - modalities: List to append modalities - - Returns: - Processed message dictionary - """ - if not isinstance(msg_dict.get("content"), list): - # Already a string or None, no processing needed - return {k: v for k, v in msg_dict.items() if v is not None} - - if content_format == "openai": - # OpenAI format: preserve structured content list, normalize types - processed_content_parts = [] - for chunk in msg_dict["content"]: - if isinstance(chunk, dict): - chunk_type = chunk.get("type") - - if chunk_type == "image_url": - image_data.append(chunk["image_url"]["url"]) - if chunk.get("modalities"): - modalities.append(chunk.get("modalities")) - # Normalize to simple 'image' type for template compatibility - processed_content_parts.append({"type": "image"}) - elif chunk_type == "audio_url": - audio_data.append(chunk["audio_url"]["url"]) - # Normalize to simple 'audio' type - processed_content_parts.append({"type": "audio"}) - else: - # Keep other content as-is (text, etc.) - processed_content_parts.append(chunk) - - new_msg = { - k: v for k, v in msg_dict.items() if v is not None and k != "content" - } - new_msg["content"] = processed_content_parts - return new_msg - - else: # content_format == "string" - # String format: flatten to text only (for templates like DeepSeek) - text_parts = [] - for chunk in msg_dict["content"]: - if isinstance(chunk, dict) and chunk.get("type") == "text": - text_parts.append(chunk["text"]) - # Note: For string format, we ignore images/audio since the template - # doesn't expect structured content - multimodal placeholders would - # need to be inserted differently - - new_msg = msg_dict.copy() - new_msg["content"] = " ".join(text_parts) if text_parts else "" - new_msg = {k: v for k, v in new_msg.items() if v is not None} - return new_msg - - def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None, diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index e0e95cff4..d1e414df6 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List from partial_json_parser.core.exceptions import MalformedJSON from partial_json_parser.core.options import Allow +from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.core_types import ( StreamingParseResult, ToolCallItem, @@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import ( _is_complete_json, _partial_json_loads, ) -from sglang.srt.openai_api.protocol import Tool logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py index 1245c3db4..e3befca5b 100644 --- a/python/sglang/srt/function_call/deepseekv3_detector.py +++ b/python/sglang/srt/function_call/deepseekv3_detector.py @@ -3,6 +3,7 @@ import logging import re from typing import List +from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, @@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import ( ) from sglang.srt.function_call.ebnf_composer import EBNFComposer from sglang.srt.function_call.utils import _is_complete_json -from sglang.srt.openai_api.protocol import Tool logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index af1033a52..10b92a0af 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -1,6 +1,12 @@ import logging from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union +from sglang.srt.entrypoints.openai.protocol import ( + StructuralTagResponseFormat, + StructuresResponseFormat, + Tool, + ToolChoice, +) from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector @@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector -from sglang.srt.openai_api.protocol import ( - StructuralTagResponseFormat, - StructuresResponseFormat, - Tool, - ToolChoice, -) logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py index 065ffd7f6..e7afeddb0 100644 --- a/python/sglang/srt/function_call/llama32_detector.py +++ b/python/sglang/srt/function_call/llama32_detector.py @@ -2,6 +2,7 @@ import json import logging from typing import List +from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, @@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import ( _GetInfoFunc, ) from sglang.srt.function_call.ebnf_composer import EBNFComposer -from sglang.srt.openai_api.protocol import Tool logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py index 05d3bfead..031368006 100644 --- a/python/sglang/srt/function_call/mistral_detector.py +++ b/python/sglang/srt/function_call/mistral_detector.py @@ -3,6 +3,7 @@ import logging import re from typing import List +from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, @@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import ( _GetInfoFunc, ) from sglang.srt.function_call.ebnf_composer import EBNFComposer -from sglang.srt.openai_api.protocol import Tool logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py index 4ef2e3db5..d3096d919 100644 --- a/python/sglang/srt/function_call/pythonic_detector.py +++ b/python/sglang/srt/function_call/pythonic_detector.py @@ -4,6 +4,7 @@ import logging import re from typing import List, Optional +from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, @@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import ( _GetInfoFunc, ) from sglang.srt.function_call.ebnf_composer import EBNFComposer -from sglang.srt.openai_api.protocol import Tool logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py index ad1317777..cee3f18ea 100644 --- a/python/sglang/srt/function_call/qwen25_detector.py +++ b/python/sglang/srt/function_call/qwen25_detector.py @@ -3,6 +3,7 @@ import logging import re from typing import List +from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, @@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import ( _GetInfoFunc, ) from sglang.srt.function_call.ebnf_composer import EBNFComposer -from sglang.srt.openai_api.protocol import Tool logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/openai_api/utils.py b/python/sglang/srt/jinja_template_utils.py similarity index 95% rename from python/sglang/srt/openai_api/utils.py rename to python/sglang/srt/jinja_template_utils.py index 610251aff..14a94e487 100644 --- a/python/sglang/srt/openai_api/utils.py +++ b/python/sglang/srt/jinja_template_utils.py @@ -1,11 +1,12 @@ -""" -Utility functions for OpenAI API adapter. +"""Template utilities for Jinja template processing. + +This module provides utilities for analyzing and processing Jinja chat templates, +including content format detection and message processing. """ import logging -from typing import Dict, List -import jinja2.nodes +import jinja2 import transformers.utils.chat_template_utils as hf_chat_utils logger = logging.getLogger(__name__) @@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str): return None -def detect_template_content_format(chat_template: str) -> str: +def detect_jinja_template_content_format(chat_template: str) -> str: """ Detect whether a chat template expects 'string' or 'openai' content format. diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 3c4bf2a42..0fcb38227 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -864,12 +864,6 @@ class SetInternalStateReq: server_args: Dict[str, Any] -@dataclass -class V1RerankReqInput: - query: str - documents: List[str] - - @dataclass class SetInternalStateReqOutput: updated: bool diff --git a/python/sglang/srt/managers/template_manager.py b/python/sglang/srt/managers/template_manager.py new file mode 100644 index 000000000..4684bf1a0 --- /dev/null +++ b/python/sglang/srt/managers/template_manager.py @@ -0,0 +1,226 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Centralized template management for chat templates and completion templates. + +This module provides a unified interface for managing both chat conversation templates +and code completion templates, eliminating global state and improving modularity. +""" + +import json +import logging +import os +from typing import Optional + +from sglang.srt.code_completion_parser import ( + CompletionTemplate, + FimPosition, + completion_template_exists, + register_completion_template, +) +from sglang.srt.conversation import ( + Conversation, + SeparatorStyle, + chat_template_exists, + get_conv_template_by_model_path, + register_conv_template, +) +from sglang.srt.jinja_template_utils import detect_jinja_template_content_format + +logger = logging.getLogger(__name__) + + +class TemplateManager: + """ + Centralized manager for chat and completion templates. + + This class encapsulates all template-related state and operations, + eliminating the need for global variables and providing a clean + interface for template management. + """ + + def __init__(self): + self._chat_template_name: Optional[str] = None + self._completion_template_name: Optional[str] = None + self._jinja_template_content_format: Optional[str] = None + + @property + def chat_template_name(self) -> Optional[str]: + """Get the current chat template name.""" + return self._chat_template_name + + @property + def completion_template_name(self) -> Optional[str]: + """Get the current completion template name.""" + return self._completion_template_name + + @property + def jinja_template_content_format(self) -> Optional[str]: + """Get the detected template content format ('string' or 'openai' or None).""" + return self._jinja_template_content_format + + def load_chat_template( + self, tokenizer_manager, chat_template_arg: str, model_path: str + ) -> None: + """ + Load a chat template from various sources. + + Args: + tokenizer_manager: The tokenizer manager instance + chat_template_arg: Template name or file path + model_path: Path to the model + """ + logger.info(f"Loading chat template: {chat_template_arg}") + + if not chat_template_exists(chat_template_arg): + if not os.path.exists(chat_template_arg): + raise RuntimeError( + f"Chat template {chat_template_arg} is not a built-in template name " + "or a valid chat template file path." + ) + + if chat_template_arg.endswith(".jinja"): + self._load_jinja_template(tokenizer_manager, chat_template_arg) + else: + self._load_json_chat_template(chat_template_arg) + else: + self._chat_template_name = chat_template_arg + + def guess_chat_template_from_model_path(self, model_path: str) -> None: + """ + Infer chat template name from model path. + + Args: + model_path: Path to the model + """ + template_name = get_conv_template_by_model_path(model_path) + if template_name is not None: + logger.info(f"Inferred chat template from model path: {template_name}") + self._chat_template_name = template_name + + def load_completion_template(self, completion_template_arg: str) -> None: + """ + Load completion template for code completion. + + Args: + completion_template_arg: Template name or file path + """ + logger.info(f"Loading completion template: {completion_template_arg}") + + if not completion_template_exists(completion_template_arg): + if not os.path.exists(completion_template_arg): + raise RuntimeError( + f"Completion template {completion_template_arg} is not a built-in template name " + "or a valid completion template file path." + ) + + self._load_json_completion_template(completion_template_arg) + else: + self._completion_template_name = completion_template_arg + + def initialize_templates( + self, + tokenizer_manager, + model_path: str, + chat_template: Optional[str] = None, + completion_template: Optional[str] = None, + ) -> None: + """ + Initialize all templates based on provided configuration. + + Args: + tokenizer_manager: The tokenizer manager instance + model_path: Path to the model + chat_template: Optional chat template name/path + completion_template: Optional completion template name/path + """ + # Load chat template + if chat_template: + self.load_chat_template(tokenizer_manager, chat_template, model_path) + else: + self.guess_chat_template_from_model_path(model_path) + + # Load completion template + if completion_template: + self.load_completion_template(completion_template) + + def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None: + """Load a Jinja template file.""" + with open(template_path, "r") as f: + chat_template = "".join(f.readlines()).strip("\n") + tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n") + self._chat_template_name = None + # Detect content format from the loaded template + self._jinja_template_content_format = detect_jinja_template_content_format( + chat_template + ) + logger.info( + f"Detected chat template content format: {self._jinja_template_content_format}" + ) + + def _load_json_chat_template(self, template_path: str) -> None: + """Load a JSON chat template file.""" + assert template_path.endswith( + ".json" + ), "unrecognized format of chat template file" + + with open(template_path, "r") as filep: + template = json.load(filep) + try: + sep_style = SeparatorStyle[template["sep_style"]] + except KeyError: + raise ValueError( + f"Unknown separator style: {template['sep_style']}" + ) from None + + register_conv_template( + Conversation( + name=template["name"], + system_template=template["system"] + "\n{system_message}", + system_message=template.get("system_message", ""), + roles=(template["user"], template["assistant"]), + sep_style=sep_style, + sep=template.get("sep", "\n"), + stop_str=template["stop_str"], + ), + override=True, + ) + self._chat_template_name = template["name"] + + def _load_json_completion_template(self, template_path: str) -> None: + """Load a JSON completion template file.""" + assert template_path.endswith( + ".json" + ), "unrecognized format of completion template file" + + with open(template_path, "r") as filep: + template = json.load(filep) + try: + fim_position = FimPosition[template["fim_position"]] + except KeyError: + raise ValueError( + f"Unknown fim position: {template['fim_position']}" + ) from None + + register_completion_template( + CompletionTemplate( + name=template["name"], + fim_begin_token=template["fim_begin_token"], + fim_middle_token=template["fim_middle_token"], + fim_end_token=template["fim_end_token"], + fim_position=fim_position, + ), + override=True, + ) + self._completion_template_name = template["name"] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index fbab668a4..959f23277 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1058,12 +1058,7 @@ class TokenizerManager: "lora_path", ] ) - out_skip_names = set( - [ - "text", - "output_ids", - ] - ) + out_skip_names = set(["text", "output_ids", "embedding"]) elif self.log_requests_level == 1: max_length = 2048 elif self.log_requests_level == 2: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py deleted file mode 100644 index aba1a5afd..000000000 --- a/python/sglang/srt/openai_api/adapter.py +++ /dev/null @@ -1,2148 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Conversion between OpenAI APIs and native SRT APIs""" - -import asyncio -import base64 -import json -import logging -import os -import time -import uuid -from http import HTTPStatus -from typing import Dict, List - -from fastapi import HTTPException, Request, UploadFile -from fastapi.responses import ORJSONResponse, StreamingResponse -from pydantic import ValidationError - -from sglang.srt.code_completion_parser import ( - generate_completion_prompt_from_request, - is_completion_template_defined, -) -from sglang.srt.conversation import ( - Conversation, - SeparatorStyle, - chat_template_exists, - generate_chat_conv, - generate_embedding_convs, - get_conv_template_by_model_path, - register_conv_template, -) -from sglang.srt.function_call.function_call_parser import FunctionCallParser -from sglang.srt.managers.io_struct import ( - EmbeddingReqInput, - GenerateReqInput, - V1RerankReqInput, -) -from sglang.srt.openai_api.protocol import ( - BatchRequest, - BatchResponse, - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatCompletionTokenLogprob, - ChatMessage, - ChoiceLogprobs, - CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - DeltaMessage, - EmbeddingObject, - EmbeddingRequest, - EmbeddingResponse, - ErrorResponse, - FileDeleteResponse, - FileRequest, - FileResponse, - FunctionResponse, - LogProbs, - MultimodalEmbeddingInput, - RerankResponse, - ScoringRequest, - ScoringResponse, - ToolCall, - TopLogprob, - UsageInfo, -) -from sglang.srt.openai_api.utils import ( - detect_template_content_format, - process_content_for_template_format, -) -from sglang.srt.reasoning_parser import ReasoningParser -from sglang.utils import convert_json_schema_to_str, get_exception_traceback - -logger = logging.getLogger(__name__) - -chat_template_name = None - -# Global cache for template content format detection (one model/template per instance) -# NOTE: A better approach would be to initialize the chat template format when the endpoint is created -_cached_chat_template = None -_cached_template_format = None - - -class FileMetadata: - def __init__(self, filename: str, purpose: str): - self.filename = filename - self.purpose = purpose - - -# In-memory storage for batch jobs and files -batch_storage: Dict[str, BatchResponse] = {} -file_id_request: Dict[str, FileMetadata] = {} -file_id_response: Dict[str, FileResponse] = {} -# map file id to file path in SGLang backend -file_id_storage: Dict[str, str] = {} - -# backend storage directory -storage_dir = None - - -def create_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, -): - error = ErrorResponse(message=message, type=err_type, code=status_code.value) - return ORJSONResponse(content=error.model_dump(), status_code=error.code) - - -def create_streaming_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, -) -> str: - error = ErrorResponse(message=message, type=err_type, code=status_code.value) - json_str = json.dumps({"error": error.model_dump()}) - return json_str - - -def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path): - global chat_template_name - - logger.info( - f"Use chat template for the OpenAI-compatible API server: {chat_template_arg}" - ) - - if not chat_template_exists(chat_template_arg): - if not os.path.exists(chat_template_arg): - raise RuntimeError( - f"Chat template {chat_template_arg} is not a built-in template name " - "or a valid chat template file path." - ) - if chat_template_arg.endswith(".jinja"): - with open(chat_template_arg, "r") as f: - chat_template = "".join(f.readlines()).strip("\n") - tokenizer_manager.tokenizer.chat_template = chat_template.replace( - "\\n", "\n" - ) - chat_template_name = None - else: - assert chat_template_arg.endswith( - ".json" - ), "unrecognized format of chat template file" - with open(chat_template_arg, "r") as filep: - template = json.load(filep) - try: - sep_style = SeparatorStyle[template["sep_style"]] - except KeyError: - raise ValueError( - f"Unknown separator style: {template['sep_style']}" - ) from None - register_conv_template( - Conversation( - name=template["name"], - system_template=template["system"] + "\n{system_message}", - system_message=template.get("system_message", ""), - roles=(template["user"], template["assistant"]), - sep_style=sep_style, - sep=template.get("sep", "\n"), - stop_str=template["stop_str"], - ), - override=True, - ) - chat_template_name = template["name"] - else: - chat_template_name = chat_template_arg - - -def guess_chat_template_name_from_model_path(model_path): - global chat_template_name - chat_template_name = get_conv_template_by_model_path(model_path) - if chat_template_name is not None: - logger.info( - f"Infer the chat template name from the model path and obtain the result: {chat_template_name}." - ) - - -def _validate_prompt(prompt: str): - """Validate that the prompt is not empty or whitespace only.""" - is_invalid = False - - # Check for empty/whitespace string - if isinstance(prompt, str): - is_invalid = not prompt.strip() - # Check for various invalid list cases: [], [""], [" "], [[]] - elif isinstance(prompt, list): - is_invalid = not prompt or ( - len(prompt) == 1 - and ( - (isinstance(prompt[0], str) and not prompt[0].strip()) - or (isinstance(prompt[0], list) and not prompt[0]) - ) - ) - - if is_invalid: - raise HTTPException( - status_code=400, - detail="Input cannot be empty or contain only whitespace.", - ) - - return prompt - - -async def v1_files_create( - file: UploadFile, purpose: str, file_storage_path: str = None -): - try: - global storage_dir - if file_storage_path: - storage_dir = file_storage_path - # Read the file content - file_content = await file.read() - - # Create an instance of RequestBody - request_body = FileRequest(file=file_content, purpose=purpose) - - # Save the file to the sglang_oai_storage directory - os.makedirs(storage_dir, exist_ok=True) - file_id = f"backend_input_file-{uuid.uuid4()}" - filename = f"{file_id}.jsonl" - file_path = os.path.join(storage_dir, filename) - - with open(file_path, "wb") as f: - f.write(request_body.file) - - # add info to global file map - file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose) - file_id_storage[file_id] = file_path - - # Return the response in the required format - response = FileResponse( - id=file_id, - bytes=len(request_body.file), - created_at=int(time.time()), - filename=file.filename, - purpose=request_body.purpose, - ) - file_id_response[file_id] = response - - return response - except ValidationError as e: - return {"error": "Invalid input", "details": e.errors()} - - -async def v1_delete_file(file_id: str): - # Retrieve the file job from the in-memory storage - file_response = file_id_response.get(file_id) - if file_response is None: - raise HTTPException(status_code=404, detail="File not found") - file_path = file_id_storage.get(file_id) - if file_path is None: - raise HTTPException(status_code=404, detail="File not found") - os.remove(file_path) - del file_id_response[file_id] - del file_id_storage[file_id] - return FileDeleteResponse(id=file_id, deleted=True) - - -async def v1_batches(tokenizer_manager, raw_request: Request): - try: - body = await raw_request.json() - - batch_request = BatchRequest(**body) - - batch_id = f"batch_{uuid.uuid4()}" - - # Create an instance of BatchResponse - batch_response = BatchResponse( - id=batch_id, - endpoint=batch_request.endpoint, - input_file_id=batch_request.input_file_id, - completion_window=batch_request.completion_window, - created_at=int(time.time()), - metadata=batch_request.metadata, - ) - - batch_storage[batch_id] = batch_response - - # Start processing the batch asynchronously - asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request)) - - # Return the initial batch_response - return batch_response - - except ValidationError as e: - return {"error": "Invalid input", "details": e.errors()} - except Exception as e: - return {"error": str(e)} - - -async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest): - try: - # Update the batch status to "in_progress" - batch_storage[batch_id].status = "in_progress" - batch_storage[batch_id].in_progress_at = int(time.time()) - - # Retrieve the input file content - input_file_request = file_id_request.get(batch_request.input_file_id) - if not input_file_request: - raise ValueError("Input file not found") - - # Parse the JSONL file and process each request - input_file_path = file_id_storage.get(batch_request.input_file_id) - with open(input_file_path, "r", encoding="utf-8") as f: - lines = f.readlines() - - total_requests = len(lines) - completed_requests = 0 - failed_requests = 0 - - all_ret = [] - end_point = batch_storage[batch_id].endpoint - file_request_list = [] - all_requests = [] - request_ids = [] - for line_id, line in enumerate(lines): - request_data = json.loads(line) - file_request_list.append(request_data) - body = request_data["body"] - request_ids.append(f"{batch_id}-req_{line_id}") - - # Although streaming is supported for standalone completions, it is not supported in - # batch mode (multiple completions in single request). - if body.get("stream", False): - raise ValueError("Streaming requests are not supported in batch mode") - - if end_point == "/v1/chat/completions": - all_requests.append(ChatCompletionRequest(**body)) - elif end_point == "/v1/completions": - all_requests.append(CompletionRequest(**body)) - - if end_point == "/v1/chat/completions": - adapted_request, request = v1_chat_generate_request( - all_requests, tokenizer_manager, request_ids=request_ids - ) - elif end_point == "/v1/completions": - adapted_request, request = v1_generate_request( - all_requests, request_ids=request_ids - ) - - try: - created = int(time.time()) - ret = await tokenizer_manager.generate_request(adapted_request).__anext__() - if not isinstance(ret, list): - ret = [ret] - if end_point == "/v1/chat/completions": - responses = v1_chat_generate_response( - request, - ret, - created, - to_file=True, - cache_report=tokenizer_manager.server_args.enable_cache_report, - tool_call_parser=tokenizer_manager.server_args.tool_call_parser, - ) - else: - responses = v1_generate_response( - request, - ret, - tokenizer_manager, - created, - to_file=True, - cache_report=tokenizer_manager.server_args.enable_cache_report, - ) - - except Exception as e: - logger.error(f"error: {get_exception_traceback()}") - responses = [] - error_json = { - "id": f"batch_req_{uuid.uuid4()}", - "custom_id": request_data.get("custom_id"), - "response": None, - "error": {"message": str(e)}, - } - all_ret.append(error_json) - failed_requests += len(file_request_list) - - for idx, response in enumerate(responses): - # the batch_req here can be changed to be named within a batch granularity - response_json = { - "id": f"batch_req_{uuid.uuid4()}", - "custom_id": file_request_list[idx].get("custom_id"), - "response": response, - "error": None, - } - all_ret.append(response_json) - completed_requests += 1 - - # Write results to a new file - output_file_id = f"backend_result_file-{uuid.uuid4()}" - global storage_dir - output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl") - with open(output_file_path, "w", encoding="utf-8") as f: - for ret in all_ret: - f.write(json.dumps(ret) + "\n") - - # Update batch response with output file information - retrieve_batch = batch_storage[batch_id] - retrieve_batch.output_file_id = output_file_id - file_id_storage[output_file_id] = output_file_path - file_id_response[output_file_id] = FileResponse( - id=output_file_id, - bytes=os.path.getsize(output_file_path), - created_at=int(time.time()), - filename=f"{output_file_id}.jsonl", - purpose="batch_result", - ) - # Update batch status to "completed" - retrieve_batch.status = "completed" - retrieve_batch.completed_at = int(time.time()) - retrieve_batch.request_counts = { - "total": total_requests, - "completed": completed_requests, - "failed": failed_requests, - } - - except Exception as e: - logger.error(f"error: {e}") - # Update batch status to "failed" - retrieve_batch = batch_storage[batch_id] - retrieve_batch.status = "failed" - retrieve_batch.failed_at = int(time.time()) - retrieve_batch.errors = {"message": str(e)} - - -async def v1_retrieve_batch(batch_id: str): - # Retrieve the batch job from the in-memory storage - batch_response = batch_storage.get(batch_id) - if batch_response is None: - raise HTTPException(status_code=404, detail="Batch not found") - - return batch_response - - -async def v1_cancel_batch(tokenizer_manager, batch_id: str): - # Retrieve the batch job from the in-memory storage - batch_response = batch_storage.get(batch_id) - if batch_response is None: - raise HTTPException(status_code=404, detail="Batch not found") - - # Only do cancal when status is "validating" or "in_progress" - if batch_response.status in ["validating", "in_progress"]: - # Start cancelling the batch asynchronously - asyncio.create_task( - cancel_batch( - tokenizer_manager=tokenizer_manager, - batch_id=batch_id, - input_file_id=batch_response.input_file_id, - ) - ) - - # Update batch status to "cancelling" - batch_response.status = "cancelling" - - return batch_response - else: - raise HTTPException( - status_code=500, - detail=f"Current status is {batch_response.status}, no need to cancel", - ) - - -async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str): - try: - # Update the batch status to "cancelling" - batch_storage[batch_id].status = "cancelling" - - # Retrieve the input file content - input_file_request = file_id_request.get(input_file_id) - if not input_file_request: - raise ValueError("Input file not found") - - # Parse the JSONL file and process each request - input_file_path = file_id_storage.get(input_file_id) - with open(input_file_path, "r", encoding="utf-8") as f: - lines = f.readlines() - - # Cancel requests by request_ids - for line_id in range(len(lines)): - rid = f"{batch_id}-req_{line_id}" - tokenizer_manager.abort_request(rid=rid) - - retrieve_batch = batch_storage[batch_id] - retrieve_batch.status = "cancelled" - - except Exception as e: - logger.error("error in SGLang:", e) - # Update batch status to "failed" - retrieve_batch = batch_storage[batch_id] - retrieve_batch.status = "failed" - retrieve_batch.failed_at = int(time.time()) - retrieve_batch.errors = {"message": str(e)} - - -async def v1_retrieve_file(file_id: str): - # Retrieve the batch job from the in-memory storage - file_response = file_id_response.get(file_id) - if file_response is None: - raise HTTPException(status_code=404, detail="File not found") - return file_response - - -async def v1_retrieve_file_content(file_id: str): - file_pth = file_id_storage.get(file_id) - if not file_pth or not os.path.exists(file_pth): - raise HTTPException(status_code=404, detail="File not found") - - def iter_file(): - with open(file_pth, mode="rb") as file_like: - yield from file_like - - return StreamingResponse(iter_file(), media_type="application/octet-stream") - - -def v1_generate_request( - all_requests: List[CompletionRequest], request_ids: List[str] = None -): - if len(all_requests) > 1: - first_prompt_type = type(all_requests[0].prompt) - for request in all_requests: - assert ( - type(request.prompt) is first_prompt_type - ), "All prompts must be of the same type in file input settings" - if request.n > 1: - raise ValueError( - "Parallel sampling is not supported for completions from files" - ) - - prompts = [] - sampling_params_list = [] - return_logprobs = [] - logprob_start_lens = [] - top_logprobs_nums = [] - lora_paths = [] - return_hidden_states = [] - - for request in all_requests: - # NOTE: with openai API, the prompt's logprobs are always not computed - if request.echo and request.logprobs: - logger.warning( - "Echo is not compatible with logprobs. " - "To compute logprobs of input prompt, please use the native /generate API." - ) - - prompt = request.prompt - if is_completion_template_defined(): - prompt = generate_completion_prompt_from_request(request) - prompts.append(prompt) - - lora_paths.append(request.lora_path) - if request.echo and request.logprobs: - current_logprob_start_len = 0 - else: - current_logprob_start_len = -1 - sampling_params_list.append( - { - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "min_new_tokens": request.min_tokens, - "stop": request.stop, - "stop_token_ids": request.stop_token_ids, - "top_p": request.top_p, - "top_k": request.top_k, - "min_p": request.min_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "repetition_penalty": request.repetition_penalty, - "regex": request.regex, - "json_schema": request.json_schema, - "ebnf": request.ebnf, - "n": request.n, - "no_stop_trim": request.no_stop_trim, - "ignore_eos": request.ignore_eos, - "skip_special_tokens": request.skip_special_tokens, - "logit_bias": request.logit_bias, - } - ) - return_logprobs.append(request.logprobs is not None) - logprob_start_lens.append(current_logprob_start_len) - top_logprobs_nums.append( - request.logprobs if request.logprobs is not None else 0 - ) - return_hidden_states.append(request.return_hidden_states) - - if len(all_requests) == 1: - if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): - prompt_kwargs = {"text": prompts[0]} - else: - prompt_kwargs = {"input_ids": prompts[0]} - sampling_params_list = sampling_params_list[0] - return_logprobs = return_logprobs[0] - logprob_start_lens = logprob_start_lens[0] - top_logprobs_nums = top_logprobs_nums[0] - lora_paths = lora_paths[0] - return_hidden_states = return_hidden_states[0] - else: - if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): - prompt_kwargs = {"text": prompts} - else: - prompt_kwargs = {"input_ids": prompts} - - adapted_request = GenerateReqInput( - **prompt_kwargs, - sampling_params=sampling_params_list, - return_logprob=return_logprobs, - top_logprobs_num=top_logprobs_nums, - logprob_start_len=logprob_start_lens, - return_text_in_logprobs=True, - stream=all_requests[0].stream, - rid=request_ids, - lora_path=lora_paths, - return_hidden_states=return_hidden_states, - bootstrap_host=all_requests[0].bootstrap_host, - bootstrap_port=all_requests[0].bootstrap_port, - bootstrap_room=all_requests[0].bootstrap_room, - ) - - return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] - - -def v1_generate_response( - request, ret, tokenizer_manager, created, to_file=False, cache_report=False -): - choices = [] - echo = False - - if (not isinstance(request, list)) and request.echo: - # TODO: handle the case prompt is token ids - if isinstance(request.prompt, list) and isinstance(request.prompt[0], str): - # for the case of multiple str prompts - prompts = request.prompt - elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): - # for the case of multiple token ids prompts - prompts = [ - tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True) - for prompt in request.prompt - ] - elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): - # for the case of single token ids prompt - prompts = [ - tokenizer_manager.tokenizer.decode( - request.prompt, skip_special_tokens=True - ) - ] - else: - # for the case of single str prompt - prompts = [request.prompt] - echo = True - - for idx, ret_item in enumerate(ret): - text = ret_item["text"] - if isinstance(request, list) and request[idx].echo: - echo = True - text = request[idx].prompt + text - if echo and not isinstance(request, list): - prompt_index = idx // request.n - text = prompts[prompt_index] + text - - logprobs = False - if isinstance(request, list) and request[idx].logprobs is not None: - logprobs = True - elif (not isinstance(request, list)) and request.logprobs is not None: - logprobs = True - if logprobs: - if echo: - input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] - input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] - else: - input_token_logprobs = None - input_top_logprobs = None - - logprobs = to_openai_style_logprobs( - input_token_logprobs=input_token_logprobs, - input_top_logprobs=input_top_logprobs, - output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], - output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], - ) - else: - logprobs = None - - hidden_states = None - if isinstance(request, list) and request[idx].return_hidden_states: - hidden_states = ret_item["meta_info"].get("hidden_states", None) - elif (not isinstance(request, list)) and request.return_hidden_states: - hidden_states = ret_item["meta_info"].get("hidden_states", None) - if hidden_states is not None: - hidden_states = ( - hidden_states[-1] if hidden_states and len(hidden_states) > 1 else [] - ) - - finish_reason = ret_item["meta_info"]["finish_reason"] - - if to_file: - # to make the choice data json serializable - choice_data = { - "index": 0, - "text": text, - "logprobs": logprobs, - "finish_reason": finish_reason["type"] if finish_reason else None, - "matched_stop": ( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - } - if hidden_states is not None: - choice_data["hidden_states"] = hidden_states - else: - choice_data = CompletionResponseChoice( - index=idx, - text=text, - logprobs=logprobs, - finish_reason=finish_reason["type"] if finish_reason else None, - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - hidden_states=hidden_states, - ) - - choices.append(choice_data) - - if to_file: - responses = [] - for i, choice in enumerate(choices): - response = { - "status_code": 200, - "request_id": ret[i]["meta_info"]["id"], - "body": { - # remain the same but if needed we can change that - "id": ret[i]["meta_info"]["id"], - "object": "text_completion", - "created": created, - "model": request[i].model, - "choices": choice, - "usage": { - "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"], - "completion_tokens": ret[i]["meta_info"]["completion_tokens"], - "total_tokens": ret[i]["meta_info"]["prompt_tokens"] - + ret[i]["meta_info"]["completion_tokens"], - }, - "system_fingerprint": None, - }, - } - responses.append(response) - return responses - else: - prompt_tokens = sum( - ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) - ) - completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) - cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret) - response = CompletionResponse( - id=ret[0]["meta_info"]["id"], - model=request.model, - created=created, - choices=choices, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - prompt_tokens_details=( - {"cached_tokens": cached_tokens} if cache_report else None - ), - ), - ) - return response - - -async def v1_completions(tokenizer_manager, raw_request: Request): - try: - request_json = await raw_request.json() - except Exception as e: - return create_error_response("Invalid request body, error: ", str(e)) - all_requests = [CompletionRequest(**request_json)] - created = int(time.time()) - adapted_request, request = v1_generate_request(all_requests) - - if adapted_request.stream: - - async def generate_stream_resp(): - stream_buffers = {} - n_prev_tokens = {} - prompt_tokens = {} - completion_tokens = {} - cached_tokens = {} - hidden_states = {} - - try: - async for content in tokenizer_manager.generate_request( - adapted_request, raw_request - ): - index = content.get("index", 0) - - stream_buffer = stream_buffers.get(index, "") - n_prev_token = n_prev_tokens.get(index, 0) - - text = content["text"] - prompt_tokens[index] = content["meta_info"]["prompt_tokens"] - completion_tokens[index] = content["meta_info"]["completion_tokens"] - cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) - hidden_states[index] = content["meta_info"].get( - "hidden_states", None - ) or hidden_states.get(index) - - if not stream_buffer: # The first chunk - if request.echo: - if isinstance(request.prompt, str): - # for the case of single str prompts - prompts = request.prompt - elif isinstance(request.prompt, list): - if isinstance(request.prompt[0], str): - # for the case of multiple str prompts - prompts = request.prompt[index // request.n] - elif isinstance(request.prompt[0], int): - # for the case of single token ids prompt - prompts = tokenizer_manager.tokenizer.decode( - request.prompt, skip_special_tokens=True - ) - elif isinstance(request.prompt[0], list) and isinstance( - request.prompt[0][0], int - ): - # for the case of multiple token ids prompts - prompts = tokenizer_manager.tokenizer.decode( - request.prompt[index // request.n], - skip_special_tokens=True, - ) - - # Prepend prompt in response text. - text = prompts + text - - if request.logprobs is not None: - # The first chunk and echo is enabled. - if not stream_buffer and request.echo: - input_token_logprobs = content["meta_info"][ - "input_token_logprobs" - ] - input_top_logprobs = content["meta_info"][ - "input_top_logprobs" - ] - else: - input_token_logprobs = None - input_top_logprobs = None - - logprobs = to_openai_style_logprobs( - input_token_logprobs=input_token_logprobs, - input_top_logprobs=input_top_logprobs, - output_token_logprobs=content["meta_info"][ - "output_token_logprobs" - ][n_prev_token:], - output_top_logprobs=content["meta_info"][ - "output_top_logprobs" - ][n_prev_token:], - ) - n_prev_token = len( - content["meta_info"]["output_token_logprobs"] - ) - else: - logprobs = None - - delta = text[len(stream_buffer) :] - stream_buffer = stream_buffer + delta - finish_reason = content["meta_info"]["finish_reason"] - choice_data = CompletionResponseStreamChoice( - index=index, - text=delta, - logprobs=logprobs, - finish_reason=finish_reason["type"] if finish_reason else None, - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - ) - chunk = CompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - object="text_completion", - choices=[choice_data], - model=request.model, - ) - - stream_buffers[index] = stream_buffer - n_prev_tokens[index] = n_prev_token - - yield f"data: {chunk.model_dump_json()}\n\n" - if request.return_hidden_states and hidden_states: - for index, choice_hidden_states in hidden_states.items(): - last_token_hidden_states = ( - choice_hidden_states[-1] - if choice_hidden_states and len(choice_hidden_states) > 1 - else [] - ) - hidden_states_chunk = CompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[ - CompletionResponseStreamChoice( - text="", - index=index, - hidden_states=last_token_hidden_states, - finish_reason=None, - ) - ], - model=request.model, - ) - yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" - if request.stream_options and request.stream_options.include_usage: - total_prompt_tokens = sum( - tokens - for i, tokens in prompt_tokens.items() - if i % request.n == 0 - ) - total_completion_tokens = sum( - tokens for tokens in completion_tokens.values() - ) - cache_report = tokenizer_manager.server_args.enable_cache_report - if cache_report: - cached_tokens_sum = sum( - tokens for tokens in cached_tokens.values() - ) - prompt_tokens_details = {"cached_tokens": cached_tokens_sum} - else: - prompt_tokens_details = None - usage = UsageInfo( - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens, - prompt_tokens_details=prompt_tokens_details, - ) - - final_usage_chunk = CompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[], - model=request.model, - usage=usage, - ) - final_usage_data = final_usage_chunk.model_dump_json( - exclude_none=True - ) - yield f"data: {final_usage_data}\n\n" - except ValueError as e: - error = create_streaming_error_response(str(e)) - yield f"data: {error}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse( - generate_stream_resp(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), - ) - - # Non-streaming response. - try: - ret = await tokenizer_manager.generate_request( - adapted_request, raw_request - ).__anext__() - except ValueError as e: - return create_error_response(str(e)) - - if not isinstance(ret, list): - ret = [ret] - - response = v1_generate_response( - request, - ret, - tokenizer_manager, - created, - cache_report=tokenizer_manager.server_args.enable_cache_report, - ) - return response - - -def _get_enable_thinking_from_request(request_obj): - """Extracts the 'enable_thinking' flag from request chat_template_kwargs. - - Args: - request_obj: The request object (or an item from a list of requests). - - Returns: - The boolean value of 'enable_thinking' if found and not True, otherwise True. - """ - if ( - hasattr(request_obj, "chat_template_kwargs") - and request_obj.chat_template_kwargs - and request_obj.chat_template_kwargs.get("enable_thinking") is not None - ): - return request_obj.chat_template_kwargs.get("enable_thinking") - return True - - -def v1_chat_generate_request( - all_requests: List[ChatCompletionRequest], - tokenizer_manager, - request_ids: List[str] = None, -): - input_ids = [] - prompts = [] - sampling_params_list = [] - image_data_list = [] - audio_data_list = [] - return_logprobs = [] - logprob_start_lens = [] - top_logprobs_nums = [] - modalities_list = [] - lora_paths = [] - return_hidden_states = [] - - # NOTE: with openai API, the prompt's logprobs are always not computed - - is_multimodal = tokenizer_manager.model_config.is_multimodal - for request in all_requests: - # Prep the data needed for the underlying GenerateReqInput: - # - prompt: The full prompt string. - # - stop: Custom stop tokens. - # - image_data: None or a list of image strings (URLs or base64 strings). - # - audio_data: None or a list of audio strings (URLs). - # None skips any image processing in GenerateReqInput. - tool_call_constraint = None - prompt = "" - prompt_ids = [] - if not isinstance(request.messages, str): - # Apply chat template and its stop strings. - tools = None - if request.tools and request.tool_choice != "none": - request.skip_special_tokens = False - if not isinstance(request.tool_choice, str): - tools = [ - item.function.model_dump() - for item in request.tools - if item.function.name == request.tool_choice.function.name - ] - else: - tools = [item.function.model_dump() for item in request.tools] - - tool_call_parser = tokenizer_manager.server_args.tool_call_parser - parser = FunctionCallParser(request.tools, tool_call_parser) - tool_call_constraint = parser.get_structure_constraint( - request.tool_choice - ) - - if chat_template_name is None: - openai_compatible_messages = [] - image_data = [] - audio_data = [] - modalities = [] - - # Detect template content format by analyzing the jinja template (cached globally) - global _cached_chat_template, _cached_template_format - current_template = tokenizer_manager.tokenizer.chat_template - - if current_template != _cached_chat_template: - # Template changed or first time - analyze it - _cached_chat_template = current_template - _cached_template_format = detect_template_content_format( - current_template - ) - logger.info( - f"Detected chat template content format: {_cached_template_format}" - ) - - template_content_format = _cached_template_format - - for message in request.messages: - if message.content is None: - message.content = "" - msg_dict = message.model_dump() - - # Process content based on detected template format - processed_msg = process_content_for_template_format( - msg_dict, - template_content_format, - image_data, - audio_data, - modalities, - ) - openai_compatible_messages.append(processed_msg) - - # Handle assistant prefix for continue_final_message - if ( - openai_compatible_messages - and openai_compatible_messages[-1]["role"] == "assistant" - ): - if request.continue_final_message: - # Remove the final assistant message so its content can be continued. - assistant_prefix = openai_compatible_messages[-1]["content"] - openai_compatible_messages = openai_compatible_messages[:-1] - else: - assistant_prefix = None - else: - assistant_prefix = None - - try: - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( - openai_compatible_messages, - tokenize=True, - add_generation_prompt=True, - tools=tools, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), - ) - except: - # This except branch will be triggered when the chosen model - # has a different tools input format that is not compatible - # with openAI's apply_chat_template tool_call format, like Mistral. - tools = [t if "function" in t else {"function": t} for t in tools] - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( - openai_compatible_messages, - tokenize=True, - add_generation_prompt=True, - tools=tools, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), - ) - - if assistant_prefix: - encoded = tokenizer_manager.tokenizer.encode(assistant_prefix) - if ( - encoded - and encoded[0] == tokenizer_manager.tokenizer.bos_token_id - ): - encoded = encoded[1:] - prompt_ids += encoded - if is_multimodal: - prompt = tokenizer_manager.tokenizer.decode(prompt_ids) - stop = request.stop - image_data = image_data if image_data else None - audio_data = audio_data if audio_data else None - modalities = modalities if modalities else [] - else: - conv = generate_chat_conv(request, chat_template_name) - # If we should continue the final assistant message, adjust the conversation. - if ( - request.continue_final_message - and request.messages - and request.messages[-1].role == "assistant" - ): - # Remove the auto-added blank assistant turn, if present. - if conv.messages and conv.messages[-1][1] is None: - conv.messages.pop() - # Rebuild the prompt from the conversation. - prompt = conv.get_prompt() - # Strip any trailing stop tokens or separators that indicate end-of-assistant. - if isinstance(conv.stop_str, list): - for stop_token in conv.stop_str: - if prompt.endswith(stop_token): - prompt = prompt[: -len(stop_token)] - elif isinstance(conv.stop_str, str) and prompt.endswith( - conv.stop_str - ): - prompt = prompt[: -len(conv.stop_str)] - if conv.sep and prompt.endswith(conv.sep): - prompt = prompt[: -len(conv.sep)] - if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2): - prompt = prompt[: -len(conv.sep2)] - else: - prompt = conv.get_prompt() - - image_data = conv.image_data - audio_data = conv.audio_data - modalities = conv.modalities - stop = conv.stop_str or [] if not request.ignore_eos else [] - - if request.stop: - if isinstance(request.stop, str): - stop.append(request.stop) - else: - stop.extend(request.stop) - - if not is_multimodal: - prompt_ids = tokenizer_manager.tokenizer.encode(prompt) - else: - # Use the raw prompt and stop strings if the messages is already a string. - prompt_ids = request.messages - stop = request.stop - image_data = None - audio_data = None - modalities = [] - prompt = request.messages - input_ids.append(prompt_ids) - return_logprobs.append(request.logprobs) - logprob_start_lens.append(-1) - top_logprobs_nums.append(request.top_logprobs or 0) - lora_paths.append(request.lora_path) - prompts.append(prompt) - - sampling_params = { - "temperature": request.temperature, - "max_new_tokens": request.max_tokens or request.max_completion_tokens, - "min_new_tokens": request.min_tokens, - "stop": stop, - "stop_token_ids": request.stop_token_ids, - "top_p": request.top_p, - "top_k": request.top_k, - "min_p": request.min_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "repetition_penalty": request.repetition_penalty, - "regex": request.regex, - "ebnf": request.ebnf, - "n": request.n, - "no_stop_trim": request.no_stop_trim, - "ignore_eos": request.ignore_eos, - "skip_special_tokens": request.skip_special_tokens, - "logit_bias": request.logit_bias, - } - - if request.response_format and request.response_format.type == "json_schema": - sampling_params["json_schema"] = convert_json_schema_to_str( - request.response_format.json_schema.schema_ - ) - elif request.response_format and request.response_format.type == "json_object": - sampling_params["json_schema"] = '{"type": "object"}' - elif ( - request.response_format and request.response_format.type == "structural_tag" - ): - sampling_params["structural_tag"] = convert_json_schema_to_str( - request.response_format.model_dump(by_alias=True) - ) - - # Check if there are already existing output constraints - has_existing_constraints = ( - sampling_params.get("regex") - or sampling_params.get("ebnf") - or sampling_params.get("structural_tag") - or sampling_params.get("json_schema") - ) - - if tool_call_constraint and has_existing_constraints: - logger.warning("Constrained decoding is not compatible with tool calls.") - elif tool_call_constraint: - constraint_type, constraint_value = tool_call_constraint - if constraint_type == "structural_tag": - sampling_params[constraint_type] = convert_json_schema_to_str( - constraint_value.model_dump(by_alias=True) - ) - else: - sampling_params[constraint_type] = constraint_value - - sampling_params_list.append(sampling_params) - - image_data_list.append(image_data) - audio_data_list.append(audio_data) - modalities_list.append(modalities) - return_hidden_states.append(request.return_hidden_states) - if len(all_requests) == 1: - if is_multimodal: - # processor will need text input - prompt_kwargs = {"text": prompts[0]} - else: - if isinstance(input_ids[0], str): - prompt_kwargs = {"text": input_ids[0]} - else: - prompt_kwargs = {"input_ids": input_ids[0]} - sampling_params_list = sampling_params_list[0] - image_data_list = image_data_list[0] - audio_data_list = audio_data_list[0] - return_logprobs = return_logprobs[0] - logprob_start_lens = logprob_start_lens[0] - top_logprobs_nums = top_logprobs_nums[0] - modalities_list = modalities_list[0] - lora_paths = lora_paths[0] - request_ids = request_ids[0] - return_hidden_states = return_hidden_states[0] - else: - if tokenizer_manager.model_config.is_multimodal: - # processor will need text input - prompt_kwargs = {"text": prompts} - else: - if isinstance(input_ids[0], str): - prompt_kwargs = {"text": input_ids} - else: - prompt_kwargs = {"input_ids": input_ids} - - adapted_request = GenerateReqInput( - **prompt_kwargs, - image_data=image_data_list, - audio_data=audio_data_list, - sampling_params=sampling_params_list, - return_logprob=return_logprobs, - logprob_start_len=logprob_start_lens, - top_logprobs_num=top_logprobs_nums, - stream=all_requests[0].stream, - return_text_in_logprobs=True, - rid=request_ids, - modalities=modalities_list, - lora_path=lora_paths, - bootstrap_host=all_requests[0].bootstrap_host, - bootstrap_port=all_requests[0].bootstrap_port, - bootstrap_room=all_requests[0].bootstrap_room, - return_hidden_states=return_hidden_states, - ) - - return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] - - -def v1_chat_generate_response( - request, - ret, - created, - to_file=False, - cache_report=False, - tool_call_parser=None, - reasoning_parser=None, -): - choices = [] - - for idx, ret_item in enumerate(ret): - logprobs = False - if isinstance(request, list) and request[idx].logprobs: - logprobs = True - elif (not isinstance(request, list)) and request.logprobs: - logprobs = True - if logprobs: - logprobs = to_openai_style_logprobs( - output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], - output_top_logprobs=ret_item["meta_info"].get( - "output_top_logprobs", None - ), - ) - token_logprobs = [] - for token_idx, (token, logprob) in enumerate( - zip(logprobs.tokens, logprobs.token_logprobs) - ): - token_bytes = list(token.encode("utf-8")) - top_logprobs = [] - if logprobs.top_logprobs: - for top_token, top_logprob in logprobs.top_logprobs[ - token_idx - ].items(): - top_token_bytes = list(top_token.encode("utf-8")) - top_logprobs.append( - TopLogprob( - token=top_token, - bytes=top_token_bytes, - logprob=top_logprob, - ) - ) - token_logprobs.append( - ChatCompletionTokenLogprob( - token=token, - bytes=token_bytes, - logprob=logprob, - top_logprobs=top_logprobs, - ) - ) - - choice_logprobs = ChoiceLogprobs(content=token_logprobs) - else: - choice_logprobs = None - - if isinstance(request, list) and request[idx].return_hidden_states: - include_hidden_states = True - elif not isinstance(request, list) and request.return_hidden_states: - include_hidden_states = True - else: - include_hidden_states = False - if include_hidden_states and ret_item["meta_info"].get("hidden_states", None): - hidden_states = ret_item["meta_info"]["hidden_states"] - hidden_states = ( - hidden_states[-1] if hidden_states and len(hidden_states) > 1 else [] - ) - else: - hidden_states = None - - finish_reason = ret_item["meta_info"]["finish_reason"] - - tool_calls = None - text = ret_item["text"] - - if isinstance(request, list): - tool_choice = request[idx].tool_choice - tools = request[idx].tools - separate_reasoning = request[idx].separate_reasoning - enable_thinking = _get_enable_thinking_from_request(request[idx]) - else: - tool_choice = request.tool_choice - tools = request.tools - separate_reasoning = request.separate_reasoning - enable_thinking = _get_enable_thinking_from_request(request) - - reasoning_text = None - if reasoning_parser and separate_reasoning and enable_thinking: - try: - parser = ReasoningParser( - model_type=reasoning_parser, stream_reasoning=False - ) - reasoning_text, text = parser.parse_non_stream(text) - except Exception as e: - logger.error(f"Exception: {e}") - return create_error_response( - HTTPStatus.BAD_REQUEST, - "Failed to parse reasoning related info to json format!", - ) - - if tool_choice != "none" and tools: - parser = FunctionCallParser(tools, tool_call_parser) - if parser.has_tool_call(text): - if finish_reason["type"] == "stop": - finish_reason["type"] = "tool_calls" - finish_reason["matched"] = None - try: - text, call_info_list = parser.parse_non_stream(text) - tool_calls = [ - ToolCall( - id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}", - function=FunctionResponse( - name=call_info.name, arguments=call_info.parameters - ), - ) - for call_info in call_info_list - ] - except Exception as e: - logger.error(f"Exception: {e}") - return create_error_response( - HTTPStatus.BAD_REQUEST, - "Failed to parse fc related info to json format!", - ) - - if to_file: - # to make the choice data json serializable - choice_data = { - "index": 0, - "message": { - "role": "assistant", - "content": text if text else None, - "tool_calls": tool_calls, - "reasoning_content": reasoning_text if reasoning_text else None, - }, - "logprobs": choice_logprobs.model_dump() if choice_logprobs else None, - "finish_reason": finish_reason["type"] if finish_reason else None, - "matched_stop": ( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - } - if hidden_states is not None: - choice_data["hidden_states"] = hidden_states - else: - choice_data = ChatCompletionResponseChoice( - index=idx, - message=ChatMessage( - role="assistant", - content=text if text else None, - tool_calls=tool_calls, - reasoning_content=reasoning_text if reasoning_text else None, - ), - logprobs=choice_logprobs, - finish_reason=finish_reason["type"] if finish_reason else None, - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - hidden_states=hidden_states, - ) - - choices.append(choice_data) - - if to_file: - responses = [] - - for i, choice in enumerate(choices): - response = { - "status_code": 200, - "request_id": ret[i]["meta_info"]["id"], - "body": { - # remain the same but if needed we can change that - "id": ret[i]["meta_info"]["id"], - "object": "chat.completion", - "created": created, - "model": ( - request[i].model if isinstance(request, list) else request.model - ), - "choices": choice, - "usage": { - "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"], - "completion_tokens": ret[i]["meta_info"]["completion_tokens"], - "total_tokens": ret[i]["meta_info"]["prompt_tokens"] - + ret[i]["meta_info"]["completion_tokens"], - }, - "system_fingerprint": None, - }, - } - responses.append(response) - return responses - else: - prompt_tokens = sum( - ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) - ) - completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) - cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret) - response = ChatCompletionResponse( - id=ret[0]["meta_info"]["id"], - created=created, - model=request.model, - choices=choices, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - prompt_tokens_details=( - {"cached_tokens": cached_tokens} if cache_report else None - ), - ), - ) - return response - - -async def v1_chat_completions( - tokenizer_manager, raw_request: Request, cache_report=False -): - try: - request_json = await raw_request.json() - except Exception as e: - return create_error_response("Invalid request body, error: ", str(e)) - all_requests = [ChatCompletionRequest(**request_json)] - created = int(time.time()) - adapted_request, request = v1_chat_generate_request( - all_requests, tokenizer_manager, request_ids=[all_requests[0].rid] - ) - - if adapted_request.stream: - parser_dict = {} - reasoning_parser_dict = {} - - async def generate_stream_resp(): - tool_index_previous = -1 - is_firsts = {} - stream_buffers = {} - n_prev_tokens = {} - prompt_tokens = {} - completion_tokens = {} - cached_tokens = {} - hidden_states = {} - try: - async for content in tokenizer_manager.generate_request( - adapted_request, raw_request - ): - index = content.get("index", 0) - text = content["text"] - hidden_states[index] = content["meta_info"].get( - "hidden_states", None - ) or hidden_states.get(index) - - is_first = is_firsts.get(index, True) - stream_buffer = stream_buffers.get(index, "") - n_prev_token = n_prev_tokens.get(index, 0) - - prompt_tokens[index] = content["meta_info"]["prompt_tokens"] - completion_tokens[index] = content["meta_info"]["completion_tokens"] - cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) - if request.logprobs: - logprobs = to_openai_style_logprobs( - output_token_logprobs=content["meta_info"][ - "output_token_logprobs" - ][n_prev_token:], - output_top_logprobs=content["meta_info"].get( - "output_top_logprobs", [] - )[n_prev_token:], - ) - - n_prev_token = len( - content["meta_info"]["output_token_logprobs"] - ) - token_logprobs = [] - for token, logprob in zip( - logprobs.tokens, logprobs.token_logprobs - ): - token_bytes = list(token.encode("utf-8")) - top_logprobs = [] - if logprobs.top_logprobs: - for top_token, top_logprob in logprobs.top_logprobs[ - 0 - ].items(): - top_token_bytes = list(top_token.encode("utf-8")) - top_logprobs.append( - TopLogprob( - token=top_token, - bytes=top_token_bytes, - logprob=top_logprob, - ) - ) - token_logprobs.append( - ChatCompletionTokenLogprob( - token=token, - bytes=token_bytes, - logprob=logprob, - top_logprobs=top_logprobs, - ) - ) - - choice_logprobs = ChoiceLogprobs(content=token_logprobs) - - else: - choice_logprobs = None - - finish_reason = content["meta_info"]["finish_reason"] - finish_reason_type = ( - finish_reason["type"] if finish_reason else None - ) - - if is_first: - # First chunk with role - is_first = False - delta = DeltaMessage(role="assistant") - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=delta, - finish_reason=finish_reason_type, - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - logprobs=choice_logprobs, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - - text = content["text"] - delta = text[len(stream_buffer) :] - new_stream_buffer = stream_buffer + delta - - enable_thinking = _get_enable_thinking_from_request(request) - - if ( - tokenizer_manager.server_args.reasoning_parser - and request.separate_reasoning - and enable_thinking - ): - if index not in reasoning_parser_dict: - reasoning_parser_dict[index] = ReasoningParser( - tokenizer_manager.server_args.reasoning_parser, - request.stream_reasoning, - ) - reasoning_parser = reasoning_parser_dict[index] - reasoning_text, delta = reasoning_parser.parse_stream_chunk( - delta - ) - if reasoning_text: - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage( - reasoning_content=( - reasoning_text if reasoning_text else None - ) - ), - finish_reason=finish_reason_type, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - if (delta and len(delta) == 0) or not delta: - stream_buffers[index] = new_stream_buffer - is_firsts[index] = is_first - n_prev_tokens[index] = n_prev_token - continue - - if request.tool_choice != "none" and request.tools: - if index not in parser_dict: - parser_dict[index] = FunctionCallParser( - tools=request.tools, - tool_call_parser=tokenizer_manager.server_args.tool_call_parser, - ) - parser = parser_dict[index] - - # parse_increment => returns (normal_text, calls) - normal_text, calls = parser.parse_stream_chunk(delta) - - # 1) if there's normal_text, output it as normal content - if normal_text: - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage( - content=normal_text if normal_text else None - ), - finish_reason=finish_reason_type, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - - # 2) if we found calls, we output them as separate chunk(s) - for call_item in calls: - tool_index_current = call_item.tool_index - # transform call_item -> FunctionResponse + ToolCall - if finish_reason_type == "stop": - latest_delta_len = 0 - if isinstance(call_item.parameters, str): - latest_delta_len = len(call_item.parameters) - - expected_call = json.dumps( - parser.detector.prev_tool_call_arr[index].get( - "arguments", {} - ), - ensure_ascii=False, - ) - actual_call = parser.detector.streamed_args_for_tool[ - index - ] - if latest_delta_len > 0: - actual_call = actual_call[:-latest_delta_len] - remaining_call = expected_call.replace( - actual_call, "", 1 - ) - call_item.parameters = remaining_call - - finish_reason_type = "tool_calls" - tool_call = ToolCall( - id=( - f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}" - if tool_index_previous != tool_index_current - else None - ), - index=call_item.tool_index, - function=FunctionResponse( - name=call_item.name, - arguments=call_item.parameters, - ), - ) - tool_index_previous = tool_index_current - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(tool_calls=[tool_call]), - finish_reason=( - None - if request.stream_options - and request.stream_options.include_usage - else finish_reason_type - ), # additional chunk will be return - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - - stream_buffers[index] = new_stream_buffer - is_firsts[index] = is_first - n_prev_tokens[index] = n_prev_token - - else: - # No tool calls => just treat this as normal text - if delta or not ( - request.stream_options - and request.stream_options.include_usage - ): - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(content=delta if delta else None), - finish_reason=( - None - if request.stream_options - and request.stream_options.include_usage - else finish_reason_type - ), - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - logprobs=choice_logprobs, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - stream_buffers[index] = new_stream_buffer - is_firsts[index] = is_first - n_prev_tokens[index] = n_prev_token - if finish_reason_type == "stop" and request.tool_choice != "none": - parser = FunctionCallParser( - tools=request.tools, - tool_call_parser=tokenizer_manager.server_args.tool_call_parser, - ) - if parser.has_tool_call(new_stream_buffer): - # if the stream ends with empty string after tool calls - finish_reason_type = "tool_calls" - - if request.stream_options and request.stream_options.include_usage: - total_prompt_tokens = sum( - tokens - for i, tokens in prompt_tokens.items() - if i % request.n == 0 - ) - total_completion_tokens = sum( - tokens for tokens in completion_tokens.values() - ) - cache_report = tokenizer_manager.server_args.enable_cache_report - if cache_report: - cached_tokens_sum = sum( - tokens for tokens in cached_tokens.values() - ) - prompt_tokens_details = {"cached_tokens": cached_tokens_sum} - else: - prompt_tokens_details = None - usage = UsageInfo( - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens, - prompt_tokens_details=prompt_tokens_details, - ) - - else: - usage = None - if request.return_hidden_states and hidden_states: - for index, choice_hidden_states in hidden_states.items(): - last_token_hidden_states = ( - choice_hidden_states[-1] - if choice_hidden_states and len(choice_hidden_states) > 1 - else [] - ) - hidden_states_chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[ - ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage( - hidden_states=last_token_hidden_states - ), - finish_reason=finish_reason_type, - ) - ], - model=request.model, - ) - yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" - final_usage_chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[ - ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(), - finish_reason=finish_reason_type, - ) - ], - model=request.model, - usage=usage, - ) - yield f"data: {final_usage_chunk.model_dump_json()}\n\n" - except ValueError as e: - error = create_streaming_error_response(str(e)) - yield f"data: {error}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse( - generate_stream_resp(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), - ) - - # Non-streaming response. - try: - ret = await tokenizer_manager.generate_request( - adapted_request, raw_request - ).__anext__() - except ValueError as e: - return create_error_response(str(e)) - if not isinstance(ret, list): - ret = [ret] - - response = v1_chat_generate_response( - request, - ret, - created, - cache_report=tokenizer_manager.server_args.enable_cache_report, - tool_call_parser=tokenizer_manager.server_args.tool_call_parser, - reasoning_parser=tokenizer_manager.server_args.reasoning_parser, - ) - - return response - - -def v1_embedding_request(all_requests, tokenizer_manager): - prompts = [] - sampling_params_list = [] - first_prompt_type = type(all_requests[0].input) - - for request in all_requests: - prompt = request.input - # Check for empty/whitespace string - prompt = _validate_prompt(request.input) - assert ( - type(prompt) is first_prompt_type - ), "All prompts must be of the same type in file input settings" - prompts.append(prompt) - - if len(all_requests) == 1: - prompt = prompts[0] - if isinstance(prompt, str) or isinstance(prompt[0], str): - prompt_kwargs = {"text": prompt} - elif isinstance(prompt, list) and isinstance( - prompt[0], MultimodalEmbeddingInput - ): - texts = [] - images = [] - for item in prompt: - # TODO simply use padding for text, we should use a better way to handle this - texts.append(item.text if item.text is not None else "padding") - images.append(item.image if item.image is not None else None) - generate_prompts = [] - if chat_template_name is not None: - convs = generate_embedding_convs(texts, images, chat_template_name) - for conv in convs: - generate_prompts.append(conv.get_prompt()) - else: - generate_prompts = texts - if len(generate_prompts) == 1: - prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]} - else: - prompt_kwargs = {"text": generate_prompts, "image_data": images} - else: - prompt_kwargs = {"input_ids": prompt} - request_ids = all_requests[0].rid - else: - if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): - prompt_kwargs = {"text": prompts} - elif isinstance(prompts[0], list) and isinstance( - prompts[0][0], MultimodalEmbeddingInput - ): - # TODO: multiple requests - raise NotImplementedError( - "Multiple requests with multimodal inputs are not supported yet" - ) - else: - prompt_kwargs = {"input_ids": prompts} - request_ids = [req.rid for req in all_requests] - - adapted_request = EmbeddingReqInput( - rid=request_ids, - **prompt_kwargs, - ) - - if len(all_requests) == 1: - return adapted_request, all_requests[0] - return adapted_request, all_requests - - -def v1_embedding_response(ret, model_path, to_file=False): - embedding_objects = [] - prompt_tokens = 0 - for idx, ret_item in enumerate(ret): - embedding_objects.append( - EmbeddingObject( - embedding=ret[idx]["embedding"], - index=idx, - ) - ) - prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"] - - return EmbeddingResponse( - data=embedding_objects, - model=model_path, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - total_tokens=prompt_tokens, - ), - ) - - -async def v1_embeddings(tokenizer_manager, raw_request: Request): - try: - request_json = await raw_request.json() - except Exception as e: - return create_error_response("Invalid request body, error: ", str(e)) - all_requests = [EmbeddingRequest(**request_json)] - adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager) - - try: - ret = await tokenizer_manager.generate_request( - adapted_request, raw_request - ).__anext__() - except ValueError as e: - return create_error_response(str(e)) - - if not isinstance(ret, list): - ret = [ret] - - response = v1_embedding_response(ret, tokenizer_manager.model_path) - - return response - - -def v1_rerank_request(obj: V1RerankReqInput): - if obj.query is None: - raise ValueError("query is required") - if obj.documents is None or len(obj.documents) == 0: - raise ValueError("documents is required") - - pairs = [] - for doc in obj.documents: - pairs.append([obj.query, doc]) - - adapted_request = EmbeddingReqInput( - text=pairs, - is_cross_encoder_request=True, - ) - - return adapted_request - - -def v1_rerank_response(ret, obj: V1RerankReqInput): - - response = [] - for idx, ret_item in enumerate(ret): - response.append( - RerankResponse( - score=ret[idx]["embedding"], - document=obj.documents[idx], - index=idx, - meta_info=ret[idx]["meta_info"], - ) - ) - - response.sort(key=lambda x: x.score, reverse=True) - - return response - - -async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request): - adapted_request = v1_rerank_request(obj) - - try: - ret = await tokenizer_manager.generate_request( - adapted_request, raw_request - ).__anext__() - - except ValueError as e: - return create_error_response(str(e)) - - if not isinstance(ret, list): - ret = [ret] - - response = v1_rerank_response( - ret, - obj, - ) - - return response - - -def to_openai_style_logprobs( - input_token_logprobs=None, - output_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=None, -): - ret_logprobs = LogProbs() - - def append_token_logprobs(token_logprobs): - for logprob, _, token_text in token_logprobs: - ret_logprobs.tokens.append(token_text) - ret_logprobs.token_logprobs.append(logprob) - - # Not supported yet - ret_logprobs.text_offset.append(-1) - - def append_top_logprobs(top_logprobs): - for tokens in top_logprobs: - if tokens is not None: - ret_logprobs.top_logprobs.append( - {token[2]: token[0] for token in tokens} - ) - else: - ret_logprobs.top_logprobs.append(None) - - if input_token_logprobs is not None: - append_token_logprobs(input_token_logprobs) - if output_token_logprobs is not None: - append_token_logprobs(output_token_logprobs) - if input_top_logprobs is not None: - append_top_logprobs(input_top_logprobs) - if output_top_logprobs is not None: - append_top_logprobs(output_top_logprobs) - - return ret_logprobs - - -async def v1_score(tokenizer_manager, raw_request): - try: - # Parse request - request_data = await raw_request.json() - request = ScoringRequest(**request_data) - - # Use tokenizer_manager's score_request method directly - scores = await tokenizer_manager.score_request( - query=request.query, - items=request.items, - label_token_ids=request.label_token_ids, - apply_softmax=request.apply_softmax, - item_first=request.item_first, - request=request, - ) - - # Create response with just the scores, without usage info - response = ScoringResponse( - scores=scores, - model=request.model, - ) - return response - - except Exception as e: - logger.error(f"Error in v1_score: {str(e)}") - return create_error_response(str(e)) diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py deleted file mode 100644 index 71153b912..000000000 --- a/python/sglang/srt/openai_api/protocol.py +++ /dev/null @@ -1,551 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Pydantic models for OpenAI API protocol""" - -import time -from typing import Dict, List, Optional, Union - -from pydantic import BaseModel, Field, model_serializer, root_validator -from typing_extensions import Literal - - -class ModelCard(BaseModel): - """Model cards.""" - - id: str - object: str = "model" - created: int = Field(default_factory=lambda: int(time.time())) - owned_by: str = "sglang" - root: Optional[str] = None - max_model_len: Optional[int] = None - - -class ModelList(BaseModel): - """Model list consists of model cards.""" - - object: str = "list" - data: List[ModelCard] = Field(default_factory=list) - - -class ErrorResponse(BaseModel): - object: str = "error" - message: str - type: str - param: Optional[str] = None - code: int - - -class LogProbs(BaseModel): - text_offset: List[int] = Field(default_factory=list) - token_logprobs: List[Optional[float]] = Field(default_factory=list) - tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) - - -class TopLogprob(BaseModel): - token: str - bytes: List[int] - logprob: float - - -class ChatCompletionTokenLogprob(BaseModel): - token: str - bytes: List[int] - logprob: float - top_logprobs: List[TopLogprob] - - -class ChoiceLogprobs(BaseModel): - # build for v1/chat/completions response - content: List[ChatCompletionTokenLogprob] - - -class UsageInfo(BaseModel): - prompt_tokens: int = 0 - total_tokens: int = 0 - completion_tokens: Optional[int] = 0 - # only used to return cached tokens when --enable-cache-report is set - prompt_tokens_details: Optional[Dict[str, int]] = None - - -class StreamOptions(BaseModel): - include_usage: Optional[bool] = False - - -class JsonSchemaResponseFormat(BaseModel): - name: str - description: Optional[str] = None - # use alias to workaround pydantic conflict - schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) - strict: Optional[bool] = False - - -class FileRequest(BaseModel): - # https://platform.openai.com/docs/api-reference/files/create - file: bytes # The File object (not file name) to be uploaded - purpose: str = ( - "batch" # The intended purpose of the uploaded file, default is "batch" - ) - - -class FileResponse(BaseModel): - id: str - object: str = "file" - bytes: int - created_at: int - filename: str - purpose: str - - -class FileDeleteResponse(BaseModel): - id: str - object: str = "file" - deleted: bool - - -class BatchRequest(BaseModel): - input_file_id: ( - str # The ID of an uploaded file that contains requests for the new batch - ) - endpoint: str # The endpoint to be used for all requests in the batch - completion_window: str # The time frame within which the batch should be processed - metadata: Optional[dict] = None # Optional custom metadata for the batch - - -class BatchResponse(BaseModel): - id: str - object: str = "batch" - endpoint: str - errors: Optional[dict] = None - input_file_id: str - completion_window: str - status: str = "validating" - output_file_id: Optional[str] = None - error_file_id: Optional[str] = None - created_at: int - in_progress_at: Optional[int] = None - expires_at: Optional[int] = None - finalizing_at: Optional[int] = None - completed_at: Optional[int] = None - failed_at: Optional[int] = None - expired_at: Optional[int] = None - cancelling_at: Optional[int] = None - cancelled_at: Optional[int] = None - request_counts: Optional[dict] = None - metadata: Optional[dict] = None - - -class CompletionRequest(BaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/completions/create - model: str - prompt: Union[List[int], List[List[int]], str, List[str]] - best_of: Optional[int] = None - echo: bool = False - frequency_penalty: float = 0.0 - logit_bias: Optional[Dict[str, float]] = None - logprobs: Optional[int] = None - max_tokens: int = 16 - n: int = 1 - presence_penalty: float = 0.0 - seed: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None - stream: bool = False - stream_options: Optional[StreamOptions] = None - suffix: Optional[str] = None - temperature: float = 1.0 - top_p: float = 1.0 - user: Optional[str] = None - - # Extra parameters for SRT backend only and will be ignored by OpenAI models. - top_k: int = -1 - min_p: float = 0.0 - min_tokens: int = 0 - json_schema: Optional[str] = None - regex: Optional[str] = None - ebnf: Optional[str] = None - repetition_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = None - no_stop_trim: bool = False - ignore_eos: bool = False - skip_special_tokens: bool = True - lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - session_params: Optional[Dict] = None - return_hidden_states: Optional[bool] = False - - # For PD disaggregation - bootstrap_host: Optional[str] = None - bootstrap_port: Optional[int] = None - bootstrap_room: Optional[int] = None - - -class CompletionResponseChoice(BaseModel): - index: int - text: str - logprobs: Optional[LogProbs] = None - finish_reason: Literal["stop", "length", "content_filter", "abort"] - matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None - - @model_serializer - def _serialize(self): - return exclude_if_none(self, ["hidden_states"]) - - -class CompletionResponse(BaseModel): - id: str - object: str = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[CompletionResponseChoice] - usage: UsageInfo - - -class CompletionResponseStreamChoice(BaseModel): - index: int - text: str - logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None - matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None - - @model_serializer - def _serialize(self): - return exclude_if_none(self, ["hidden_states"]) - - -class CompletionStreamResponse(BaseModel): - id: str - object: str = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[CompletionResponseStreamChoice] - usage: Optional[UsageInfo] = None - - -class ChatCompletionMessageContentTextPart(BaseModel): - type: Literal["text"] - text: str - - -class ChatCompletionMessageContentImageURL(BaseModel): - url: str - detail: Optional[Literal["auto", "low", "high"]] = "auto" - - -class ChatCompletionMessageContentAudioURL(BaseModel): - url: str - - -class ChatCompletionMessageContentImagePart(BaseModel): - type: Literal["image_url"] - image_url: ChatCompletionMessageContentImageURL - modalities: Optional[Literal["image", "multi-images", "video"]] = "image" - - -class ChatCompletionMessageContentAudioPart(BaseModel): - type: Literal["audio_url"] - audio_url: ChatCompletionMessageContentAudioURL - - -ChatCompletionMessageContentPart = Union[ - ChatCompletionMessageContentTextPart, - ChatCompletionMessageContentImagePart, - ChatCompletionMessageContentAudioPart, -] - - -class FunctionResponse(BaseModel): - """Function response.""" - - name: Optional[str] = None - arguments: Optional[str] = None - - -class ToolCall(BaseModel): - """Tool call response.""" - - id: Optional[str] = None - index: Optional[int] = None - type: Literal["function"] = "function" - function: FunctionResponse - - -class ChatCompletionMessageGenericParam(BaseModel): - role: Literal["system", "assistant", "tool"] - content: Union[str, List[ChatCompletionMessageContentTextPart], None] - tool_call_id: Optional[str] = None - name: Optional[str] = None - reasoning_content: Optional[str] = None - tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) - - -class ChatCompletionMessageUserParam(BaseModel): - role: Literal["user"] - content: Union[str, List[ChatCompletionMessageContentPart]] - - -ChatCompletionMessageParam = Union[ - ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam -] - - -class ResponseFormat(BaseModel): - type: Literal["text", "json_object", "json_schema"] - json_schema: Optional[JsonSchemaResponseFormat] = None - - -class StructuresResponseFormat(BaseModel): - begin: str - schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) - end: str - - -class StructuralTagResponseFormat(BaseModel): - type: Literal["structural_tag"] - structures: List[StructuresResponseFormat] - triggers: List[str] - - -class Function(BaseModel): - """Function descriptions.""" - - description: Optional[str] = Field(default=None, examples=[None]) - name: Optional[str] = None - parameters: Optional[object] = None - strict: bool = False - - -class Tool(BaseModel): - """Function wrapper.""" - - type: str = Field(default="function", examples=["function"]) - function: Function - - -class ToolChoiceFuncName(BaseModel): - """The name of tool choice function.""" - - name: Optional[str] = None - - -class ToolChoice(BaseModel): - """The tool choice definition.""" - - function: ToolChoiceFuncName - type: Literal["function"] = Field(default="function", examples=["function"]) - - -class ChatCompletionRequest(BaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/chat/create - messages: List[ChatCompletionMessageParam] - model: str - frequency_penalty: float = 0.0 - logit_bias: Optional[Dict[str, float]] = None - logprobs: bool = False - top_logprobs: Optional[int] = None - max_tokens: Optional[int] = Field( - default=None, - deprecated="max_tokens is deprecated in favor of the max_completion_tokens field", - description="The maximum number of tokens that can be generated in the chat completion. ", - ) - max_completion_tokens: Optional[int] = Field( - default=None, - description="The maximum number of completion tokens for a chat completion request, " - "including visible output tokens and reasoning tokens. Input tokens are not included. ", - ) - n: int = 1 - presence_penalty: float = 0.0 - response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None - seed: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None - stream: bool = False - stream_options: Optional[StreamOptions] = None - temperature: float = 0.7 - top_p: float = 1.0 - user: Optional[str] = None - tools: Optional[List[Tool]] = Field(default=None, examples=[None]) - tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( - default="auto", examples=["none"] - ) # noqa - - @root_validator(pre=True) - def set_tool_choice_default(cls, values): - if values.get("tool_choice") is None: - if values.get("tools") is None: - values["tool_choice"] = "none" - else: - values["tool_choice"] = "auto" - return values - - # Extra parameters for SRT backend only and will be ignored by OpenAI models. - top_k: int = -1 - min_p: float = 0.0 - min_tokens: int = 0 - regex: Optional[str] = None - ebnf: Optional[str] = None - repetition_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = None - no_stop_trim: bool = False - ignore_eos: bool = False - continue_final_message: bool = False - skip_special_tokens: bool = True - lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - session_params: Optional[Dict] = None - separate_reasoning: bool = True - stream_reasoning: bool = True - chat_template_kwargs: Optional[Dict] = None - - # The request id. - rid: Optional[str] = None - - # For PD disaggregation - bootstrap_host: Optional[str] = None - bootstrap_port: Optional[int] = None - bootstrap_room: Optional[int] = None - - # Hidden States - return_hidden_states: Optional[bool] = False - - -class ChatMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None - reasoning_content: Optional[str] = None - tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) - - -class ChatCompletionResponseChoice(BaseModel): - index: int - message: ChatMessage - logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None - finish_reason: Literal[ - "stop", "length", "tool_calls", "content_filter", "function_call", "abort" - ] - matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None - - @model_serializer - def _serialize(self): - return exclude_if_none(self, ["hidden_states"]) - - -class ChatCompletionResponse(BaseModel): - id: str - object: str = "chat.completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[ChatCompletionResponseChoice] - usage: UsageInfo - - -class DeltaMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None - reasoning_content: Optional[str] = None - tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) - hidden_states: Optional[object] = None - - @model_serializer - def _serialize(self): - return exclude_if_none(self, ["hidden_states"]) - - -class ChatCompletionResponseStreamChoice(BaseModel): - index: int - delta: DeltaMessage - logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None - finish_reason: Optional[ - Literal["stop", "length", "tool_calls", "content_filter", "function_call"] - ] = None - matched_stop: Union[None, int, str] = None - - -class ChatCompletionStreamResponse(BaseModel): - id: str - object: str = "chat.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[ChatCompletionResponseStreamChoice] - usage: Optional[UsageInfo] = None - - -class MultimodalEmbeddingInput(BaseModel): - text: Optional[str] = None - image: Optional[str] = None - - -class EmbeddingRequest(BaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/embeddings/create - input: Union[ - List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput] - ] - model: str - encoding_format: str = "float" - dimensions: int = None - user: Optional[str] = None - - # The request id. - rid: Optional[str] = None - - -class EmbeddingObject(BaseModel): - embedding: List[float] - index: int - object: str = "embedding" - - -class EmbeddingResponse(BaseModel): - data: List[EmbeddingObject] - model: str - object: str = "list" - usage: Optional[UsageInfo] = None - - -class ScoringRequest(BaseModel): - query: Optional[Union[str, List[int]]] = ( - None # Query text or pre-tokenized token IDs - ) - items: Optional[Union[str, List[str], List[List[int]]]] = ( - None # Item text(s) or pre-tokenized token IDs - ) - label_token_ids: Optional[List[int]] = ( - None # Token IDs to compute probabilities for - ) - apply_softmax: bool = False - item_first: bool = False - model: str - - -class ScoringResponse(BaseModel): - scores: List[ - List[float] - ] # List of lists of probabilities, each in the order of label_token_ids - model: str - usage: Optional[UsageInfo] = None - object: str = "scoring" - - -class RerankResponse(BaseModel): - score: float - document: str - index: int - meta_info: Optional[dict] = None - - -def exclude_if_none(obj, field_names: List[str]): - omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names} - return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None} diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py index d8bf8f09c..746445bd9 100644 --- a/python/sglang/srt/reasoning_parser.py +++ b/python/sglang/srt/reasoning_parser.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple, Type class StreamingParseResult: @@ -32,17 +32,26 @@ class BaseReasoningFormatDetector: One-time parsing: Detects and parses reasoning sections in the provided text. Returns both reasoning content and normal text separately. """ - text = text.replace(self.think_start_token, "").strip() - if self.think_end_token not in text: + in_reasoning = self._in_reasoning or text.startswith(self.think_start_token) + + if not in_reasoning: + return StreamingParseResult(normal_text=text) + + # The text is considered to be in a reasoning block. + processed_text = text.replace(self.think_start_token, "").strip() + + if self.think_end_token not in processed_text: # Assume reasoning was truncated before `` token - return StreamingParseResult(reasoning_text=text) + return StreamingParseResult(reasoning_text=processed_text) # Extract reasoning content - splits = text.split(self.think_end_token, maxsplit=1) + splits = processed_text.split(self.think_end_token, maxsplit=1) reasoning_text = splits[0] - text = splits[1].strip() + normal_text = splits[1].strip() - return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text) + return StreamingParseResult( + normal_text=normal_text, reasoning_text=reasoning_text + ) def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: """ @@ -61,6 +70,7 @@ class BaseReasoningFormatDetector: if not self.stripped_think_start and self.think_start_token in current_text: current_text = current_text.replace(self.think_start_token, "") self.stripped_think_start = True + self._in_reasoning = True # Handle end of reasoning block if self._in_reasoning and self.think_end_token in current_text: @@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector): """ def __init__(self, stream_reasoning: bool = True): - # Qwen3 is assumed to be reasoning until `` token + # Qwen3 won't be in reasoning mode when user passes `enable_thinking=False` super().__init__( "", "", - force_reasoning=True, + force_reasoning=False, stream_reasoning=stream_reasoning, ) @@ -151,12 +161,12 @@ class ReasoningParser: If True, streams reasoning content as it arrives. """ - DetectorMap: Dict[str, BaseReasoningFormatDetector] = { + DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = { "deepseek-r1": DeepSeekR1Detector, "qwen3": Qwen3Detector, } - def __init__(self, model_type: str = None, stream_reasoning: bool = True): + def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True): if not model_type: raise ValueError("Model type must be specified") diff --git a/test/srt/openai/conftest.py b/test/srt/openai/conftest.py deleted file mode 100644 index ed88d624b..000000000 --- a/test/srt/openai/conftest.py +++ /dev/null @@ -1,87 +0,0 @@ -# sglang/test/srt/openai/conftest.py -import os -import socket -import subprocess -import sys -import tempfile -import time -from contextlib import closing -from typing import Generator - -import pytest -import requests - -from sglang.srt.utils import kill_process_tree # reuse SGLang helper -from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST - -SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server" -DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST -STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120)) - - -def _pick_free_port() -> int: - with closing(socket.socket()) as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None: - start = time.perf_counter() - while time.perf_counter() - start < timeout: - if proc.poll() is not None: # crashed - raise RuntimeError("api_server terminated prematurely") - try: - if requests.get(f"{base}/health", timeout=1).status_code == 200: - return - except requests.RequestException: - pass - time.sleep(0.4) - raise RuntimeError("api_server readiness probe timed out") - - -def launch_openai_server(model: str = DEFAULT_MODEL, **kw): - """Spawn the draft OpenAI-compatible server and wait until it's ready.""" - port = _pick_free_port() - cmd = [ - sys.executable, - "-m", - SERVER_MODULE, - "--model-path", - model, - "--host", - "127.0.0.1", - "--port", - str(port), - *map(str, kw.get("args", [])), - ] - env = {**os.environ, **kw.get("env", {})} - - # Write logs to a temp file so the child never blocks on a full pipe. - log_file = tempfile.NamedTemporaryFile("w+", delete=False) - proc = subprocess.Popen( - cmd, - env=env, - stdout=log_file, - stderr=subprocess.STDOUT, - text=True, - ) - - base = f"http://127.0.0.1:{port}" - try: - _wait_until_healthy(proc, base, STARTUP_TIMEOUT) - except Exception as e: - proc.terminate() - proc.wait(5) - log_file.seek(0) - print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr) - raise e - return proc, base, log_file - - -@pytest.fixture(scope="session") -def openai_server() -> Generator[str, None, None]: - """PyTest fixture that provides the server's base URL and cleans up.""" - proc, base, log_file = launch_openai_server() - yield base - kill_process_tree(proc.pid) - log_file.close() diff --git a/test/srt/openai/test_protocol.py b/test/srt/openai/test_protocol.py index 06260024a..65b4e4c50 100644 --- a/test/srt/openai/test_protocol.py +++ b/test/srt/openai/test_protocol.py @@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import ( class TestModelCard(unittest.TestCase): """Test ModelCard protocol model""" - def test_basic_model_card_creation(self): - """Test basic model card creation with required fields""" - card = ModelCard(id="test-model") - self.assertEqual(card.id, "test-model") - self.assertEqual(card.object, "model") - self.assertEqual(card.owned_by, "sglang") - self.assertIsInstance(card.created, int) - self.assertIsNone(card.root) - self.assertIsNone(card.max_model_len) - - def test_model_card_with_optional_fields(self): - """Test model card with optional fields""" - card = ModelCard( - id="test-model", - root="/path/to/model", - max_model_len=2048, - created=1234567890, - ) - self.assertEqual(card.id, "test-model") - self.assertEqual(card.root, "/path/to/model") - self.assertEqual(card.max_model_len, 2048) - self.assertEqual(card.created, 1234567890) - def test_model_card_serialization(self): """Test model card JSON serialization""" card = ModelCard(id="test-model", max_model_len=4096) @@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase): self.assertEqual(model_list.data[1].id, "model-2") -class TestErrorResponse(unittest.TestCase): - """Test ErrorResponse protocol model""" - - def test_basic_error_response(self): - """Test basic error response creation""" - error = ErrorResponse( - message="Invalid request", type="BadRequestError", code=400 - ) - self.assertEqual(error.object, "error") - self.assertEqual(error.message, "Invalid request") - self.assertEqual(error.type, "BadRequestError") - self.assertEqual(error.code, 400) - self.assertIsNone(error.param) - - def test_error_response_with_param(self): - """Test error response with parameter""" - error = ErrorResponse( - message="Invalid temperature", - type="ValidationError", - code=422, - param="temperature", - ) - self.assertEqual(error.param, "temperature") - - -class TestUsageInfo(unittest.TestCase): - """Test UsageInfo protocol model""" - - def test_basic_usage_info(self): - """Test basic usage info creation""" - usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30) - self.assertEqual(usage.prompt_tokens, 10) - self.assertEqual(usage.completion_tokens, 20) - self.assertEqual(usage.total_tokens, 30) - self.assertIsNone(usage.prompt_tokens_details) - - def test_usage_info_with_cache_details(self): - """Test usage info with cache details""" - usage = UsageInfo( - prompt_tokens=10, - completion_tokens=20, - total_tokens=30, - prompt_tokens_details={"cached_tokens": 5}, - ) - self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5}) - - class TestCompletionRequest(unittest.TestCase): """Test CompletionRequest protocol model""" @@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase): self.assertFalse(request.stream) # default self.assertFalse(request.echo) # default - def test_completion_request_with_options(self): - """Test completion request with various options""" - request = CompletionRequest( - model="test-model", - prompt=["Hello", "world"], - max_tokens=100, - temperature=0.7, - top_p=0.9, - n=2, - stream=True, - echo=True, - stop=[".", "!"], - logprobs=5, - ) - self.assertEqual(request.prompt, ["Hello", "world"]) - self.assertEqual(request.max_tokens, 100) - self.assertEqual(request.temperature, 0.7) - self.assertEqual(request.top_p, 0.9) - self.assertEqual(request.n, 2) - self.assertTrue(request.stream) - self.assertTrue(request.echo) - self.assertEqual(request.stop, [".", "!"]) - self.assertEqual(request.logprobs, 5) - def test_completion_request_sglang_extensions(self): """Test completion request with SGLang-specific extensions""" request = CompletionRequest( @@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase): CompletionRequest(model="test-model") # missing prompt -class TestCompletionResponse(unittest.TestCase): - """Test CompletionResponse protocol model""" - - def test_basic_completion_response(self): - """Test basic completion response""" - choice = CompletionResponseChoice( - index=0, text="Hello world!", finish_reason="stop" - ) - usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5) - response = CompletionResponse( - id="test-id", model="test-model", choices=[choice], usage=usage - ) - self.assertEqual(response.id, "test-id") - self.assertEqual(response.object, "text_completion") - self.assertEqual(response.model, "test-model") - self.assertEqual(len(response.choices), 1) - self.assertEqual(response.choices[0].text, "Hello world!") - self.assertEqual(response.usage.total_tokens, 5) - - class TestChatCompletionRequest(unittest.TestCase): """Test ChatCompletionRequest protocol model""" @@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase): self.assertFalse(request.stream) # default self.assertEqual(request.tool_choice, "none") # default when no tools - def test_chat_completion_with_multimodal_content(self): - """Test chat completion with multimodal content""" - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."}, - }, - ], - } - ] - request = ChatCompletionRequest(model="test-model", messages=messages) - self.assertEqual(len(request.messages[0].content), 2) - self.assertEqual(request.messages[0].content[0].type, "text") - self.assertEqual(request.messages[0].content[1].type, "image_url") - - def test_chat_completion_with_tools(self): - """Test chat completion with tools""" - messages = [{"role": "user", "content": "What's the weather?"}] - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather information", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - }, - }, - } - ] - request = ChatCompletionRequest( - model="test-model", messages=messages, tools=tools - ) - self.assertEqual(len(request.tools), 1) - self.assertEqual(request.tools[0].function.name, "get_weather") - self.assertEqual(request.tool_choice, "auto") # default when tools present - def test_chat_completion_tool_choice_validation(self): """Test tool choice validation logic""" messages = [{"role": "user", "content": "Hello"}] @@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase): self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"}) -class TestChatCompletionResponse(unittest.TestCase): - """Test ChatCompletionResponse protocol model""" - - def test_basic_chat_completion_response(self): - """Test basic chat completion response""" - message = ChatMessage(role="assistant", content="Hello there!") - choice = ChatCompletionResponseChoice( - index=0, message=message, finish_reason="stop" - ) - usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5) - response = ChatCompletionResponse( - id="test-id", model="test-model", choices=[choice], usage=usage - ) - self.assertEqual(response.id, "test-id") - self.assertEqual(response.object, "chat.completion") - self.assertEqual(response.model, "test-model") - self.assertEqual(len(response.choices), 1) - self.assertEqual(response.choices[0].message.content, "Hello there!") - - def test_chat_completion_response_with_tool_calls(self): - """Test chat completion response with tool calls""" - tool_call = ToolCall( - id="call_123", - function=FunctionResponse( - name="get_weather", arguments='{"location": "San Francisco"}' - ), - ) - message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call]) - choice = ChatCompletionResponseChoice( - index=0, message=message, finish_reason="tool_calls" - ) - usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15) - response = ChatCompletionResponse( - id="test-id", model="test-model", choices=[choice], usage=usage - ) - self.assertEqual( - response.choices[0].message.tool_calls[0].function.name, "get_weather" - ) - self.assertEqual(response.choices[0].finish_reason, "tool_calls") - - -class TestEmbeddingRequest(unittest.TestCase): - """Test EmbeddingRequest protocol model""" - - def test_basic_embedding_request(self): - """Test basic embedding request""" - request = EmbeddingRequest(model="test-model", input="Hello world") - self.assertEqual(request.model, "test-model") - self.assertEqual(request.input, "Hello world") - self.assertEqual(request.encoding_format, "float") # default - self.assertIsNone(request.dimensions) # default - - def test_embedding_request_with_list_input(self): - """Test embedding request with list input""" - request = EmbeddingRequest( - model="test-model", input=["Hello", "world"], dimensions=512 - ) - self.assertEqual(request.input, ["Hello", "world"]) - self.assertEqual(request.dimensions, 512) - - def test_multimodal_embedding_request(self): - """Test multimodal embedding request""" - multimodal_input = [ - MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), - MultimodalEmbeddingInput(text="World", image=None), - ] - request = EmbeddingRequest(model="test-model", input=multimodal_input) - self.assertEqual(len(request.input), 2) - self.assertEqual(request.input[0].text, "Hello") - self.assertEqual(request.input[0].image, "base64_image_data") - self.assertEqual(request.input[1].text, "World") - self.assertIsNone(request.input[1].image) - - -class TestEmbeddingResponse(unittest.TestCase): - """Test EmbeddingResponse protocol model""" - - def test_basic_embedding_response(self): - """Test basic embedding response""" - embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0) - usage = UsageInfo(prompt_tokens=3, total_tokens=3) - response = EmbeddingResponse( - data=[embedding_obj], model="test-model", usage=usage - ) - self.assertEqual(response.object, "list") - self.assertEqual(len(response.data), 1) - self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3]) - self.assertEqual(response.data[0].index, 0) - self.assertEqual(response.usage.prompt_tokens, 3) - - -class TestScoringRequest(unittest.TestCase): - """Test ScoringRequest protocol model""" - - def test_basic_scoring_request(self): - """Test basic scoring request""" - request = ScoringRequest( - model="test-model", query="Hello", items=["World", "Earth"] - ) - self.assertEqual(request.model, "test-model") - self.assertEqual(request.query, "Hello") - self.assertEqual(request.items, ["World", "Earth"]) - self.assertFalse(request.apply_softmax) # default - self.assertFalse(request.item_first) # default - - def test_scoring_request_with_token_ids(self): - """Test scoring request with token IDs""" - request = ScoringRequest( - model="test-model", - query=[1, 2, 3], - items=[[4, 5], [6, 7]], - label_token_ids=[8, 9], - apply_softmax=True, - item_first=True, - ) - self.assertEqual(request.query, [1, 2, 3]) - self.assertEqual(request.items, [[4, 5], [6, 7]]) - self.assertEqual(request.label_token_ids, [8, 9]) - self.assertTrue(request.apply_softmax) - self.assertTrue(request.item_first) - - -class TestScoringResponse(unittest.TestCase): - """Test ScoringResponse protocol model""" - - def test_basic_scoring_response(self): - """Test basic scoring response""" - response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model") - self.assertEqual(response.object, "scoring") - self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]]) - self.assertEqual(response.model, "test-model") - self.assertIsNone(response.usage) # default - - -class TestFileOperations(unittest.TestCase): - """Test file operation protocol models""" - - def test_file_request(self): - """Test file request model""" - file_data = b"test file content" - request = FileRequest(file=file_data, purpose="batch") - self.assertEqual(request.file, file_data) - self.assertEqual(request.purpose, "batch") - - def test_file_response(self): - """Test file response model""" - response = FileResponse( - id="file-123", - bytes=1024, - created_at=1234567890, - filename="test.jsonl", - purpose="batch", - ) - self.assertEqual(response.id, "file-123") - self.assertEqual(response.object, "file") - self.assertEqual(response.bytes, 1024) - self.assertEqual(response.filename, "test.jsonl") - - def test_file_delete_response(self): - """Test file delete response model""" - response = FileDeleteResponse(id="file-123", deleted=True) - self.assertEqual(response.id, "file-123") - self.assertEqual(response.object, "file") - self.assertTrue(response.deleted) - - -class TestBatchOperations(unittest.TestCase): - """Test batch operation protocol models""" - - def test_batch_request(self): - """Test batch request model""" - request = BatchRequest( - input_file_id="file-123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={"custom": "value"}, - ) - self.assertEqual(request.input_file_id, "file-123") - self.assertEqual(request.endpoint, "/v1/chat/completions") - self.assertEqual(request.completion_window, "24h") - self.assertEqual(request.metadata, {"custom": "value"}) - - def test_batch_response(self): - """Test batch response model""" - response = BatchResponse( - id="batch-123", - endpoint="/v1/chat/completions", - input_file_id="file-123", - completion_window="24h", - created_at=1234567890, - ) - self.assertEqual(response.id, "batch-123") - self.assertEqual(response.object, "batch") - self.assertEqual(response.status, "validating") # default - self.assertEqual(response.endpoint, "/v1/chat/completions") - - -class TestResponseFormats(unittest.TestCase): - """Test response format protocol models""" - - def test_basic_response_format(self): - """Test basic response format""" - format_obj = ResponseFormat(type="json_object") - self.assertEqual(format_obj.type, "json_object") - self.assertIsNone(format_obj.json_schema) - - def test_json_schema_response_format(self): - """Test JSON schema response format""" - schema = {"type": "object", "properties": {"name": {"type": "string"}}} - json_schema = JsonSchemaResponseFormat( - name="person_schema", description="Person schema", schema=schema - ) - format_obj = ResponseFormat(type="json_schema", json_schema=json_schema) - self.assertEqual(format_obj.type, "json_schema") - self.assertEqual(format_obj.json_schema.name, "person_schema") - self.assertEqual(format_obj.json_schema.schema_, schema) - - def test_structural_tag_response_format(self): - """Test structural tag response format""" - structures = [ - { - "begin": "", - "schema_": {"type": "string"}, - "end": "", - } - ] - format_obj = StructuralTagResponseFormat( - type="structural_tag", structures=structures, triggers=["think"] - ) - self.assertEqual(format_obj.type, "structural_tag") - self.assertEqual(len(format_obj.structures), 1) - self.assertEqual(format_obj.triggers, ["think"]) - - -class TestLogProbs(unittest.TestCase): - """Test LogProbs protocol models""" - - def test_basic_logprobs(self): - """Test basic LogProbs model""" - logprobs = LogProbs( - text_offset=[0, 5, 11], - token_logprobs=[-0.1, -0.2, -0.3], - tokens=["Hello", " ", "world"], - top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}], - ) - self.assertEqual(len(logprobs.tokens), 3) - self.assertEqual(logprobs.tokens, ["Hello", " ", "world"]) - self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3]) - - def test_choice_logprobs(self): - """Test ChoiceLogprobs model""" - token_logprob = ChatCompletionTokenLogprob( - token="Hello", - bytes=[72, 101, 108, 108, 111], - logprob=-0.1, - top_logprobs=[ - TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1) - ], - ) - choice_logprobs = ChoiceLogprobs(content=[token_logprob]) - self.assertEqual(len(choice_logprobs.content), 1) - self.assertEqual(choice_logprobs.content[0].token, "Hello") - - -class TestStreamingModels(unittest.TestCase): - """Test streaming response models""" - - def test_stream_options(self): - """Test StreamOptions model""" - options = StreamOptions(include_usage=True) - self.assertTrue(options.include_usage) - - def test_chat_completion_stream_response(self): - """Test ChatCompletionStreamResponse model""" - delta = DeltaMessage(role="assistant", content="Hello") - choice = ChatCompletionResponseStreamChoice(index=0, delta=delta) - response = ChatCompletionStreamResponse( - id="test-id", model="test-model", choices=[choice] - ) - self.assertEqual(response.object, "chat.completion.chunk") - self.assertEqual(response.choices[0].delta.content, "Hello") - - class TestModelSerialization(unittest.TestCase): """Test model serialization with hidden states""" @@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase): class TestValidationEdgeCases(unittest.TestCase): """Test edge cases and validation scenarios""" - def test_empty_messages_validation(self): - """Test validation with empty messages""" - with self.assertRaises(ValidationError): - ChatCompletionRequest(model="test-model", messages=[]) - def test_invalid_tool_choice_type(self): """Test invalid tool choice type""" messages = [{"role": "user", "content": "Hello"}] @@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase): with self.assertRaises(ValidationError): CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1) - def test_invalid_temperature_range(self): - """Test invalid temperature values""" - # Note: The current protocol doesn't enforce temperature range, - # but this test documents expected behavior - request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0) - self.assertEqual(request.temperature, 5.0) # Currently allowed - def test_model_serialization_roundtrip(self): """Test that models can be serialized and deserialized""" original_request = ChatCompletionRequest( diff --git a/test/srt/openai/test_server.py b/test/srt/openai/test_server.py deleted file mode 100644 index 3de52f4cd..000000000 --- a/test/srt/openai/test_server.py +++ /dev/null @@ -1,52 +0,0 @@ -# sglang/test/srt/openai/test_server.py -import requests - -from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID - - -def test_health(openai_server: str): - r = requests.get(f"{openai_server}/health") - assert r.status_code == 200 - # FastAPI returns an empty body → r.text == "" - assert r.text == "" - - -def test_models_endpoint(openai_server: str): - r = requests.get(f"{openai_server}/v1/models") - assert r.status_code == 200, r.text - payload = r.json() - - # Basic contract - assert "data" in payload and isinstance(payload["data"], list) and payload["data"] - - # Validate fields of the first model card - first = payload["data"][0] - for key in ("id", "root", "max_model_len"): - assert key in first, f"missing {key} in {first}" - - # max_model_len must be positive - assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0 - - # The server should report the same model id we launched it with - ids = {m["id"] for m in payload["data"]} - assert MODEL_ID in ids - - -def test_get_model_info(openai_server: str): - r = requests.get(f"{openai_server}/get_model_info") - assert r.status_code == 200, r.text - info = r.json() - - expected_keys = {"model_path", "tokenizer_path", "is_generation"} - assert expected_keys.issubset(info.keys()) - - # model_path must end with the one we passed on the CLI - assert info["model_path"].endswith(MODEL_ID) - - # is_generation is documented as a boolean - assert isinstance(info["is_generation"], bool) - - -def test_unknown_route_returns_404(openai_server: str): - r = requests.get(f"{openai_server}/definitely-not-a-real-route") - assert r.status_code == 404 diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index ff38fccc7..701dc2e55 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -57,11 +57,21 @@ class _MockTokenizerManager: self.create_abort_task = Mock() +class _MockTemplateManager: + """Minimal mock for TemplateManager.""" + + def __init__(self): + self.chat_template_name: Optional[str] = "llama-3" + self.jinja_template_content_format: Optional[str] = None + self.completion_template_name: Optional[str] = None + + class ServingChatTestCase(unittest.TestCase): # ------------- common fixtures ------------- def setUp(self): self.tm = _MockTokenizerManager() - self.chat = OpenAIServingChat(self.tm) + self.template_manager = _MockTemplateManager() + self.chat = OpenAIServingChat(self.tm, self.template_manager) # frequently reused requests self.basic_req = ChatCompletionRequest( @@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase): self.assertFalse(adapted.stream) self.assertEqual(processed, self.basic_req) - # # ------------- tool-call branch ------------- - # def test_tool_call_request_conversion(self): - # req = ChatCompletionRequest( - # model="x", - # messages=[{"role": "user", "content": "Weather?"}], - # tools=[ - # { - # "type": "function", - # "function": { - # "name": "get_weather", - # "parameters": {"type": "object", "properties": {}}, - # }, - # } - # ], - # tool_choice="auto", - # ) - - # with patch.object( - # self.chat, - # "_process_messages", - # return_value=("Prompt", [1, 2, 3], None, None, [], [""], None), - # ): - # adapted, _ = self.chat._convert_to_internal_request(req, "rid") - # self.assertEqual(adapted.rid, "rid") - - # def test_tool_choice_none(self): - # req = ChatCompletionRequest( - # model="x", - # messages=[{"role": "user", "content": "Hi"}], - # tools=[{"type": "function", "function": {"name": "noop"}}], - # tool_choice="none", - # ) - # with patch.object( - # self.chat, - # "_process_messages", - # return_value=("Prompt", [1, 2, 3], None, None, [], [""], None), - # ): - # adapted, _ = self.chat._convert_to_internal_request(req, "rid") - # self.assertEqual(adapted.rid, "rid") - - # ------------- multimodal branch ------------- - def test_multimodal_request_with_images(self): - self.tm.model_config.is_multimodal = True - - req = ChatCompletionRequest( - model="x", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in the image?"}, - { - "type": "image_url", - "image_url": {"url": "data:image/jpeg;base64,"}, - }, - ], - } - ], - ) - - with patch.object( - self.chat, - "_apply_jinja_template", - return_value=("prompt", [1, 2], ["img"], None, [], []), - ), patch.object( - self.chat, - "_apply_conversation_template", - return_value=("prompt", ["img"], None, [], []), - ): - out = self.chat._process_messages(req, True) - _, _, image_data, *_ = out - self.assertEqual(image_data, ["img"]) - - # ------------- template handling ------------- - def test_jinja_template_processing(self): - req = ChatCompletionRequest( - model="x", messages=[{"role": "user", "content": "Hello"}] - ) - self.tm.chat_template_name = None - self.tm.tokenizer.chat_template = "" - - with patch.object( - self.chat, - "_apply_jinja_template", - return_value=("processed", [1], None, None, [], [""]), - ), patch("builtins.hasattr", return_value=True): - prompt, prompt_ids, *_ = self.chat._process_messages(req, False) - self.assertEqual(prompt, "processed") - self.assertEqual(prompt_ids, [1]) - # ------------- sampling-params ------------- def test_sampling_param_build(self): req = ChatCompletionRequest( diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 7a42523c7..c0568e93b 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -5,6 +5,7 @@ Run with: """ import unittest +from typing import Optional from unittest.mock import AsyncMock, Mock, patch from sglang.srt.entrypoints.openai.protocol import CompletionRequest @@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl from sglang.srt.managers.tokenizer_manager import TokenizerManager +class _MockTemplateManager: + """Minimal mock for TemplateManager.""" + + def __init__(self): + self.chat_template_name: Optional[str] = None + self.jinja_template_content_format: Optional[str] = None + self.completion_template_name: Optional[str] = ( + None # Set to None to avoid template processing + ) + + class ServingCompletionTestCase(unittest.TestCase): """Bundle all prompt/echo tests in one TestCase.""" @@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase): tm.generate_request = AsyncMock() tm.create_abort_task = Mock() - self.sc = OpenAIServingCompletion(tm) + self.template_manager = _MockTemplateManager() + self.sc = OpenAIServingCompletion(tm, self.template_manager) # ---------- prompt-handling ---------- def test_single_string_prompt(self): @@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase): internal, _ = self.sc._convert_to_internal_request(req) self.assertEqual(internal.input_ids, [1, 2, 3, 4]) - def test_completion_template_handling(self): - req = CompletionRequest( - model="x", prompt="def f():", suffix="return 1", max_tokens=100 - ) - with patch( - "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined", - return_value=True, - ), patch( - "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", - return_value="processed_prompt", - ): - internal, _ = self.sc._convert_to_internal_request(req) - self.assertEqual(internal.text, "processed_prompt") - # ---------- echo-handling ---------- def test_echo_with_string_prompt_streaming(self): req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True) diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index b6e3094df..1305e668d 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi with the original adapter.py functionality and follows OpenAI API specifications. """ -import asyncio -import json -import time import unittest import uuid -from typing import Any, Dict, List -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock from fastapi import Request -from fastapi.responses import ORJSONResponse -from pydantic_core import ValidationError from sglang.srt.entrypoints.openai.protocol import ( - EmbeddingObject, EmbeddingRequest, EmbeddingResponse, - ErrorResponse, MultimodalEmbeddingInput, - UsageInfo, ) from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from sglang.srt.managers.io_struct import EmbeddingReqInput @@ -58,11 +49,22 @@ class _MockTokenizerManager: self.generate_request = Mock(return_value=mock_generate_embedding()) +# Mock TemplateManager for embedding tests +class _MockTemplateManager: + def __init__(self): + self.chat_template_name = None # None for embeddings usually + self.jinja_template_content_format = None + self.completion_template_name = None + + class ServingEmbeddingTestCase(unittest.TestCase): def setUp(self): """Set up test fixtures.""" self.tokenizer_manager = _MockTokenizerManager() - self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager) + self.template_manager = _MockTemplateManager() + self.serving_embedding = OpenAIServingEmbedding( + self.tokenizer_manager, self.template_manager + ) self.request = Mock(spec=Request) self.request.headers = {} @@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase): self.assertIsNone(adapted_request.image_data[1]) # self.assertEqual(adapted_request.rid, "test-id") - def test_build_single_embedding_response(self): - """Test building response for single embedding.""" - ret_data = [ - { - "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], - "meta_info": {"prompt_tokens": 5}, - } - ] - - response = self.serving_embedding._build_embedding_response(ret_data) - - self.assertIsInstance(response, EmbeddingResponse) - self.assertEqual(response.model, "test-model") - self.assertEqual(len(response.data), 1) - self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5]) - self.assertEqual(response.data[0].index, 0) - self.assertEqual(response.data[0].object, "embedding") - self.assertEqual(response.usage.prompt_tokens, 5) - self.assertEqual(response.usage.total_tokens, 5) - self.assertEqual(response.usage.completion_tokens, 0) - - def test_build_multiple_embedding_response(self): - """Test building response for multiple embeddings.""" - ret_data = [ - { - "embedding": [0.1, 0.2, 0.3], - "meta_info": {"prompt_tokens": 3}, - }, - { - "embedding": [0.4, 0.5, 0.6], - "meta_info": {"prompt_tokens": 4}, - }, - ] - - response = self.serving_embedding._build_embedding_response(ret_data) - - self.assertIsInstance(response, EmbeddingResponse) - self.assertEqual(len(response.data), 2) - self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3]) - self.assertEqual(response.data[0].index, 0) - self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6]) - self.assertEqual(response.data[1].index, 1) - self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4 - self.assertEqual(response.usage.total_tokens, 7) - - def test_handle_request_success(self): - """Test successful embedding request handling.""" - - async def run_test(): - # Mock the generate_request to return expected data - async def mock_generate(): - yield { - "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], - "meta_info": {"prompt_tokens": 5}, - } - - self.serving_embedding.tokenizer_manager.generate_request = Mock( - return_value=mock_generate() - ) - - response = await self.serving_embedding.handle_request( - self.basic_req, self.request - ) - - self.assertIsInstance(response, EmbeddingResponse) - self.assertEqual(len(response.data), 1) - self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5]) - - asyncio.run(run_test()) - - def test_handle_request_validation_error(self): - """Test handling request with validation error.""" - - async def run_test(): - invalid_request = EmbeddingRequest(model="test-model", input="") - - response = await self.serving_embedding.handle_request( - invalid_request, self.request - ) - - self.assertIsInstance(response, ORJSONResponse) - self.assertEqual(response.status_code, 400) - - asyncio.run(run_test()) - - def test_handle_request_generation_error(self): - """Test handling request with generation error.""" - - async def run_test(): - # Mock generate_request to raise an error - async def mock_generate_error(): - raise ValueError("Generation failed") - yield # This won't be reached but needed for async generator - - self.serving_embedding.tokenizer_manager.generate_request = Mock( - return_value=mock_generate_error() - ) - - response = await self.serving_embedding.handle_request( - self.basic_req, self.request - ) - - self.assertIsInstance(response, ORJSONResponse) - self.assertEqual(response.status_code, 400) - - asyncio.run(run_test()) - - def test_handle_request_internal_error(self): - """Test handling request with internal server error.""" - - async def run_test(): - # Mock _convert_to_internal_request to raise an exception - with patch.object( - self.serving_embedding, - "_convert_to_internal_request", - side_effect=Exception("Internal error"), - ): - response = await self.serving_embedding.handle_request( - self.basic_req, self.request - ) - - self.assertIsInstance(response, ORJSONResponse) - self.assertEqual(response.status_code, 500) - - asyncio.run(run_test()) - if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 327d084d9..26ff2b31b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -29,6 +29,10 @@ suites = { TestFile("models/test_reward_models.py", 132), TestFile("models/test_vlm_models.py", 437), TestFile("models/test_transformers_models.py", 320), + TestFile("openai/test_protocol.py", 10), + TestFile("openai/test_serving_chat.py", 10), + TestFile("openai/test_serving_completions.py", 10), + TestFile("openai/test_serving_embedding.py", 10), TestFile("test_abort.py", 51), TestFile("test_block_int8.py", 22), TestFile("test_create_kvindices.py", 2), @@ -49,6 +53,7 @@ suites = { TestFile("test_hidden_states.py", 55), TestFile("test_int8_kernel.py", 8), TestFile("test_input_embeddings.py", 38), + TestFile("test_jinja_template_utils.py", 1), TestFile("test_json_constrained.py", 98), TestFile("test_large_max_new_tokens.py", 41), TestFile("test_metrics.py", 32), @@ -59,14 +64,8 @@ suites = { TestFile("test_mla_fp8.py", 93), TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_overlap_scheduler.py", 234), - TestFile("test_openai_adapter.py", 1), TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_server.py", 149), - TestFile("openai/test_server.py", 120), - TestFile("openai/test_protocol.py", 60), - TestFile("openai/test_serving_chat.py", 120), - TestFile("openai/test_serving_completions.py", 120), - TestFile("openai/test_serving_embedding.py", 120), TestFile("test_openai_server_hidden_states.py", 240), TestFile("test_penalty.py", 41), TestFile("test_page_size.py", 60), diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index ab6c8b999..35b75d715 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -3,6 +3,7 @@ import unittest from xgrammar import GrammarCompiler, TokenizerInfo +from sglang.srt.entrypoints.openai.protocol import Function, Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.llama32_detector import Llama32Detector @@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.openai_api.protocol import Function, Tool from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST diff --git a/test/srt/test_openai_adapter.py b/test/srt/test_jinja_template_utils.py similarity index 95% rename from test/srt/test_openai_adapter.py rename to test/srt/test_jinja_template_utils.py index 598ddfd49..b6bacd12f 100644 --- a/test/srt/test_openai_adapter.py +++ b/test/srt/test_jinja_template_utils.py @@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils. import unittest from unittest.mock import patch -from sglang.srt.openai_api.utils import ( - detect_template_content_format, +from sglang.srt.jinja_template_utils import ( + detect_jinja_template_content_format, process_content_for_template_format, ) from sglang.test.test_utils import CustomTestCase @@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase): {%- endfor %} """ - result = detect_template_content_format(llama4_pattern) + result = detect_jinja_template_content_format(llama4_pattern) self.assertEqual(result, "openai") def test_detect_deepseek_string_format(self): @@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase): {%- endfor %} """ - result = detect_template_content_format(deepseek_pattern) + result = detect_jinja_template_content_format(deepseek_pattern) self.assertEqual(result, "string") def test_detect_invalid_template(self): """Test handling of invalid template (should default to 'string').""" invalid_pattern = "{{{{ invalid jinja syntax }}}}" - result = detect_template_content_format(invalid_pattern) + result = detect_jinja_template_content_format(invalid_pattern) self.assertEqual(result, "string") def test_detect_empty_template(self): """Test handling of empty template (should default to 'string').""" - result = detect_template_content_format("") + result = detect_jinja_template_content_format("") self.assertEqual(result, "string") def test_process_content_openai_format(self): diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 4913eb38c..8d5f5ad89 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase): ) is_firsts = {} + is_finished = {} for response in generator: usage = response.usage if usage is not None: @@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase): continue index = response.choices[0].index + finish_reason = response.choices[0].finish_reason + if finish_reason is not None: + is_finished[index] = True + data = response.choices[0].delta if is_firsts.get(index, True): @@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase): is_firsts[index] = False continue - if logprobs: + if logprobs and not is_finished.get(index, False): assert response.choices[0].logprobs, f"logprobs was not returned" assert isinstance( response.choices[0].logprobs.content[0].top_logprobs[0].token, str @@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase): assert ( isinstance(data.content, str) or isinstance(data.reasoning_content, str) - or len(data.tool_calls) > 0 + or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0) or response.choices[0].finish_reason ) assert response.id @@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase): index, True ), f"index {index} is not found in the response" - def _create_batch(self, mode, client): - if mode == "completion": - input_file_path = "complete_input.jsonl" - # write content to input file - content = [ - { - "custom_id": "request-1", - "method": "POST", - "url": "/v1/completions", - "body": { - "model": "gpt-3.5-turbo-instruct", - "prompt": "List 3 names of famous soccer player: ", - "max_tokens": 20, - }, - }, - { - "custom_id": "request-2", - "method": "POST", - "url": "/v1/completions", - "body": { - "model": "gpt-3.5-turbo-instruct", - "prompt": "List 6 names of famous basketball player: ", - "max_tokens": 40, - }, - }, - { - "custom_id": "request-3", - "method": "POST", - "url": "/v1/completions", - "body": { - "model": "gpt-3.5-turbo-instruct", - "prompt": "List 6 names of famous tenniss player: ", - "max_tokens": 40, - }, - }, - ] - - else: - input_file_path = "chat_input.jsonl" - content = [ - { - "custom_id": "request-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": "gpt-3.5-turbo-0125", - "messages": [ - { - "role": "system", - "content": "You are a helpful assistant.", - }, - { - "role": "user", - "content": "Hello! List 3 NBA players and tell a story", - }, - ], - "max_tokens": 30, - }, - }, - { - "custom_id": "request-2", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": "gpt-3.5-turbo-0125", - "messages": [ - {"role": "system", "content": "You are an assistant. "}, - { - "role": "user", - "content": "Hello! List three capital and tell a story", - }, - ], - "max_tokens": 50, - }, - }, - ] - - with open(input_file_path, "w") as file: - for line in content: - file.write(json.dumps(line) + "\n") - - with open(input_file_path, "rb") as file: - uploaded_file = client.files.create(file=file, purpose="batch") - if mode == "completion": - endpoint = "/v1/completions" - elif mode == "chat": - endpoint = "/v1/chat/completions" - completion_window = "24h" - batch_job = client.batches.create( - input_file_id=uploaded_file.id, - endpoint=endpoint, - completion_window=completion_window, - ) - - return batch_job, content, uploaded_file - - def run_batch(self, mode): - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client) - - while batch_job.status not in ["completed", "failed", "cancelled"]: - time.sleep(3) - print( - f"Batch job status: {batch_job.status}...trying again in 3 seconds..." - ) - batch_job = client.batches.retrieve(batch_job.id) - assert ( - batch_job.status == "completed" - ), f"Batch job status is not completed: {batch_job.status}" - assert batch_job.request_counts.completed == len(content) - assert batch_job.request_counts.failed == 0 - assert batch_job.request_counts.total == len(content) - - result_file_id = batch_job.output_file_id - file_response = client.files.content(result_file_id) - result_content = file_response.read().decode("utf-8") # Decode bytes to string - results = [ - json.loads(line) - for line in result_content.split("\n") - if line.strip() != "" - ] - assert len(results) == len(content) - for delete_fid in [uploaded_file.id, result_file_id]: - del_pesponse = client.files.delete(delete_fid) - assert del_pesponse.deleted - - def run_cancel_batch(self, mode): - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client) - - assert batch_job.status not in ["cancelling", "cancelled"] - - batch_job = client.batches.cancel(batch_id=batch_job.id) - assert batch_job.status == "cancelling" - - while batch_job.status not in ["failed", "cancelled"]: - batch_job = client.batches.retrieve(batch_job.id) - print( - f"Batch job status: {batch_job.status}...trying again in 3 seconds..." - ) - time.sleep(3) - - assert batch_job.status == "cancelled" - del_response = client.files.delete(uploaded_file.id) - assert del_response.deleted - def test_completion(self): for echo in [False, True]: for logprobs in [None, 5]: @@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase): for parallel_sample_num in [1, 2]: self.run_chat_completion_stream(logprobs, parallel_sample_num) - def test_batch(self): - for mode in ["completion", "chat"]: - self.run_batch(mode) - - def test_cancel_batch(self): - for mode in ["completion", "chat"]: - self.run_cancel_batch(mode) - def test_regex(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white assert len(models) == 1 assert isinstance(getattr(models[0], "max_model_len", None), int) + def test_retrieve_model(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + # Test retrieving an existing model + retrieved_model = client.models.retrieve(self.model) + self.assertEqual(retrieved_model.id, self.model) + self.assertEqual(retrieved_model.root, self.model) + + # Test retrieving a non-existent model + with self.assertRaises(openai.NotFoundError): + client.models.retrieve("non-existent-model") + # ------------------------------------------------------------------------- # EBNF Test Class: TestOpenAIServerEBNF @@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase): self.assertTrue(len(response.data[0].embedding) > 0) self.assertTrue(len(response.data[1].embedding) > 0) + def test_embedding_single_batch_str(self): + """Test embedding with a List[str] and length equals to 1""" + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create(model=self.model, input=["Hello world"]) + self.assertEqual(len(response.data), 1) + self.assertTrue(len(response.data[0].embedding) > 0) + + def test_embedding_single_int_list(self): + """Test embedding with a List[int] or List[List[int]]]""" + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create( + model=self.model, + input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]], + ) + self.assertEqual(len(response.data), 1) + self.assertTrue(len(response.data[0].embedding) > 0) + + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create( + model=self.model, + input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061], + ) + self.assertEqual(len(response.data), 1) + self.assertTrue(len(response.data[0].embedding) > 0) + def test_empty_string_embedding(self): """Test embedding an empty string.""" diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index 5a90f0853..518ec8671 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -21,6 +21,7 @@ from transformers import ( from sglang import Engine from sglang.srt.configs.model_config import ModelConfig from sglang.srt.conversation import generate_chat_conv +from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor, @@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import ( MultimodalInputs, ) from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.server_args import ServerArgs diff --git a/test/srt/test_vlm_input_format.py b/test/srt/test_vlm_input_format.py index 2911e04d1..d2670ecac 100644 --- a/test/srt/test_vlm_input_format.py +++ b/test/srt/test_vlm_input_format.py @@ -15,7 +15,7 @@ from transformers import ( from sglang import Engine from sglang.srt.conversation import generate_chat_conv -from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"