feat(oai refactor): Replace openai_api with entrypoints/openai (#7351)
Co-authored-by: Jin Pan <jpan236@wisc.edu>
This commit is contained in:
@@ -20,7 +20,7 @@ from sglang.bench_serving import (
|
|||||||
get_gen_prefix_cache_path,
|
get_gen_prefix_cache_path,
|
||||||
)
|
)
|
||||||
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
|
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
|
||||||
from sglang.utils import encode_video_base64
|
from sglang.utils import encode_video_base64
|
||||||
|
|
||||||
# type of content fields, can be only prompts or with images/videos
|
# type of content fields, can be only prompts or with images/videos
|
||||||
|
|||||||
@@ -64,11 +64,14 @@
|
|||||||
"text = \"Once upon a time\"\n",
|
"text = \"Once upon a time\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
|
"curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
|
||||||
|
" -H \"Content-Type: application/json\" \\\n",
|
||||||
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
|
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n",
|
"result = subprocess.check_output(curl_text, shell=True)\n",
|
||||||
" \"embedding\"\n",
|
"\n",
|
||||||
"]\n",
|
"print(result)\n",
|
||||||
|
"\n",
|
||||||
|
"text_embedding = json.loads(result)[\"data\"][0][\"embedding\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")"
|
"print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")"
|
||||||
]
|
]
|
||||||
@@ -152,6 +155,7 @@
|
|||||||
"input_ids = tokenizer.encode(text)\n",
|
"input_ids = tokenizer.encode(text)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
|
"curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
|
||||||
|
" -H \"Content-Type: application/json\" \\\n",
|
||||||
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
|
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",
|
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",
|
||||||
|
|||||||
@@ -67,6 +67,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"curl_command = f\"\"\"\n",
|
"curl_command = f\"\"\"\n",
|
||||||
"curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
|
"curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
|
||||||
|
" -H \"Content-Type: application/json\" \\\\\n",
|
||||||
" -d '{{\n",
|
" -d '{{\n",
|
||||||
" \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
|
" \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
|
||||||
" \"messages\": [\n",
|
" \"messages\": [\n",
|
||||||
|
|||||||
@@ -36,7 +36,7 @@
|
|||||||
"import requests\n",
|
"import requests\n",
|
||||||
"from PIL import Image\n",
|
"from PIL import Image\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
|
"from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest\n",
|
||||||
"from sglang.srt.conversation import chat_templates\n",
|
"from sglang.srt.conversation import chat_templates\n",
|
||||||
"\n",
|
"\n",
|
||||||
"image = Image.open(\n",
|
"image = Image.open(\n",
|
||||||
|
|||||||
@@ -15,9 +15,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from enum import auto
|
from enum import auto
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||||
@@ -57,46 +55,6 @@ class CompletionTemplate:
|
|||||||
completion_templates: dict[str, CompletionTemplate] = {}
|
completion_templates: dict[str, CompletionTemplate] = {}
|
||||||
|
|
||||||
|
|
||||||
def load_completion_template_for_openai_api(completion_template_arg):
|
|
||||||
global completion_template_name
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not completion_template_exists(completion_template_arg):
|
|
||||||
if not os.path.exists(completion_template_arg):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Completion template {completion_template_arg} is not a built-in template name "
|
|
||||||
"or a valid completion template file path."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert completion_template_arg.endswith(
|
|
||||||
".json"
|
|
||||||
), "unrecognized format of completion template file"
|
|
||||||
with open(completion_template_arg, "r") as filep:
|
|
||||||
template = json.load(filep)
|
|
||||||
try:
|
|
||||||
fim_position = FimPosition[template["fim_position"]]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown fim position: {template['fim_position']}"
|
|
||||||
) from None
|
|
||||||
register_completion_template(
|
|
||||||
CompletionTemplate(
|
|
||||||
name=template["name"],
|
|
||||||
fim_begin_token=template["fim_begin_token"],
|
|
||||||
fim_middle_token=template["fim_middle_token"],
|
|
||||||
fim_end_token=template["fim_end_token"],
|
|
||||||
fim_position=fim_position,
|
|
||||||
),
|
|
||||||
override=True,
|
|
||||||
)
|
|
||||||
completion_template_name = template["name"]
|
|
||||||
else:
|
|
||||||
completion_template_name = completion_template_arg
|
|
||||||
|
|
||||||
|
|
||||||
def register_completion_template(template: CompletionTemplate, override: bool = False):
|
def register_completion_template(template: CompletionTemplate, override: bool = False):
|
||||||
"""Register a new completion template."""
|
"""Register a new completion template."""
|
||||||
if not override:
|
if not override:
|
||||||
|
|||||||
@@ -11,7 +11,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Conversation chat templates."""
|
"""Conversation chat templates.
|
||||||
|
|
||||||
|
This module provides conversation template definitions, data structures, and utilities
|
||||||
|
for managing chat templates across different model types in SGLang.
|
||||||
|
|
||||||
|
Key components:
|
||||||
|
- Conversation class: Defines the structure and behavior of chat templates
|
||||||
|
- SeparatorStyle enum: Different conversation formatting styles
|
||||||
|
- Template registry: Functions to register and retrieve templates by name or model path
|
||||||
|
- Built-in templates: Pre-defined templates for popular models
|
||||||
|
"""
|
||||||
|
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
@@ -20,7 +30,7 @@ import re
|
|||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from sglang.srt.utils import read_system_prompt_from_file
|
from sglang.srt.utils import read_system_prompt_from_file
|
||||||
|
|
||||||
|
|
||||||
@@ -618,7 +628,7 @@ def generate_chat_conv(
|
|||||||
|
|
||||||
|
|
||||||
# llama2 template
|
# llama2 template
|
||||||
# reference: https://huggingface.co/blog/codellama#conversational-instructions
|
# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
|
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
|||||||
import torch
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
|
|
||||||
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
|
|
||||||
from sglang.srt.entrypoints.EngineBase import EngineBase
|
from sglang.srt.entrypoints.EngineBase import EngineBase
|
||||||
from sglang.srt.managers.data_parallel_controller import (
|
from sglang.srt.managers.data_parallel_controller import (
|
||||||
run_data_parallel_controller_process,
|
run_data_parallel_controller_process,
|
||||||
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.openai_api.adapter import (
|
|
||||||
guess_chat_template_name_from_model_path,
|
|
||||||
load_chat_template_for_openai_api,
|
|
||||||
)
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -123,12 +119,13 @@ class Engine(EngineBase):
|
|||||||
logger.info(f"{server_args=}")
|
logger.info(f"{server_args=}")
|
||||||
|
|
||||||
# Launch subprocesses
|
# Launch subprocesses
|
||||||
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
port_args=port_args,
|
port_args=port_args,
|
||||||
)
|
)
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.tokenizer_manager = tokenizer_manager
|
self.tokenizer_manager = tokenizer_manager
|
||||||
|
self.template_manager = template_manager
|
||||||
self.scheduler_info = scheduler_info
|
self.scheduler_info = scheduler_info
|
||||||
|
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
@@ -647,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
|
|
||||||
def _launch_subprocesses(
|
def _launch_subprocesses(
|
||||||
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
||||||
) -> Tuple[TokenizerManager, Dict]:
|
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
||||||
"""
|
"""
|
||||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||||
"""
|
"""
|
||||||
@@ -732,7 +729,7 @@ def _launch_subprocesses(
|
|||||||
|
|
||||||
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
||||||
# When using `Engine` as a Python API, we don't want to block here.
|
# When using `Engine` as a Python API, we don't want to block here.
|
||||||
return None, None
|
return None, None, None
|
||||||
|
|
||||||
launch_dummy_health_check_server(server_args.host, server_args.port)
|
launch_dummy_health_check_server(server_args.host, server_args.port)
|
||||||
|
|
||||||
@@ -741,7 +738,7 @@ def _launch_subprocesses(
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
||||||
)
|
)
|
||||||
return None, None
|
return None, None, None
|
||||||
|
|
||||||
# Launch detokenizer process
|
# Launch detokenizer process
|
||||||
detoken_proc = mp.Process(
|
detoken_proc = mp.Process(
|
||||||
@@ -755,15 +752,15 @@ def _launch_subprocesses(
|
|||||||
|
|
||||||
# Launch tokenizer process
|
# Launch tokenizer process
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||||
if server_args.chat_template:
|
|
||||||
load_chat_template_for_openai_api(
|
|
||||||
tokenizer_manager, server_args.chat_template, server_args.model_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
guess_chat_template_name_from_model_path(server_args.model_path)
|
|
||||||
|
|
||||||
if server_args.completion_template:
|
# Initialize templates
|
||||||
load_completion_template_for_openai_api(server_args.completion_template)
|
template_manager = TemplateManager()
|
||||||
|
template_manager.initialize_templates(
|
||||||
|
tokenizer_manager=tokenizer_manager,
|
||||||
|
model_path=server_args.model_path,
|
||||||
|
chat_template=server_args.chat_template,
|
||||||
|
completion_template=server_args.completion_template,
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for the model to finish loading
|
# Wait for the model to finish loading
|
||||||
scheduler_infos = []
|
scheduler_infos = []
|
||||||
@@ -787,4 +784,4 @@ def _launch_subprocesses(
|
|||||||
# Assume all schedulers have the same scheduler_info
|
# Assume all schedulers have the same scheduler_info
|
||||||
scheduler_info = scheduler_infos[0]
|
scheduler_info = scheduler_infos[0]
|
||||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||||
return tokenizer_manager, scheduler_info
|
return tokenizer_manager, template_manager, scheduler_info
|
||||||
|
|||||||
@@ -38,7 +38,8 @@ import orjson
|
|||||||
import requests
|
import requests
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import FastAPI, File, Form, Request, UploadFile
|
from fastapi import Depends, FastAPI, Request, UploadFile
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
@@ -47,6 +48,20 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
register_disaggregation_server,
|
register_disaggregation_server,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
CompletionRequest,
|
||||||
|
EmbeddingRequest,
|
||||||
|
ModelCard,
|
||||||
|
ModelList,
|
||||||
|
ScoringRequest,
|
||||||
|
V1RerankReqInput,
|
||||||
|
)
|
||||||
|
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
||||||
|
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
||||||
|
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
|
||||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
@@ -67,26 +82,11 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
V1RerankReqInput,
|
|
||||||
VertexGenerateReqInput,
|
VertexGenerateReqInput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
from sglang.srt.openai_api.adapter import (
|
|
||||||
v1_batches,
|
|
||||||
v1_cancel_batch,
|
|
||||||
v1_chat_completions,
|
|
||||||
v1_completions,
|
|
||||||
v1_delete_file,
|
|
||||||
v1_embeddings,
|
|
||||||
v1_files_create,
|
|
||||||
v1_rerank,
|
|
||||||
v1_retrieve_batch,
|
|
||||||
v1_retrieve_file,
|
|
||||||
v1_retrieve_file_content,
|
|
||||||
v1_score,
|
|
||||||
)
|
|
||||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -109,6 +109,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class _GlobalState:
|
class _GlobalState:
|
||||||
tokenizer_manager: TokenizerManager
|
tokenizer_manager: TokenizerManager
|
||||||
|
template_manager: TemplateManager
|
||||||
scheduler_info: Dict
|
scheduler_info: Dict
|
||||||
|
|
||||||
|
|
||||||
@@ -123,6 +124,24 @@ def set_global_state(global_state: _GlobalState):
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(fast_api_app: FastAPI):
|
async def lifespan(fast_api_app: FastAPI):
|
||||||
server_args: ServerArgs = fast_api_app.server_args
|
server_args: ServerArgs = fast_api_app.server_args
|
||||||
|
|
||||||
|
# Initialize OpenAI serving handlers
|
||||||
|
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
|
_global_state.tokenizer_manager, _global_state.template_manager
|
||||||
|
)
|
||||||
|
fast_api_app.state.openai_serving_chat = OpenAIServingChat(
|
||||||
|
_global_state.tokenizer_manager, _global_state.template_manager
|
||||||
|
)
|
||||||
|
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
|
_global_state.tokenizer_manager, _global_state.template_manager
|
||||||
|
)
|
||||||
|
fast_api_app.state.openai_serving_score = OpenAIServingScore(
|
||||||
|
_global_state.tokenizer_manager
|
||||||
|
)
|
||||||
|
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
|
||||||
|
_global_state.tokenizer_manager
|
||||||
|
)
|
||||||
|
|
||||||
if server_args.warmups is not None:
|
if server_args.warmups is not None:
|
||||||
await execute_warmups(
|
await execute_warmups(
|
||||||
server_args.warmups.split(","), _global_state.tokenizer_manager
|
server_args.warmups.split(","), _global_state.tokenizer_manager
|
||||||
@@ -148,6 +167,36 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Custom exception handlers to change validation error status codes
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
"""Override FastAPI's default 422 validation error with 400"""
|
||||||
|
return ORJSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"detail": exc.errors(),
|
||||||
|
"body": exc.body,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_json_request(raw_request: Request):
|
||||||
|
"""Validate that the request content-type is application/json."""
|
||||||
|
content_type = raw_request.headers.get("content-type", "").lower()
|
||||||
|
media_type = content_type.split(";", maxsplit=1)[0]
|
||||||
|
if media_type != "application/json":
|
||||||
|
raise RequestValidationError(
|
||||||
|
errors=[
|
||||||
|
{
|
||||||
|
"loc": ["header", "content-type"],
|
||||||
|
"msg": "Unsupported Media Type: Only 'application/json' is allowed",
|
||||||
|
"type": "value_error",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||||
|
|
||||||
|
|
||||||
@@ -330,13 +379,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/v1/rerank", methods=["POST", "PUT"])
|
@app.api_route(
|
||||||
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
|
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||||
try:
|
)
|
||||||
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
|
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
||||||
return ret
|
"""Endpoint for reranking documents based on query relevance."""
|
||||||
except ValueError as e:
|
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
||||||
return _create_error_response(e)
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
||||||
@@ -619,25 +669,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
|
|||||||
##### OpenAI-compatible API endpoints #####
|
##### OpenAI-compatible API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
||||||
async def openai_v1_completions(raw_request: Request):
|
async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
|
||||||
return await v1_completions(_global_state.tokenizer_manager, raw_request)
|
"""OpenAI-compatible text completion endpoint."""
|
||||||
|
return await raw_request.app.state.openai_serving_completion.handle_request(
|
||||||
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
|
||||||
async def openai_v1_chat_completions(raw_request: Request):
|
async def openai_v1_chat_completions(
|
||||||
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
request: ChatCompletionRequest, raw_request: Request
|
||||||
|
):
|
||||||
|
"""OpenAI-compatible chat completion endpoint."""
|
||||||
|
return await raw_request.app.state.openai_serving_chat.handle_request(
|
||||||
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
@app.post(
|
||||||
async def openai_v1_embeddings(raw_request: Request):
|
"/v1/embeddings",
|
||||||
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
|
response_class=ORJSONResponse,
|
||||||
return response
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
)
|
||||||
|
async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
|
||||||
|
"""OpenAI-compatible embeddings endpoint."""
|
||||||
|
return await raw_request.app.state.openai_serving_embedding.handle_request(
|
||||||
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_class=ORJSONResponse)
|
@app.get("/v1/models", response_class=ORJSONResponse)
|
||||||
def available_models():
|
async def available_models():
|
||||||
"""Show available models."""
|
"""Show available models. OpenAI-compatible endpoint."""
|
||||||
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
||||||
model_cards = []
|
model_cards = []
|
||||||
for served_model_name in served_model_names:
|
for served_model_name in served_model_names:
|
||||||
@@ -651,45 +715,29 @@ def available_models():
|
|||||||
return ModelList(data=model_cards)
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/files")
|
@app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
|
||||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
async def retrieve_model(model: str):
|
||||||
return await v1_files_create(
|
"""Retrieves a model instance, providing basic information about the model."""
|
||||||
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
|
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
||||||
|
|
||||||
|
if model not in served_model_names:
|
||||||
|
return ORJSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": f"The model '{model}' does not exist",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": "model",
|
||||||
|
"code": "model_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return ModelCard(
|
||||||
@app.delete("/v1/files/{file_id}")
|
id=model,
|
||||||
async def delete_file(file_id: str):
|
root=model,
|
||||||
# https://platform.openai.com/docs/api-reference/files/delete
|
max_model_len=_global_state.tokenizer_manager.model_config.context_len,
|
||||||
return await v1_delete_file(file_id)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/batches")
|
|
||||||
async def openai_v1_batches(raw_request: Request):
|
|
||||||
return await v1_batches(_global_state.tokenizer_manager, raw_request)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/batches/{batch_id}/cancel")
|
|
||||||
async def cancel_batches(batch_id: str):
|
|
||||||
# https://platform.openai.com/docs/api-reference/batch/cancel
|
|
||||||
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/batches/{batch_id}")
|
|
||||||
async def retrieve_batch(batch_id: str):
|
|
||||||
return await v1_retrieve_batch(batch_id)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/files/{file_id}")
|
|
||||||
async def retrieve_file(file_id: str):
|
|
||||||
# https://platform.openai.com/docs/api-reference/files/retrieve
|
|
||||||
return await v1_retrieve_file(file_id)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/files/{file_id}/content")
|
|
||||||
async def retrieve_file_content(file_id: str):
|
|
||||||
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
|
||||||
return await v1_retrieve_file_content(file_id)
|
|
||||||
|
|
||||||
|
|
||||||
## SageMaker API
|
## SageMaker API
|
||||||
@@ -700,8 +748,13 @@ async def sagemaker_health() -> Response:
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/invocations")
|
@app.post("/invocations")
|
||||||
async def sagemaker_chat_completions(raw_request: Request):
|
async def sagemaker_chat_completions(
|
||||||
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
request: ChatCompletionRequest, raw_request: Request
|
||||||
|
):
|
||||||
|
"""OpenAI-compatible chat completion endpoint."""
|
||||||
|
return await raw_request.app.state.openai_serving_chat.handle_request(
|
||||||
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
## Vertex AI API
|
## Vertex AI API
|
||||||
@@ -732,10 +785,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|||||||
return ORJSONResponse({"predictions": ret})
|
return ORJSONResponse({"predictions": ret})
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/score")
|
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||||
async def v1_score_request(raw_request: Request):
|
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
||||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||||
return await v1_score(_global_state.tokenizer_manager, raw_request)
|
return await raw_request.app.state.openai_serving_score.handle_request(
|
||||||
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_error_response(e):
|
def _create_error_response(e):
|
||||||
@@ -764,10 +819,13 @@ def launch_server(
|
|||||||
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
||||||
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
||||||
"""
|
"""
|
||||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||||
|
server_args=server_args
|
||||||
|
)
|
||||||
set_global_state(
|
set_global_state(
|
||||||
_GlobalState(
|
_GlobalState(
|
||||||
tokenizer_manager=tokenizer_manager,
|
tokenizer_manager=tokenizer_manager,
|
||||||
|
template_manager=template_manager,
|
||||||
scheduler_info=scheduler_info,
|
scheduler_info=scheduler_info,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,388 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""
|
|
||||||
SGLang OpenAI-Compatible API Server.
|
|
||||||
|
|
||||||
This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import Callable, Dict, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
import uvicorn
|
|
||||||
import uvloop
|
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import (
|
|
||||||
FAKE_BOOTSTRAP_HOST,
|
|
||||||
register_disaggregation_server,
|
|
||||||
)
|
|
||||||
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
|
|
||||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
|
||||||
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
from sglang.srt.utils import (
|
|
||||||
add_prometheus_middleware,
|
|
||||||
delete_directory,
|
|
||||||
get_bool_env_var,
|
|
||||||
kill_process_tree,
|
|
||||||
set_uvicorn_logging_configs,
|
|
||||||
)
|
|
||||||
from sglang.srt.warmup import execute_warmups
|
|
||||||
from sglang.utils import get_exception_traceback
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
||||||
|
|
||||||
|
|
||||||
# Store global states
|
|
||||||
class AppState:
|
|
||||||
engine: Optional[Engine] = None
|
|
||||||
server_args: Optional[ServerArgs] = None
|
|
||||||
tokenizer_manager: Optional[TokenizerManager] = None
|
|
||||||
scheduler_info: Optional[Dict] = None
|
|
||||||
embedding_server: Optional[OpenAIServingEmbedding] = None
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
app.state.server_args.enable_metrics = True # By default, we enable metrics
|
|
||||||
|
|
||||||
server_args = app.state.server_args
|
|
||||||
|
|
||||||
# Initialize engine
|
|
||||||
logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...")
|
|
||||||
|
|
||||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
|
||||||
app.state.tokenizer_manager = tokenizer_manager
|
|
||||||
app.state.scheduler_info = scheduler_info
|
|
||||||
app.state.serving_embedding = OpenAIServingEmbedding(
|
|
||||||
tokenizer_manager=tokenizer_manager
|
|
||||||
)
|
|
||||||
|
|
||||||
if server_args.enable_metrics:
|
|
||||||
add_prometheus_middleware(app)
|
|
||||||
enable_func_timer()
|
|
||||||
|
|
||||||
# Initialize engine state attribute to None for now
|
|
||||||
app.state.engine = None
|
|
||||||
|
|
||||||
if server_args.warmups is not None:
|
|
||||||
await execute_warmups(
|
|
||||||
server_args.warmups.split(","), app.state.tokenizer_manager
|
|
||||||
)
|
|
||||||
logger.info("Warmup ended")
|
|
||||||
|
|
||||||
warmup_thread = getattr(app, "warmup_thread", None)
|
|
||||||
if warmup_thread is not None:
|
|
||||||
warmup_thread.start()
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
# Lifespan shutdown
|
|
||||||
if hasattr(app.state, "engine") and app.state.engine is not None:
|
|
||||||
logger.info("SGLang engine is shutting down.")
|
|
||||||
# Add engine cleanup logic here when implemented
|
|
||||||
|
|
||||||
|
|
||||||
# Fast API app with CORS enabled
|
|
||||||
app = FastAPI(
|
|
||||||
lifespan=lifespan,
|
|
||||||
# TODO: check where /openai.json is created or why we use this
|
|
||||||
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
|
|
||||||
)
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/health", methods=["GET"])
|
|
||||||
async def health() -> Response:
|
|
||||||
"""Health check. Used for readiness and liveness probes."""
|
|
||||||
# In the future, this could check engine health more deeply
|
|
||||||
# For now, if the server is up, it's healthy.
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/v1/models", methods=["GET"])
|
|
||||||
async def show_models():
|
|
||||||
"""Show available models. Currently, it returns the served model name.
|
|
||||||
|
|
||||||
This endpoint is compatible with the OpenAI API standard.
|
|
||||||
"""
|
|
||||||
served_model_names = [app.state.tokenizer_manager.served_model_name]
|
|
||||||
model_cards = []
|
|
||||||
for served_model_name in served_model_names:
|
|
||||||
model_cards.append(
|
|
||||||
ModelCard(
|
|
||||||
id=served_model_name,
|
|
||||||
root=served_model_name,
|
|
||||||
max_model_len=app.state.tokenizer_manager.model_config.context_len,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ModelList(data=model_cards)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_model_info")
|
|
||||||
async def get_model_info():
|
|
||||||
"""Get the model information."""
|
|
||||||
result = {
|
|
||||||
"model_path": app.state.tokenizer_manager.model_path,
|
|
||||||
"tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path,
|
|
||||||
"is_generation": app.state.tokenizer_manager.is_generation,
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
|
||||||
async def openai_v1_completions(raw_request: Request):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
async def openai_v1_chat_completions(raw_request: Request):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/embeddings")
|
|
||||||
async def openai_v1_embeddings(raw_request: Request):
|
|
||||||
try:
|
|
||||||
request_json = await raw_request.json()
|
|
||||||
request = EmbeddingRequest(**request_json)
|
|
||||||
except Exception as e:
|
|
||||||
return app.state.serving_embedding.create_error_response(
|
|
||||||
f"Invalid request body, error: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = await app.state.serving_embedding.handle_request(request, raw_request)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/score")
|
|
||||||
async def v1_score_request(raw_request: Request):
|
|
||||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/v1/models/{model_id}", methods=["GET"])
|
|
||||||
async def show_model_detail(model_id: str):
|
|
||||||
served_model_name = app.state.tokenizer_manager.served_model_name
|
|
||||||
|
|
||||||
return ModelCard(
|
|
||||||
id=served_model_name,
|
|
||||||
root=served_model_name,
|
|
||||||
max_model_len=app.state.tokenizer_manager.model_config.context_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Additional API endpoints will be implemented in separate serving_*.py modules
|
|
||||||
# and mounted as APIRouters in future PRs
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_warmup(
|
|
||||||
server_args: ServerArgs,
|
|
||||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
|
||||||
image_token_text: str,
|
|
||||||
launch_callback: Optional[Callable[[], None]] = None,
|
|
||||||
):
|
|
||||||
return
|
|
||||||
# TODO: Please wait until the /generate implementation is complete,
|
|
||||||
# or confirm if modifications are needed before removing this.
|
|
||||||
|
|
||||||
headers = {}
|
|
||||||
url = server_args.url()
|
|
||||||
if server_args.api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
|
||||||
|
|
||||||
# Wait until the server is launched
|
|
||||||
success = False
|
|
||||||
for _ in range(120):
|
|
||||||
time.sleep(1)
|
|
||||||
try:
|
|
||||||
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
|
||||||
assert res.status_code == 200, f"{res=}, {res.text=}"
|
|
||||||
success = True
|
|
||||||
break
|
|
||||||
except (AssertionError, requests.exceptions.RequestException):
|
|
||||||
last_traceback = get_exception_traceback()
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
if pipe_finish_writer is not None:
|
|
||||||
pipe_finish_writer.send(last_traceback)
|
|
||||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
|
||||||
kill_process_tree(os.getpid())
|
|
||||||
return
|
|
||||||
|
|
||||||
model_info = res.json()
|
|
||||||
|
|
||||||
# Send a warmup request
|
|
||||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
|
||||||
# TODO: Replace with OpenAI API
|
|
||||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
|
||||||
json_data = {
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if server_args.skip_tokenizer_init:
|
|
||||||
json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
|
|
||||||
# TODO Workaround the bug that embedding errors for list of size 1
|
|
||||||
if server_args.dp_size == 1:
|
|
||||||
json_data["input_ids"] = json_data["input_ids"][0]
|
|
||||||
else:
|
|
||||||
json_data["text"] = ["The capital city of France is"] * server_args.dp_size
|
|
||||||
# TODO Workaround the bug that embedding errors for list of size 1
|
|
||||||
if server_args.dp_size == 1:
|
|
||||||
json_data["text"] = json_data["text"][0]
|
|
||||||
|
|
||||||
# Debug dumping
|
|
||||||
if server_args.debug_tensor_dump_input_file:
|
|
||||||
json_data.pop("text", None)
|
|
||||||
json_data["input_ids"] = np.load(
|
|
||||||
server_args.debug_tensor_dump_input_file
|
|
||||||
).tolist()
|
|
||||||
json_data["sampling_params"]["max_new_tokens"] = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
if server_args.disaggregation_mode == "null":
|
|
||||||
res = requests.post(
|
|
||||||
url + request_name,
|
|
||||||
json=json_data,
|
|
||||||
headers=headers,
|
|
||||||
timeout=600,
|
|
||||||
)
|
|
||||||
assert res.status_code == 200, f"{res}"
|
|
||||||
else:
|
|
||||||
logger.info(f"Start of prefill warmup ...")
|
|
||||||
json_data = {
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0.0,
|
|
||||||
"max_new_tokens": 8,
|
|
||||||
"ignore_eos": True,
|
|
||||||
},
|
|
||||||
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
|
|
||||||
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
|
||||||
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
|
||||||
"bootstrap_room": [
|
|
||||||
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
|
|
||||||
for i in range(server_args.dp_size)
|
|
||||||
],
|
|
||||||
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
|
|
||||||
}
|
|
||||||
res = requests.post(
|
|
||||||
url + request_name,
|
|
||||||
json=json_data,
|
|
||||||
headers=headers,
|
|
||||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
last_traceback = get_exception_traceback()
|
|
||||||
if pipe_finish_writer is not None:
|
|
||||||
pipe_finish_writer.send(last_traceback)
|
|
||||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
|
||||||
kill_process_tree(os.getpid())
|
|
||||||
return
|
|
||||||
|
|
||||||
# Debug print
|
|
||||||
# logger.info(f"{res.json()=}")
|
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
|
||||||
if pipe_finish_writer is not None:
|
|
||||||
pipe_finish_writer.send("ready")
|
|
||||||
|
|
||||||
if server_args.delete_ckpt_after_loading:
|
|
||||||
delete_directory(server_args.model_path)
|
|
||||||
|
|
||||||
if server_args.debug_tensor_dump_input_file:
|
|
||||||
kill_process_tree(os.getpid())
|
|
||||||
|
|
||||||
if server_args.pdlb_url is not None:
|
|
||||||
register_disaggregation_server(
|
|
||||||
server_args.disaggregation_mode,
|
|
||||||
server_args.port,
|
|
||||||
server_args.disaggregation_bootstrap_port,
|
|
||||||
server_args.pdlb_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
if launch_callback is not None:
|
|
||||||
launch_callback()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server")
|
|
||||||
# Add arguments from ServerArgs. This allows reuse of existing CLI definitions.
|
|
||||||
ServerArgs.add_cli_args(parser)
|
|
||||||
# Potentially add server-specific arguments here in the future if needed
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
|
||||||
|
|
||||||
# Store server_args in app.state for access in lifespan and endpoints
|
|
||||||
app.state.server_args = server_args
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=server_args.log_level.upper(),
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send a warmup request - we will create the thread launch it
|
|
||||||
# in the lifespan after all other warmups have fired.
|
|
||||||
warmup_thread = threading.Thread(
|
|
||||||
target=_wait_and_warmup,
|
|
||||||
args=(
|
|
||||||
server_args,
|
|
||||||
None,
|
|
||||||
None, # Never used
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
app.warmup_thread = warmup_thread
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Start the server
|
|
||||||
set_uvicorn_logging_configs()
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host=server_args.host,
|
|
||||||
port=server_args.port,
|
|
||||||
log_level=server_args.log_level.lower(),
|
|
||||||
timeout_keep_alive=60, # Increased keep-alive for potentially long requests
|
|
||||||
loop="uvloop", # Use uvloop for better performance if available
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
warmup_thread.join()
|
|
||||||
@@ -207,7 +207,7 @@ class CompletionResponseChoice(BaseModel):
|
|||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
hidden_states: Optional[object] = None
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
@@ -404,7 +404,6 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_tool_choice_default(cls, values):
|
def set_tool_choice_default(cls, values):
|
||||||
if isinstance(values, dict):
|
|
||||||
if values.get("tool_choice") is None:
|
if values.get("tool_choice") is None:
|
||||||
if values.get("tools") is None:
|
if values.get("tools") is None:
|
||||||
values["tool_choice"] = "none"
|
values["tool_choice"] = "none"
|
||||||
@@ -412,13 +411,6 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
values["tool_choice"] = "auto"
|
values["tool_choice"] = "auto"
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@field_validator("messages")
|
|
||||||
@classmethod
|
|
||||||
def validate_messages_not_empty(cls, v):
|
|
||||||
if not v:
|
|
||||||
raise ValueError("Messages cannot be empty")
|
|
||||||
return v
|
|
||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||||
top_k: int = -1
|
top_k: int = -1
|
||||||
min_p: float = 0.0
|
min_p: float = 0.0
|
||||||
@@ -457,9 +449,11 @@ class ChatCompletionResponseChoice(BaseModel):
|
|||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||||
finish_reason: Literal[
|
finish_reason: Optional[
|
||||||
|
Literal[
|
||||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||||
]
|
]
|
||||||
|
] = None
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
hidden_states: Optional[object] = None
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
@@ -530,7 +524,7 @@ class EmbeddingRequest(BaseModel):
|
|||||||
input: EmbeddingInput
|
input: EmbeddingInput
|
||||||
model: str
|
model: str
|
||||||
encoding_format: str = "float"
|
encoding_format: str = "float"
|
||||||
dimensions: int = None
|
dimensions: Optional[int] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
# The request id.
|
# The request id.
|
||||||
|
|||||||
@@ -2,16 +2,12 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
||||||
ErrorResponse,
|
|
||||||
OpenAIServingRequest,
|
|
||||||
UsageInfo,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
|
||||||
@@ -51,7 +47,7 @@ class OpenAIServingBase(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in request: {e}")
|
logger.exception(f"Error in request: {e}")
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
message=f"Internal server error: {str(e)}",
|
message=f"Internal server error: {str(e)}",
|
||||||
err_type="InternalServerError",
|
err_type="InternalServerError",
|
||||||
@@ -63,8 +59,12 @@ class OpenAIServingBase(ABC):
|
|||||||
"""Generate request ID based on request type"""
|
"""Generate request ID based on request type"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
|
def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:
|
||||||
"""Generate request ID based on request type"""
|
"""Generate request ID based on request type"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError
|
||||||
|
# Temporarily return None in this function until the rid logic is clear.
|
||||||
if rid := getattr(request, "rid", None):
|
if rid := getattr(request, "rid", None):
|
||||||
return rid
|
return rid
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ class OpenAIServingBase(ABC):
|
|||||||
adapted_request: GenerateReqInput,
|
adapted_request: GenerateReqInput,
|
||||||
request: OpenAIServingRequest,
|
request: OpenAIServingRequest,
|
||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> StreamingResponse:
|
) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:
|
||||||
"""Handle streaming request
|
"""Handle streaming request
|
||||||
|
|
||||||
Override this method in child classes that support streaming requests.
|
Override this method in child classes that support streaming requests.
|
||||||
@@ -99,7 +99,7 @@ class OpenAIServingBase(ABC):
|
|||||||
adapted_request: GenerateReqInput,
|
adapted_request: GenerateReqInput,
|
||||||
request: OpenAIServingRequest,
|
request: OpenAIServingRequest,
|
||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> Union[Any, ErrorResponse]:
|
) -> Union[Any, ErrorResponse, ORJSONResponse]:
|
||||||
"""Handle non-streaming request
|
"""Handle non-streaming request
|
||||||
|
|
||||||
Override this method in child classes that support non-streaming requests.
|
Override this method in child classes that support non-streaming requests.
|
||||||
@@ -110,7 +110,7 @@ class OpenAIServingBase(ABC):
|
|||||||
status_code=501,
|
status_code=501,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
|
def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:
|
||||||
"""Validate request"""
|
"""Validate request"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -122,6 +122,7 @@ class OpenAIServingBase(ABC):
|
|||||||
param: Optional[str] = None,
|
param: Optional[str] = None,
|
||||||
) -> ORJSONResponse:
|
) -> ORJSONResponse:
|
||||||
"""Create an error response"""
|
"""Create an error response"""
|
||||||
|
# TODO: remove fastapi dependency in openai and move response handling to the entrypoint
|
||||||
error = ErrorResponse(
|
error = ErrorResponse(
|
||||||
object="error",
|
object="error",
|
||||||
message=message,
|
message=message,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@@ -6,7 +5,7 @@ import uuid
|
|||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
@@ -28,13 +27,14 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||||
from sglang.srt.entrypoints.openai.utils import (
|
from sglang.srt.entrypoints.openai.utils import (
|
||||||
detect_template_content_format,
|
|
||||||
process_content_for_template_format,
|
|
||||||
process_hidden_states_from_ret,
|
process_hidden_states_from_ret,
|
||||||
to_openai_style_logprobs,
|
to_openai_style_logprobs,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
|
from sglang.srt.jinja_template_utils import process_content_for_template_format
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.utils import convert_json_schema_to_str
|
from sglang.utils import convert_json_schema_to_str
|
||||||
|
|
||||||
@@ -42,13 +42,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIServingChat(OpenAIServingBase):
|
class OpenAIServingChat(OpenAIServingBase):
|
||||||
"""Handler for chat completion requests"""
|
"""Handler for /v1/chat/completions requests"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(
|
||||||
super().__init__(*args, **kwargs)
|
self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager
|
||||||
# Instance-specific cache for template content format detection
|
):
|
||||||
self._cached_chat_template = None
|
super().__init__(tokenizer_manager)
|
||||||
self._cached_template_format = None
|
self.template_manager = template_manager
|
||||||
|
|
||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "chatcmpl-"
|
return "chatcmpl-"
|
||||||
@@ -142,19 +142,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Use chat template
|
# Use chat template
|
||||||
if (
|
if self.template_manager.chat_template_name is None:
|
||||||
hasattr(self.tokenizer_manager, "chat_template_name")
|
|
||||||
and self.tokenizer_manager.chat_template_name is None
|
|
||||||
):
|
|
||||||
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||||
self._apply_jinja_template(request, tools, is_multimodal)
|
self._apply_jinja_template(request, tools, is_multimodal)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt, image_data, audio_data, modalities, stop = (
|
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
||||||
self._apply_conversation_template(request)
|
self._apply_conversation_template(request, is_multimodal)
|
||||||
)
|
)
|
||||||
if not is_multimodal:
|
|
||||||
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
|
||||||
else:
|
else:
|
||||||
# Use raw prompt
|
# Use raw prompt
|
||||||
prompt_ids = request.messages
|
prompt_ids = request.messages
|
||||||
@@ -181,23 +176,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
is_multimodal: bool,
|
is_multimodal: bool,
|
||||||
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
|
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
|
||||||
"""Apply Jinja chat template"""
|
"""Apply Jinja chat template"""
|
||||||
|
prompt = ""
|
||||||
|
prompt_ids = []
|
||||||
openai_compatible_messages = []
|
openai_compatible_messages = []
|
||||||
image_data = []
|
image_data = []
|
||||||
audio_data = []
|
audio_data = []
|
||||||
modalities = []
|
modalities = []
|
||||||
|
|
||||||
# Detect template content format
|
template_content_format = self.template_manager.jinja_template_content_format
|
||||||
current_template = self.tokenizer_manager.tokenizer.chat_template
|
|
||||||
if current_template != self._cached_chat_template:
|
|
||||||
self._cached_chat_template = current_template
|
|
||||||
self._cached_template_format = detect_template_content_format(
|
|
||||||
current_template
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Detected chat template content format: {self._cached_template_format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
template_content_format = self._cached_template_format
|
|
||||||
|
|
||||||
for message in request.messages:
|
for message in request.messages:
|
||||||
if message.content is None:
|
if message.content is None:
|
||||||
@@ -262,14 +248,21 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
|
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
|
||||||
|
|
||||||
stop = request.stop or []
|
stop = request.stop
|
||||||
|
image_data = image_data if image_data else None
|
||||||
|
audio_data = audio_data if audio_data else None
|
||||||
|
modalities = modalities if modalities else []
|
||||||
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||||
|
|
||||||
def _apply_conversation_template(
|
def _apply_conversation_template(
|
||||||
self, request: ChatCompletionRequest
|
self,
|
||||||
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]:
|
request: ChatCompletionRequest,
|
||||||
|
is_multimodal: bool,
|
||||||
|
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
|
||||||
"""Apply conversation template"""
|
"""Apply conversation template"""
|
||||||
conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name)
|
prompt = ""
|
||||||
|
prompt_ids = []
|
||||||
|
conv = generate_chat_conv(request, self.template_manager.chat_template_name)
|
||||||
|
|
||||||
# If we should continue the final assistant message, adjust the conversation.
|
# If we should continue the final assistant message, adjust the conversation.
|
||||||
if (
|
if (
|
||||||
@@ -296,9 +289,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
else:
|
else:
|
||||||
prompt = conv.get_prompt()
|
prompt = conv.get_prompt()
|
||||||
|
|
||||||
image_data = conv.image_data
|
image_data = conv.image_data if conv.image_data else None
|
||||||
audio_data = conv.audio_data
|
audio_data = conv.audio_data if conv.audio_data else None
|
||||||
modalities = conv.modalities
|
modalities = conv.modalities if conv.modalities else []
|
||||||
stop = conv.stop_str or [] if not request.ignore_eos else []
|
stop = conv.stop_str or [] if not request.ignore_eos else []
|
||||||
|
|
||||||
if request.stop:
|
if request.stop:
|
||||||
@@ -307,7 +300,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
else:
|
else:
|
||||||
stop.extend(request.stop)
|
stop.extend(request.stop)
|
||||||
|
|
||||||
return prompt, image_data, audio_data, modalities, stop
|
if not is_multimodal:
|
||||||
|
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
return prompt, prompt_ids, image_data, audio_data, modalities, stop
|
||||||
|
|
||||||
def _build_sampling_params(
|
def _build_sampling_params(
|
||||||
self,
|
self,
|
||||||
@@ -459,13 +455,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
stream_buffers[index] = stream_buffer + delta
|
stream_buffers[index] = stream_buffer + delta
|
||||||
|
|
||||||
# Handle reasoning content
|
# Handle reasoning content
|
||||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
|
||||||
"enable_thinking", True
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
self.tokenizer_manager.server_args.reasoning_parser
|
self.tokenizer_manager.server_args.reasoning_parser
|
||||||
and request.separate_reasoning
|
and request.separate_reasoning
|
||||||
and enable_thinking
|
|
||||||
):
|
):
|
||||||
reasoning_text, delta = self._process_reasoning_stream(
|
reasoning_text, delta = self._process_reasoning_stream(
|
||||||
index, delta, reasoning_parser_dict, content, request
|
index, delta, reasoning_parser_dict, content, request
|
||||||
@@ -591,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
)
|
)
|
||||||
yield f"data: {usage_chunk.model_dump_json()}\n\n"
|
yield f"data: {usage_chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
except Exception as e:
|
except ValueError as e:
|
||||||
error = self.create_streaming_error_response(str(e))
|
error = self.create_streaming_error_response(str(e))
|
||||||
yield f"data: {error}\n\n"
|
yield f"data: {error}\n\n"
|
||||||
|
|
||||||
@@ -602,7 +594,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
adapted_request: GenerateReqInput,
|
adapted_request: GenerateReqInput,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> Union[ChatCompletionResponse, ErrorResponse]:
|
) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]:
|
||||||
"""Handle non-streaming chat completion request"""
|
"""Handle non-streaming chat completion request"""
|
||||||
try:
|
try:
|
||||||
ret = await self.tokenizer_manager.generate_request(
|
ret = await self.tokenizer_manager.generate_request(
|
||||||
@@ -627,7 +619,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
ret: List[Dict[str, Any]],
|
ret: List[Dict[str, Any]],
|
||||||
created: int,
|
created: int,
|
||||||
) -> ChatCompletionResponse:
|
) -> Union[ChatCompletionResponse, ORJSONResponse]:
|
||||||
"""Build chat completion response from generation results"""
|
"""Build chat completion response from generation results"""
|
||||||
choices = []
|
choices = []
|
||||||
|
|
||||||
@@ -645,11 +637,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
# Handle reasoning content
|
# Handle reasoning content
|
||||||
reasoning_text = None
|
reasoning_text = None
|
||||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
|
||||||
"enable_thinking", True
|
|
||||||
)
|
|
||||||
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||||
if reasoning_parser and request.separate_reasoning and enable_thinking:
|
if reasoning_parser and request.separate_reasoning:
|
||||||
try:
|
try:
|
||||||
parser = ReasoningParser(
|
parser = ReasoningParser(
|
||||||
model_type=reasoning_parser, stream_reasoning=False
|
model_type=reasoning_parser, stream_reasoning=False
|
||||||
@@ -691,9 +680,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
# Calculate usage
|
# Calculate usage
|
||||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
|
||||||
usage = UsageProcessor.calculate_response_usage(
|
usage = UsageProcessor.calculate_response_usage(
|
||||||
ret, n_choices=request.n, enable_cache_report=cache_report
|
ret,
|
||||||
|
n_choices=request.n,
|
||||||
|
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
@@ -821,6 +811,25 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
reasoning_parser = reasoning_parser_dict[index]
|
reasoning_parser = reasoning_parser_dict[index]
|
||||||
return reasoning_parser.parse_stream_chunk(delta)
|
return reasoning_parser.parse_stream_chunk(delta)
|
||||||
|
|
||||||
|
def _get_enable_thinking_from_request(request: ChatCompletionRequest) -> bool:
|
||||||
|
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
||||||
|
|
||||||
|
NOTE: This parameter is only useful for models that support enable_thinking
|
||||||
|
flag, such as Qwen3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_obj: The request object (or an item from a list of requests).
|
||||||
|
Returns:
|
||||||
|
The boolean value of 'enable_thinking' if found and not True, otherwise True.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
hasattr(request, "chat_template_kwargs")
|
||||||
|
and request.chat_template_kwargs
|
||||||
|
and request.chat_template_kwargs.get("enable_thinking") is not None
|
||||||
|
):
|
||||||
|
return request.chat_template_kwargs.get("enable_thinking")
|
||||||
|
return True
|
||||||
|
|
||||||
async def _process_tool_call_stream(
|
async def _process_tool_call_stream(
|
||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
|
|||||||
@@ -3,12 +3,9 @@ import time
|
|||||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.code_completion_parser import (
|
from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
|
||||||
generate_completion_prompt_from_request,
|
|
||||||
is_completion_template_defined,
|
|
||||||
)
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
@@ -24,12 +21,22 @@ from sglang.srt.entrypoints.openai.utils import (
|
|||||||
to_openai_style_logprobs,
|
to_openai_style_logprobs,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingCompletion(OpenAIServingBase):
|
class OpenAIServingCompletion(OpenAIServingBase):
|
||||||
"""Handler for completion requests"""
|
"""Handler for /v1/completion requests"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer_manager: TokenizerManager,
|
||||||
|
template_manager: TemplateManager,
|
||||||
|
):
|
||||||
|
super().__init__(tokenizer_manager)
|
||||||
|
self.template_manager = template_manager
|
||||||
|
|
||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "cmpl-"
|
return "cmpl-"
|
||||||
@@ -47,7 +54,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
)
|
)
|
||||||
# Process prompt
|
# Process prompt
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
if is_completion_template_defined():
|
if self.template_manager.completion_template_name is not None:
|
||||||
prompt = generate_completion_prompt_from_request(request)
|
prompt = generate_completion_prompt_from_request(request)
|
||||||
|
|
||||||
# Set logprob start length based on echo and logprobs
|
# Set logprob start length based on echo and logprobs
|
||||||
@@ -141,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
prompt_tokens = {}
|
prompt_tokens = {}
|
||||||
completion_tokens = {}
|
completion_tokens = {}
|
||||||
cached_tokens = {}
|
cached_tokens = {}
|
||||||
|
hidden_states = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for content in self.tokenizer_manager.generate_request(
|
async for content in self.tokenizer_manager.generate_request(
|
||||||
@@ -152,6 +160,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||||
|
hidden_states[index] = content["meta_info"].get("hidden_states", None)
|
||||||
|
|
||||||
stream_buffer = stream_buffers.get(index, "")
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
# Handle echo for first chunk
|
# Handle echo for first chunk
|
||||||
@@ -192,7 +201,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
delta = text[len(stream_buffer) :]
|
delta = text[len(stream_buffer) :]
|
||||||
stream_buffers[index] = stream_buffer + delta
|
stream_buffers[index] = stream_buffer + delta
|
||||||
finish_reason = content["meta_info"]["finish_reason"]
|
finish_reason = content["meta_info"]["finish_reason"]
|
||||||
hidden_states = content["meta_info"].get("hidden_states", None)
|
|
||||||
|
|
||||||
choice_data = CompletionResponseStreamChoice(
|
choice_data = CompletionResponseStreamChoice(
|
||||||
index=index,
|
index=index,
|
||||||
@@ -269,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
adapted_request: GenerateReqInput,
|
adapted_request: GenerateReqInput,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> Union[CompletionResponse, ErrorResponse]:
|
) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
|
||||||
"""Handle non-streaming completion request"""
|
"""Handle non-streaming completion request"""
|
||||||
try:
|
try:
|
||||||
generator = self.tokenizer_manager.generate_request(
|
generator = self.tokenizer_manager.generate_request(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from fastapi.responses import ORJSONResponse
|
||||||
|
|
||||||
from sglang.srt.conversation import generate_embedding_convs
|
from sglang.srt.conversation import generate_embedding_convs
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
@@ -13,10 +14,20 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||||
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingEmbedding(OpenAIServingBase):
|
class OpenAIServingEmbedding(OpenAIServingBase):
|
||||||
"""Handler for embedding requests"""
|
"""Handler for v1/embeddings requests"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer_manager: TokenizerManager,
|
||||||
|
template_manager: TemplateManager,
|
||||||
|
):
|
||||||
|
super().__init__(tokenizer_manager)
|
||||||
|
self.template_manager = template_manager
|
||||||
|
|
||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "embd-"
|
return "embd-"
|
||||||
@@ -68,10 +79,6 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
prompt_kwargs = {"text": prompt}
|
prompt_kwargs = {"text": prompt}
|
||||||
elif isinstance(prompt, list):
|
elif isinstance(prompt, list):
|
||||||
if len(prompt) > 0 and isinstance(prompt[0], str):
|
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||||
# List of strings - if it's a single string in a list, treat as single string
|
|
||||||
if len(prompt) == 1:
|
|
||||||
prompt_kwargs = {"text": prompt[0]}
|
|
||||||
else:
|
|
||||||
prompt_kwargs = {"text": prompt}
|
prompt_kwargs = {"text": prompt}
|
||||||
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
||||||
# Handle multimodal embedding inputs
|
# Handle multimodal embedding inputs
|
||||||
@@ -84,11 +91,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
|
|
||||||
generate_prompts = []
|
generate_prompts = []
|
||||||
# Check if we have a chat template for multimodal embeddings
|
# Check if we have a chat template for multimodal embeddings
|
||||||
chat_template_name = getattr(
|
if self.template_manager.chat_template_name is not None:
|
||||||
self.tokenizer_manager, "chat_template_name", None
|
convs = generate_embedding_convs(
|
||||||
|
texts, images, self.template_manager.chat_template_name
|
||||||
)
|
)
|
||||||
if chat_template_name is not None:
|
|
||||||
convs = generate_embedding_convs(texts, images, chat_template_name)
|
|
||||||
for conv in convs:
|
for conv in convs:
|
||||||
generate_prompts.append(conv.get_prompt())
|
generate_prompts.append(conv.get_prompt())
|
||||||
else:
|
else:
|
||||||
@@ -122,7 +128,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
adapted_request: EmbeddingReqInput,
|
adapted_request: EmbeddingReqInput,
|
||||||
request: EmbeddingRequest,
|
request: EmbeddingRequest,
|
||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:
|
||||||
"""Handle the embedding request"""
|
"""Handle the embedding request"""
|
||||||
try:
|
try:
|
||||||
ret = await self.tokenizer_manager.generate_request(
|
ret = await self.tokenizer_manager.generate_request(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import logging
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from fastapi.responses import ORJSONResponse
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -15,7 +16,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIServingRerank(OpenAIServingBase):
|
class OpenAIServingRerank(OpenAIServingBase):
|
||||||
"""Handler for rerank requests"""
|
"""Handler for /v1/rerank requests"""
|
||||||
|
|
||||||
|
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
|
||||||
|
# to another module in the future.
|
||||||
|
|
||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "rerank-"
|
return "rerank-"
|
||||||
@@ -61,7 +65,7 @@ class OpenAIServingRerank(OpenAIServingBase):
|
|||||||
adapted_request: EmbeddingReqInput,
|
adapted_request: EmbeddingReqInput,
|
||||||
request: V1RerankReqInput,
|
request: V1RerankReqInput,
|
||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> Union[RerankResponse, ErrorResponse]:
|
) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:
|
||||||
"""Handle the rerank request"""
|
"""Handle the rerank request"""
|
||||||
try:
|
try:
|
||||||
ret = await self.tokenizer_manager.generate_request(
|
ret = await self.tokenizer_manager.generate_request(
|
||||||
@@ -74,16 +78,16 @@ class OpenAIServingRerank(OpenAIServingBase):
|
|||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
response = self._build_rerank_response(ret, request)
|
responses = self._build_rerank_response(ret, request)
|
||||||
return response
|
return responses
|
||||||
|
|
||||||
def _build_rerank_response(
|
def _build_rerank_response(
|
||||||
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
|
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
|
||||||
) -> List[RerankResponse]:
|
) -> List[RerankResponse]:
|
||||||
"""Build the rerank response from generation results"""
|
"""Build the rerank response from generation results"""
|
||||||
response = []
|
responses = []
|
||||||
for idx, ret_item in enumerate(ret):
|
for idx, ret_item in enumerate(ret):
|
||||||
response.append(
|
responses.append(
|
||||||
RerankResponse(
|
RerankResponse(
|
||||||
score=ret_item["embedding"],
|
score=ret_item["embedding"],
|
||||||
document=request.documents[idx],
|
document=request.documents[idx],
|
||||||
@@ -93,6 +97,6 @@ class OpenAIServingRerank(OpenAIServingBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Sort by score in descending order (highest relevance first)
|
# Sort by score in descending order (highest relevance first)
|
||||||
response.sort(key=lambda x: x.score, reverse=True)
|
responses.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
return response
|
return responses
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
@@ -14,7 +14,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIServingScore(OpenAIServingBase):
|
class OpenAIServingScore(OpenAIServingBase):
|
||||||
"""Handler for scoring requests"""
|
"""Handler for /v1/score requests"""
|
||||||
|
|
||||||
|
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
|
||||||
|
# to another module in the future.
|
||||||
|
|
||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "score-"
|
return "score-"
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import jinja2.nodes
|
|
||||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
@@ -13,168 +10,6 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# JINJA TEMPLATE CONTENT FORMAT DETECTION
|
|
||||||
# ============================================================================
|
|
||||||
#
|
|
||||||
# This adapts vLLM's approach for detecting chat template content format:
|
|
||||||
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
|
|
||||||
# - Analyzes Jinja template AST to detect content iteration patterns
|
|
||||||
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
|
|
||||||
# - 'string' format: templates that expect simple string content
|
|
||||||
# - Processes content accordingly to match template expectations
|
|
||||||
|
|
||||||
|
|
||||||
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
|
||||||
"""Check if node is a variable access like {{ varname }}"""
|
|
||||||
if isinstance(node, jinja2.nodes.Name):
|
|
||||||
return node.ctx == "load" and node.name == varname
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
|
||||||
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
|
|
||||||
if isinstance(node, jinja2.nodes.Getitem):
|
|
||||||
return (
|
|
||||||
_is_var_access(node.node, varname)
|
|
||||||
and isinstance(node.arg, jinja2.nodes.Const)
|
|
||||||
and node.arg.value == key
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(node, jinja2.nodes.Getattr):
|
|
||||||
return _is_var_access(node.node, varname) and node.attr == key
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_var_or_elems_access(
|
|
||||||
node: jinja2.nodes.Node,
|
|
||||||
varname: str,
|
|
||||||
key: str = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if node accesses varname or varname[key] with filters/tests"""
|
|
||||||
if isinstance(node, jinja2.nodes.Filter):
|
|
||||||
return node.node is not None and _is_var_or_elems_access(
|
|
||||||
node.node, varname, key
|
|
||||||
)
|
|
||||||
if isinstance(node, jinja2.nodes.Test):
|
|
||||||
return _is_var_or_elems_access(node.node, varname, key)
|
|
||||||
|
|
||||||
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
|
||||||
node.arg, jinja2.nodes.Slice
|
|
||||||
):
|
|
||||||
return _is_var_or_elems_access(node.node, varname, key)
|
|
||||||
|
|
||||||
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
|
||||||
|
|
||||||
|
|
||||||
def _try_extract_ast(chat_template: str):
|
|
||||||
"""Try to parse the Jinja template into an AST"""
|
|
||||||
try:
|
|
||||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
|
||||||
return jinja_compiled.environment.parse(chat_template)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Error when compiling Jinja template: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def detect_template_content_format(chat_template: str) -> str:
|
|
||||||
"""
|
|
||||||
Detect whether a chat template expects 'string' or 'openai' content format.
|
|
||||||
|
|
||||||
- 'string': content is a simple string (like DeepSeek templates)
|
|
||||||
- 'openai': content is a list of structured dicts (like Llama4 templates)
|
|
||||||
|
|
||||||
Detection logic:
|
|
||||||
- If template has loops like {%- for content in message['content'] -%} → 'openai'
|
|
||||||
- Otherwise → 'string'
|
|
||||||
"""
|
|
||||||
jinja_ast = _try_extract_ast(chat_template)
|
|
||||||
if jinja_ast is None:
|
|
||||||
return "string"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Look for patterns like: {%- for content in message['content'] -%}
|
|
||||||
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
|
|
||||||
loop_iter = loop_ast.iter
|
|
||||||
|
|
||||||
# Check if iterating over message['content'] or similar
|
|
||||||
if _is_var_or_elems_access(loop_iter, "message", "content"):
|
|
||||||
return "openai" # Found content iteration → openai format
|
|
||||||
|
|
||||||
return "string" # No content loops found → string format
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Error when parsing AST of Jinja template: {e}")
|
|
||||||
return "string"
|
|
||||||
|
|
||||||
|
|
||||||
def process_content_for_template_format(
|
|
||||||
msg_dict: dict,
|
|
||||||
content_format: str,
|
|
||||||
image_data: list,
|
|
||||||
audio_data: list,
|
|
||||||
modalities: list,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Process message content based on detected template format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg_dict: Message dictionary with content
|
|
||||||
content_format: 'string' or 'openai' (detected via AST analysis)
|
|
||||||
image_data: List to append extracted image URLs
|
|
||||||
audio_data: List to append extracted audio URLs
|
|
||||||
modalities: List to append modalities
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Processed message dictionary
|
|
||||||
"""
|
|
||||||
if not isinstance(msg_dict.get("content"), list):
|
|
||||||
# Already a string or None, no processing needed
|
|
||||||
return {k: v for k, v in msg_dict.items() if v is not None}
|
|
||||||
|
|
||||||
if content_format == "openai":
|
|
||||||
# OpenAI format: preserve structured content list, normalize types
|
|
||||||
processed_content_parts = []
|
|
||||||
for chunk in msg_dict["content"]:
|
|
||||||
if isinstance(chunk, dict):
|
|
||||||
chunk_type = chunk.get("type")
|
|
||||||
|
|
||||||
if chunk_type == "image_url":
|
|
||||||
image_data.append(chunk["image_url"]["url"])
|
|
||||||
if chunk.get("modalities"):
|
|
||||||
modalities.append(chunk.get("modalities"))
|
|
||||||
# Normalize to simple 'image' type for template compatibility
|
|
||||||
processed_content_parts.append({"type": "image"})
|
|
||||||
elif chunk_type == "audio_url":
|
|
||||||
audio_data.append(chunk["audio_url"]["url"])
|
|
||||||
# Normalize to simple 'audio' type
|
|
||||||
processed_content_parts.append({"type": "audio"})
|
|
||||||
else:
|
|
||||||
# Keep other content as-is (text, etc.)
|
|
||||||
processed_content_parts.append(chunk)
|
|
||||||
|
|
||||||
new_msg = {
|
|
||||||
k: v for k, v in msg_dict.items() if v is not None and k != "content"
|
|
||||||
}
|
|
||||||
new_msg["content"] = processed_content_parts
|
|
||||||
return new_msg
|
|
||||||
|
|
||||||
else: # content_format == "string"
|
|
||||||
# String format: flatten to text only (for templates like DeepSeek)
|
|
||||||
text_parts = []
|
|
||||||
for chunk in msg_dict["content"]:
|
|
||||||
if isinstance(chunk, dict) and chunk.get("type") == "text":
|
|
||||||
text_parts.append(chunk["text"])
|
|
||||||
# Note: For string format, we ignore images/audio since the template
|
|
||||||
# doesn't expect structured content - multimodal placeholders would
|
|
||||||
# need to be inserted differently
|
|
||||||
|
|
||||||
new_msg = msg_dict.copy()
|
|
||||||
new_msg["content"] = " ".join(text_parts) if text_parts else ""
|
|
||||||
new_msg = {k: v for k, v in new_msg.items() if v is not None}
|
|
||||||
return new_msg
|
|
||||||
|
|
||||||
|
|
||||||
def to_openai_style_logprobs(
|
def to_openai_style_logprobs(
|
||||||
input_token_logprobs=None,
|
input_token_logprobs=None,
|
||||||
output_token_logprobs=None,
|
output_token_logprobs=None,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
|||||||
from partial_json_parser.core.exceptions import MalformedJSON
|
from partial_json_parser.core.exceptions import MalformedJSON
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import (
|
||||||
StreamingParseResult,
|
StreamingParseResult,
|
||||||
ToolCallItem,
|
ToolCallItem,
|
||||||
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
|
|||||||
_is_complete_json,
|
_is_complete_json,
|
||||||
_partial_json_loads,
|
_partial_json_loads,
|
||||||
)
|
)
|
||||||
from sglang.srt.openai_api.protocol import Tool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import (
|
||||||
StreamingParseResult,
|
StreamingParseResult,
|
||||||
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
from sglang.srt.function_call.utils import _is_complete_json
|
from sglang.srt.function_call.utils import _is_complete_json
|
||||||
from sglang.srt.openai_api.protocol import Tool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
|
StructuralTagResponseFormat,
|
||||||
|
StructuresResponseFormat,
|
||||||
|
Tool,
|
||||||
|
ToolChoice,
|
||||||
|
)
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import ToolCallItem
|
from sglang.srt.function_call.core_types import ToolCallItem
|
||||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||||
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
|
|||||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
from sglang.srt.openai_api.protocol import (
|
|
||||||
StructuralTagResponseFormat,
|
|
||||||
StructuresResponseFormat,
|
|
||||||
Tool,
|
|
||||||
ToolChoice,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import (
|
||||||
StreamingParseResult,
|
StreamingParseResult,
|
||||||
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
|
|||||||
_GetInfoFunc,
|
_GetInfoFunc,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
from sglang.srt.openai_api.protocol import Tool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import (
|
||||||
StreamingParseResult,
|
StreamingParseResult,
|
||||||
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
|
|||||||
_GetInfoFunc,
|
_GetInfoFunc,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
from sglang.srt.openai_api.protocol import Tool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import (
|
||||||
StreamingParseResult,
|
StreamingParseResult,
|
||||||
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
|
|||||||
_GetInfoFunc,
|
_GetInfoFunc,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
from sglang.srt.openai_api.protocol import Tool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import (
|
||||||
StreamingParseResult,
|
StreamingParseResult,
|
||||||
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
|
|||||||
_GetInfoFunc,
|
_GetInfoFunc,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
from sglang.srt.openai_api.protocol import Tool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""
|
"""Template utilities for Jinja template processing.
|
||||||
Utility functions for OpenAI API adapter.
|
|
||||||
|
This module provides utilities for analyzing and processing Jinja chat templates,
|
||||||
|
including content format detection and message processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import jinja2.nodes
|
import jinja2
|
||||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def detect_template_content_format(chat_template: str) -> str:
|
def detect_jinja_template_content_format(chat_template: str) -> str:
|
||||||
"""
|
"""
|
||||||
Detect whether a chat template expects 'string' or 'openai' content format.
|
Detect whether a chat template expects 'string' or 'openai' content format.
|
||||||
|
|
||||||
@@ -864,12 +864,6 @@ class SetInternalStateReq:
|
|||||||
server_args: Dict[str, Any]
|
server_args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class V1RerankReqInput:
|
|
||||||
query: str
|
|
||||||
documents: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SetInternalStateReqOutput:
|
class SetInternalStateReqOutput:
|
||||||
updated: bool
|
updated: bool
|
||||||
|
|||||||
226
python/sglang/srt/managers/template_manager.py
Normal file
226
python/sglang/srt/managers/template_manager.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Centralized template management for chat templates and completion templates.
|
||||||
|
|
||||||
|
This module provides a unified interface for managing both chat conversation templates
|
||||||
|
and code completion templates, eliminating global state and improving modularity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sglang.srt.code_completion_parser import (
|
||||||
|
CompletionTemplate,
|
||||||
|
FimPosition,
|
||||||
|
completion_template_exists,
|
||||||
|
register_completion_template,
|
||||||
|
)
|
||||||
|
from sglang.srt.conversation import (
|
||||||
|
Conversation,
|
||||||
|
SeparatorStyle,
|
||||||
|
chat_template_exists,
|
||||||
|
get_conv_template_by_model_path,
|
||||||
|
register_conv_template,
|
||||||
|
)
|
||||||
|
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateManager:
|
||||||
|
"""
|
||||||
|
Centralized manager for chat and completion templates.
|
||||||
|
|
||||||
|
This class encapsulates all template-related state and operations,
|
||||||
|
eliminating the need for global variables and providing a clean
|
||||||
|
interface for template management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._chat_template_name: Optional[str] = None
|
||||||
|
self._completion_template_name: Optional[str] = None
|
||||||
|
self._jinja_template_content_format: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_template_name(self) -> Optional[str]:
|
||||||
|
"""Get the current chat template name."""
|
||||||
|
return self._chat_template_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completion_template_name(self) -> Optional[str]:
|
||||||
|
"""Get the current completion template name."""
|
||||||
|
return self._completion_template_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def jinja_template_content_format(self) -> Optional[str]:
|
||||||
|
"""Get the detected template content format ('string' or 'openai' or None)."""
|
||||||
|
return self._jinja_template_content_format
|
||||||
|
|
||||||
|
def load_chat_template(
|
||||||
|
self, tokenizer_manager, chat_template_arg: str, model_path: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Load a chat template from various sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer_manager: The tokenizer manager instance
|
||||||
|
chat_template_arg: Template name or file path
|
||||||
|
model_path: Path to the model
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading chat template: {chat_template_arg}")
|
||||||
|
|
||||||
|
if not chat_template_exists(chat_template_arg):
|
||||||
|
if not os.path.exists(chat_template_arg):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Chat template {chat_template_arg} is not a built-in template name "
|
||||||
|
"or a valid chat template file path."
|
||||||
|
)
|
||||||
|
|
||||||
|
if chat_template_arg.endswith(".jinja"):
|
||||||
|
self._load_jinja_template(tokenizer_manager, chat_template_arg)
|
||||||
|
else:
|
||||||
|
self._load_json_chat_template(chat_template_arg)
|
||||||
|
else:
|
||||||
|
self._chat_template_name = chat_template_arg
|
||||||
|
|
||||||
|
def guess_chat_template_from_model_path(self, model_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Infer chat template name from model path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model
|
||||||
|
"""
|
||||||
|
template_name = get_conv_template_by_model_path(model_path)
|
||||||
|
if template_name is not None:
|
||||||
|
logger.info(f"Inferred chat template from model path: {template_name}")
|
||||||
|
self._chat_template_name = template_name
|
||||||
|
|
||||||
|
def load_completion_template(self, completion_template_arg: str) -> None:
|
||||||
|
"""
|
||||||
|
Load completion template for code completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
completion_template_arg: Template name or file path
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading completion template: {completion_template_arg}")
|
||||||
|
|
||||||
|
if not completion_template_exists(completion_template_arg):
|
||||||
|
if not os.path.exists(completion_template_arg):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Completion template {completion_template_arg} is not a built-in template name "
|
||||||
|
"or a valid completion template file path."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._load_json_completion_template(completion_template_arg)
|
||||||
|
else:
|
||||||
|
self._completion_template_name = completion_template_arg
|
||||||
|
|
||||||
|
def initialize_templates(
|
||||||
|
self,
|
||||||
|
tokenizer_manager,
|
||||||
|
model_path: str,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
completion_template: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize all templates based on provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer_manager: The tokenizer manager instance
|
||||||
|
model_path: Path to the model
|
||||||
|
chat_template: Optional chat template name/path
|
||||||
|
completion_template: Optional completion template name/path
|
||||||
|
"""
|
||||||
|
# Load chat template
|
||||||
|
if chat_template:
|
||||||
|
self.load_chat_template(tokenizer_manager, chat_template, model_path)
|
||||||
|
else:
|
||||||
|
self.guess_chat_template_from_model_path(model_path)
|
||||||
|
|
||||||
|
# Load completion template
|
||||||
|
if completion_template:
|
||||||
|
self.load_completion_template(completion_template)
|
||||||
|
|
||||||
|
def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
|
||||||
|
"""Load a Jinja template file."""
|
||||||
|
with open(template_path, "r") as f:
|
||||||
|
chat_template = "".join(f.readlines()).strip("\n")
|
||||||
|
tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
|
||||||
|
self._chat_template_name = None
|
||||||
|
# Detect content format from the loaded template
|
||||||
|
self._jinja_template_content_format = detect_jinja_template_content_format(
|
||||||
|
chat_template
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Detected chat template content format: {self._jinja_template_content_format}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_json_chat_template(self, template_path: str) -> None:
|
||||||
|
"""Load a JSON chat template file."""
|
||||||
|
assert template_path.endswith(
|
||||||
|
".json"
|
||||||
|
), "unrecognized format of chat template file"
|
||||||
|
|
||||||
|
with open(template_path, "r") as filep:
|
||||||
|
template = json.load(filep)
|
||||||
|
try:
|
||||||
|
sep_style = SeparatorStyle[template["sep_style"]]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown separator style: {template['sep_style']}"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name=template["name"],
|
||||||
|
system_template=template["system"] + "\n{system_message}",
|
||||||
|
system_message=template.get("system_message", ""),
|
||||||
|
roles=(template["user"], template["assistant"]),
|
||||||
|
sep_style=sep_style,
|
||||||
|
sep=template.get("sep", "\n"),
|
||||||
|
stop_str=template["stop_str"],
|
||||||
|
),
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
self._chat_template_name = template["name"]
|
||||||
|
|
||||||
|
def _load_json_completion_template(self, template_path: str) -> None:
|
||||||
|
"""Load a JSON completion template file."""
|
||||||
|
assert template_path.endswith(
|
||||||
|
".json"
|
||||||
|
), "unrecognized format of completion template file"
|
||||||
|
|
||||||
|
with open(template_path, "r") as filep:
|
||||||
|
template = json.load(filep)
|
||||||
|
try:
|
||||||
|
fim_position = FimPosition[template["fim_position"]]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown fim position: {template['fim_position']}"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
register_completion_template(
|
||||||
|
CompletionTemplate(
|
||||||
|
name=template["name"],
|
||||||
|
fim_begin_token=template["fim_begin_token"],
|
||||||
|
fim_middle_token=template["fim_middle_token"],
|
||||||
|
fim_end_token=template["fim_end_token"],
|
||||||
|
fim_position=fim_position,
|
||||||
|
),
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
self._completion_template_name = template["name"]
|
||||||
@@ -1058,12 +1058,7 @@ class TokenizerManager:
|
|||||||
"lora_path",
|
"lora_path",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
out_skip_names = set(
|
out_skip_names = set(["text", "output_ids", "embedding"])
|
||||||
[
|
|
||||||
"text",
|
|
||||||
"output_ids",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif self.log_requests_level == 1:
|
elif self.log_requests_level == 1:
|
||||||
max_length = 2048
|
max_length = 2048
|
||||||
elif self.log_requests_level == 2:
|
elif self.log_requests_level == 2:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,551 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Pydantic models for OpenAI API protocol"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_serializer, root_validator
|
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
|
||||||
"""Model cards."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
object: str = "model"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
owned_by: str = "sglang"
|
|
||||||
root: Optional[str] = None
|
|
||||||
max_model_len: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
|
||||||
"""Model list consists of model cards."""
|
|
||||||
|
|
||||||
object: str = "list"
|
|
||||||
data: List[ModelCard] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
|
||||||
object: str = "error"
|
|
||||||
message: str
|
|
||||||
type: str
|
|
||||||
param: Optional[str] = None
|
|
||||||
code: int
|
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
|
||||||
tokens: List[str] = Field(default_factory=list)
|
|
||||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class TopLogprob(BaseModel):
|
|
||||||
token: str
|
|
||||||
bytes: List[int]
|
|
||||||
logprob: float
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionTokenLogprob(BaseModel):
|
|
||||||
token: str
|
|
||||||
bytes: List[int]
|
|
||||||
logprob: float
|
|
||||||
top_logprobs: List[TopLogprob]
|
|
||||||
|
|
||||||
|
|
||||||
class ChoiceLogprobs(BaseModel):
|
|
||||||
# build for v1/chat/completions response
|
|
||||||
content: List[ChatCompletionTokenLogprob]
|
|
||||||
|
|
||||||
|
|
||||||
class UsageInfo(BaseModel):
|
|
||||||
prompt_tokens: int = 0
|
|
||||||
total_tokens: int = 0
|
|
||||||
completion_tokens: Optional[int] = 0
|
|
||||||
# only used to return cached tokens when --enable-cache-report is set
|
|
||||||
prompt_tokens_details: Optional[Dict[str, int]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class StreamOptions(BaseModel):
|
|
||||||
include_usage: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
class JsonSchemaResponseFormat(BaseModel):
|
|
||||||
name: str
|
|
||||||
description: Optional[str] = None
|
|
||||||
# use alias to workaround pydantic conflict
|
|
||||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
|
||||||
strict: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
class FileRequest(BaseModel):
|
|
||||||
# https://platform.openai.com/docs/api-reference/files/create
|
|
||||||
file: bytes # The File object (not file name) to be uploaded
|
|
||||||
purpose: str = (
|
|
||||||
"batch" # The intended purpose of the uploaded file, default is "batch"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FileResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: str = "file"
|
|
||||||
bytes: int
|
|
||||||
created_at: int
|
|
||||||
filename: str
|
|
||||||
purpose: str
|
|
||||||
|
|
||||||
|
|
||||||
class FileDeleteResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: str = "file"
|
|
||||||
deleted: bool
|
|
||||||
|
|
||||||
|
|
||||||
class BatchRequest(BaseModel):
|
|
||||||
input_file_id: (
|
|
||||||
str # The ID of an uploaded file that contains requests for the new batch
|
|
||||||
)
|
|
||||||
endpoint: str # The endpoint to be used for all requests in the batch
|
|
||||||
completion_window: str # The time frame within which the batch should be processed
|
|
||||||
metadata: Optional[dict] = None # Optional custom metadata for the batch
|
|
||||||
|
|
||||||
|
|
||||||
class BatchResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: str = "batch"
|
|
||||||
endpoint: str
|
|
||||||
errors: Optional[dict] = None
|
|
||||||
input_file_id: str
|
|
||||||
completion_window: str
|
|
||||||
status: str = "validating"
|
|
||||||
output_file_id: Optional[str] = None
|
|
||||||
error_file_id: Optional[str] = None
|
|
||||||
created_at: int
|
|
||||||
in_progress_at: Optional[int] = None
|
|
||||||
expires_at: Optional[int] = None
|
|
||||||
finalizing_at: Optional[int] = None
|
|
||||||
completed_at: Optional[int] = None
|
|
||||||
failed_at: Optional[int] = None
|
|
||||||
expired_at: Optional[int] = None
|
|
||||||
cancelling_at: Optional[int] = None
|
|
||||||
cancelled_at: Optional[int] = None
|
|
||||||
request_counts: Optional[dict] = None
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
|
||||||
# Ordered by official OpenAI API documentation
|
|
||||||
# https://platform.openai.com/docs/api-reference/completions/create
|
|
||||||
model: str
|
|
||||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
|
||||||
best_of: Optional[int] = None
|
|
||||||
echo: bool = False
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
logit_bias: Optional[Dict[str, float]] = None
|
|
||||||
logprobs: Optional[int] = None
|
|
||||||
max_tokens: int = 16
|
|
||||||
n: int = 1
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
seed: Optional[int] = None
|
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
|
||||||
stream: bool = False
|
|
||||||
stream_options: Optional[StreamOptions] = None
|
|
||||||
suffix: Optional[str] = None
|
|
||||||
temperature: float = 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
user: Optional[str] = None
|
|
||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
|
||||||
top_k: int = -1
|
|
||||||
min_p: float = 0.0
|
|
||||||
min_tokens: int = 0
|
|
||||||
json_schema: Optional[str] = None
|
|
||||||
regex: Optional[str] = None
|
|
||||||
ebnf: Optional[str] = None
|
|
||||||
repetition_penalty: float = 1.0
|
|
||||||
stop_token_ids: Optional[List[int]] = None
|
|
||||||
no_stop_trim: bool = False
|
|
||||||
ignore_eos: bool = False
|
|
||||||
skip_special_tokens: bool = True
|
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
|
||||||
session_params: Optional[Dict] = None
|
|
||||||
return_hidden_states: Optional[bool] = False
|
|
||||||
|
|
||||||
# For PD disaggregation
|
|
||||||
bootstrap_host: Optional[str] = None
|
|
||||||
bootstrap_port: Optional[int] = None
|
|
||||||
bootstrap_room: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
text: str
|
|
||||||
logprobs: Optional[LogProbs] = None
|
|
||||||
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
|
||||||
matched_stop: Union[None, int, str] = None
|
|
||||||
hidden_states: Optional[object] = None
|
|
||||||
|
|
||||||
@model_serializer
|
|
||||||
def _serialize(self):
|
|
||||||
return exclude_if_none(self, ["hidden_states"])
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: str = "text_completion"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
choices: List[CompletionResponseChoice]
|
|
||||||
usage: UsageInfo
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseStreamChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
text: str
|
|
||||||
logprobs: Optional[LogProbs] = None
|
|
||||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
|
||||||
matched_stop: Union[None, int, str] = None
|
|
||||||
hidden_states: Optional[object] = None
|
|
||||||
|
|
||||||
@model_serializer
|
|
||||||
def _serialize(self):
|
|
||||||
return exclude_if_none(self, ["hidden_states"])
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionStreamResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: str = "text_completion"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
choices: List[CompletionResponseStreamChoice]
|
|
||||||
usage: Optional[UsageInfo] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageContentTextPart(BaseModel):
|
|
||||||
type: Literal["text"]
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageContentImageURL(BaseModel):
|
|
||||||
url: str
|
|
||||||
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageContentAudioURL(BaseModel):
|
|
||||||
url: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageContentImagePart(BaseModel):
|
|
||||||
type: Literal["image_url"]
|
|
||||||
image_url: ChatCompletionMessageContentImageURL
|
|
||||||
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageContentAudioPart(BaseModel):
|
|
||||||
type: Literal["audio_url"]
|
|
||||||
audio_url: ChatCompletionMessageContentAudioURL
|
|
||||||
|
|
||||||
|
|
||||||
ChatCompletionMessageContentPart = Union[
|
|
||||||
ChatCompletionMessageContentTextPart,
|
|
||||||
ChatCompletionMessageContentImagePart,
|
|
||||||
ChatCompletionMessageContentAudioPart,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionResponse(BaseModel):
|
|
||||||
"""Function response."""
|
|
||||||
|
|
||||||
name: Optional[str] = None
|
|
||||||
arguments: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
|
||||||
"""Tool call response."""
|
|
||||||
|
|
||||||
id: Optional[str] = None
|
|
||||||
index: Optional[int] = None
|
|
||||||
type: Literal["function"] = "function"
|
|
||||||
function: FunctionResponse
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageGenericParam(BaseModel):
|
|
||||||
role: Literal["system", "assistant", "tool"]
|
|
||||||
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
|
|
||||||
tool_call_id: Optional[str] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
reasoning_content: Optional[str] = None
|
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageUserParam(BaseModel):
|
|
||||||
role: Literal["user"]
|
|
||||||
content: Union[str, List[ChatCompletionMessageContentPart]]
|
|
||||||
|
|
||||||
|
|
||||||
ChatCompletionMessageParam = Union[
|
|
||||||
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(BaseModel):
|
|
||||||
type: Literal["text", "json_object", "json_schema"]
|
|
||||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
|
||||||
|
|
||||||
|
|
||||||
class StructuresResponseFormat(BaseModel):
|
|
||||||
begin: str
|
|
||||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
|
||||||
end: str
|
|
||||||
|
|
||||||
|
|
||||||
class StructuralTagResponseFormat(BaseModel):
|
|
||||||
type: Literal["structural_tag"]
|
|
||||||
structures: List[StructuresResponseFormat]
|
|
||||||
triggers: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
|
||||||
"""Function descriptions."""
|
|
||||||
|
|
||||||
description: Optional[str] = Field(default=None, examples=[None])
|
|
||||||
name: Optional[str] = None
|
|
||||||
parameters: Optional[object] = None
|
|
||||||
strict: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel):
|
|
||||||
"""Function wrapper."""
|
|
||||||
|
|
||||||
type: str = Field(default="function", examples=["function"])
|
|
||||||
function: Function
|
|
||||||
|
|
||||||
|
|
||||||
class ToolChoiceFuncName(BaseModel):
|
|
||||||
"""The name of tool choice function."""
|
|
||||||
|
|
||||||
name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolChoice(BaseModel):
|
|
||||||
"""The tool choice definition."""
|
|
||||||
|
|
||||||
function: ToolChoiceFuncName
|
|
||||||
type: Literal["function"] = Field(default="function", examples=["function"])
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
|
||||||
# Ordered by official OpenAI API documentation
|
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
|
||||||
messages: List[ChatCompletionMessageParam]
|
|
||||||
model: str
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
logit_bias: Optional[Dict[str, float]] = None
|
|
||||||
logprobs: bool = False
|
|
||||||
top_logprobs: Optional[int] = None
|
|
||||||
max_tokens: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
|
|
||||||
description="The maximum number of tokens that can be generated in the chat completion. ",
|
|
||||||
)
|
|
||||||
max_completion_tokens: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The maximum number of completion tokens for a chat completion request, "
|
|
||||||
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
|
|
||||||
)
|
|
||||||
n: int = 1
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
|
||||||
seed: Optional[int] = None
|
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
|
||||||
stream: bool = False
|
|
||||||
stream_options: Optional[StreamOptions] = None
|
|
||||||
temperature: float = 0.7
|
|
||||||
top_p: float = 1.0
|
|
||||||
user: Optional[str] = None
|
|
||||||
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
|
||||||
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
|
||||||
default="auto", examples=["none"]
|
|
||||||
) # noqa
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def set_tool_choice_default(cls, values):
|
|
||||||
if values.get("tool_choice") is None:
|
|
||||||
if values.get("tools") is None:
|
|
||||||
values["tool_choice"] = "none"
|
|
||||||
else:
|
|
||||||
values["tool_choice"] = "auto"
|
|
||||||
return values
|
|
||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
|
||||||
top_k: int = -1
|
|
||||||
min_p: float = 0.0
|
|
||||||
min_tokens: int = 0
|
|
||||||
regex: Optional[str] = None
|
|
||||||
ebnf: Optional[str] = None
|
|
||||||
repetition_penalty: float = 1.0
|
|
||||||
stop_token_ids: Optional[List[int]] = None
|
|
||||||
no_stop_trim: bool = False
|
|
||||||
ignore_eos: bool = False
|
|
||||||
continue_final_message: bool = False
|
|
||||||
skip_special_tokens: bool = True
|
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
|
||||||
session_params: Optional[Dict] = None
|
|
||||||
separate_reasoning: bool = True
|
|
||||||
stream_reasoning: bool = True
|
|
||||||
chat_template_kwargs: Optional[Dict] = None
|
|
||||||
|
|
||||||
# The request id.
|
|
||||||
rid: Optional[str] = None
|
|
||||||
|
|
||||||
# For PD disaggregation
|
|
||||||
bootstrap_host: Optional[str] = None
|
|
||||||
bootstrap_port: Optional[int] = None
|
|
||||||
bootstrap_room: Optional[int] = None
|
|
||||||
|
|
||||||
# Hidden States
|
|
||||||
return_hidden_states: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
role: Optional[str] = None
|
|
||||||
content: Optional[str] = None
|
|
||||||
reasoning_content: Optional[str] = None
|
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
message: ChatMessage
|
|
||||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
|
||||||
finish_reason: Literal[
|
|
||||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
|
||||||
]
|
|
||||||
matched_stop: Union[None, int, str] = None
|
|
||||||
hidden_states: Optional[object] = None
|
|
||||||
|
|
||||||
@model_serializer
|
|
||||||
def _serialize(self):
|
|
||||||
return exclude_if_none(self, ["hidden_states"])
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: str = "chat.completion"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
choices: List[ChatCompletionResponseChoice]
|
|
||||||
usage: UsageInfo
|
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
|
||||||
role: Optional[str] = None
|
|
||||||
content: Optional[str] = None
|
|
||||||
reasoning_content: Optional[str] = None
|
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
|
||||||
hidden_states: Optional[object] = None
|
|
||||||
|
|
||||||
@model_serializer
|
|
||||||
def _serialize(self):
|
|
||||||
return exclude_if_none(self, ["hidden_states"])
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
delta: DeltaMessage
|
|
||||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
|
||||||
finish_reason: Optional[
|
|
||||||
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
|
|
||||||
] = None
|
|
||||||
matched_stop: Union[None, int, str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: str = "chat.completion.chunk"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
|
||||||
usage: Optional[UsageInfo] = None
|
|
||||||
|
|
||||||
|
|
||||||
class MultimodalEmbeddingInput(BaseModel):
|
|
||||||
text: Optional[str] = None
|
|
||||||
image: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
|
||||||
# Ordered by official OpenAI API documentation
|
|
||||||
# https://platform.openai.com/docs/api-reference/embeddings/create
|
|
||||||
input: Union[
|
|
||||||
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
|
||||||
]
|
|
||||||
model: str
|
|
||||||
encoding_format: str = "float"
|
|
||||||
dimensions: int = None
|
|
||||||
user: Optional[str] = None
|
|
||||||
|
|
||||||
# The request id.
|
|
||||||
rid: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingObject(BaseModel):
|
|
||||||
embedding: List[float]
|
|
||||||
index: int
|
|
||||||
object: str = "embedding"
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingResponse(BaseModel):
|
|
||||||
data: List[EmbeddingObject]
|
|
||||||
model: str
|
|
||||||
object: str = "list"
|
|
||||||
usage: Optional[UsageInfo] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringRequest(BaseModel):
|
|
||||||
query: Optional[Union[str, List[int]]] = (
|
|
||||||
None # Query text or pre-tokenized token IDs
|
|
||||||
)
|
|
||||||
items: Optional[Union[str, List[str], List[List[int]]]] = (
|
|
||||||
None # Item text(s) or pre-tokenized token IDs
|
|
||||||
)
|
|
||||||
label_token_ids: Optional[List[int]] = (
|
|
||||||
None # Token IDs to compute probabilities for
|
|
||||||
)
|
|
||||||
apply_softmax: bool = False
|
|
||||||
item_first: bool = False
|
|
||||||
model: str
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringResponse(BaseModel):
|
|
||||||
scores: List[
|
|
||||||
List[float]
|
|
||||||
] # List of lists of probabilities, each in the order of label_token_ids
|
|
||||||
model: str
|
|
||||||
usage: Optional[UsageInfo] = None
|
|
||||||
object: str = "scoring"
|
|
||||||
|
|
||||||
|
|
||||||
class RerankResponse(BaseModel):
|
|
||||||
score: float
|
|
||||||
document: str
|
|
||||||
index: int
|
|
||||||
meta_info: Optional[dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
def exclude_if_none(obj, field_names: List[str]):
|
|
||||||
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
|
|
||||||
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Tuple
|
from typing import Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
|
||||||
class StreamingParseResult:
|
class StreamingParseResult:
|
||||||
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
|
|||||||
One-time parsing: Detects and parses reasoning sections in the provided text.
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
||||||
Returns both reasoning content and normal text separately.
|
Returns both reasoning content and normal text separately.
|
||||||
"""
|
"""
|
||||||
text = text.replace(self.think_start_token, "").strip()
|
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
|
||||||
if self.think_end_token not in text:
|
|
||||||
|
if not in_reasoning:
|
||||||
|
return StreamingParseResult(normal_text=text)
|
||||||
|
|
||||||
|
# The text is considered to be in a reasoning block.
|
||||||
|
processed_text = text.replace(self.think_start_token, "").strip()
|
||||||
|
|
||||||
|
if self.think_end_token not in processed_text:
|
||||||
# Assume reasoning was truncated before `</think>` token
|
# Assume reasoning was truncated before `</think>` token
|
||||||
return StreamingParseResult(reasoning_text=text)
|
return StreamingParseResult(reasoning_text=processed_text)
|
||||||
|
|
||||||
# Extract reasoning content
|
# Extract reasoning content
|
||||||
splits = text.split(self.think_end_token, maxsplit=1)
|
splits = processed_text.split(self.think_end_token, maxsplit=1)
|
||||||
reasoning_text = splits[0]
|
reasoning_text = splits[0]
|
||||||
text = splits[1].strip()
|
normal_text = splits[1].strip()
|
||||||
|
|
||||||
return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
|
return StreamingParseResult(
|
||||||
|
normal_text=normal_text, reasoning_text=reasoning_text
|
||||||
|
)
|
||||||
|
|
||||||
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
||||||
"""
|
"""
|
||||||
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
|
|||||||
if not self.stripped_think_start and self.think_start_token in current_text:
|
if not self.stripped_think_start and self.think_start_token in current_text:
|
||||||
current_text = current_text.replace(self.think_start_token, "")
|
current_text = current_text.replace(self.think_start_token, "")
|
||||||
self.stripped_think_start = True
|
self.stripped_think_start = True
|
||||||
|
self._in_reasoning = True
|
||||||
|
|
||||||
# Handle end of reasoning block
|
# Handle end of reasoning block
|
||||||
if self._in_reasoning and self.think_end_token in current_text:
|
if self._in_reasoning and self.think_end_token in current_text:
|
||||||
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, stream_reasoning: bool = True):
|
def __init__(self, stream_reasoning: bool = True):
|
||||||
# Qwen3 is assumed to be reasoning until `</think>` token
|
# Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
|
||||||
super().__init__(
|
super().__init__(
|
||||||
"<think>",
|
"<think>",
|
||||||
"</think>",
|
"</think>",
|
||||||
force_reasoning=True,
|
force_reasoning=False,
|
||||||
stream_reasoning=stream_reasoning,
|
stream_reasoning=stream_reasoning,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -151,12 +161,12 @@ class ReasoningParser:
|
|||||||
If True, streams reasoning content as it arrives.
|
If True, streams reasoning content as it arrives.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
|
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
|
||||||
"deepseek-r1": DeepSeekR1Detector,
|
"deepseek-r1": DeepSeekR1Detector,
|
||||||
"qwen3": Qwen3Detector,
|
"qwen3": Qwen3Detector,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
|
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
|
||||||
if not model_type:
|
if not model_type:
|
||||||
raise ValueError("Model type must be specified")
|
raise ValueError("Model type must be specified")
|
||||||
|
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
# sglang/test/srt/openai/conftest.py
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
from contextlib import closing
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
|
|
||||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
||||||
|
|
||||||
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
|
|
||||||
DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
||||||
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
|
|
||||||
|
|
||||||
|
|
||||||
def _pick_free_port() -> int:
|
|
||||||
with closing(socket.socket()) as s:
|
|
||||||
s.bind(("127.0.0.1", 0))
|
|
||||||
return s.getsockname()[1]
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None:
|
|
||||||
start = time.perf_counter()
|
|
||||||
while time.perf_counter() - start < timeout:
|
|
||||||
if proc.poll() is not None: # crashed
|
|
||||||
raise RuntimeError("api_server terminated prematurely")
|
|
||||||
try:
|
|
||||||
if requests.get(f"{base}/health", timeout=1).status_code == 200:
|
|
||||||
return
|
|
||||||
except requests.RequestException:
|
|
||||||
pass
|
|
||||||
time.sleep(0.4)
|
|
||||||
raise RuntimeError("api_server readiness probe timed out")
|
|
||||||
|
|
||||||
|
|
||||||
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
|
|
||||||
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
|
|
||||||
port = _pick_free_port()
|
|
||||||
cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"-m",
|
|
||||||
SERVER_MODULE,
|
|
||||||
"--model-path",
|
|
||||||
model,
|
|
||||||
"--host",
|
|
||||||
"127.0.0.1",
|
|
||||||
"--port",
|
|
||||||
str(port),
|
|
||||||
*map(str, kw.get("args", [])),
|
|
||||||
]
|
|
||||||
env = {**os.environ, **kw.get("env", {})}
|
|
||||||
|
|
||||||
# Write logs to a temp file so the child never blocks on a full pipe.
|
|
||||||
log_file = tempfile.NamedTemporaryFile("w+", delete=False)
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
env=env,
|
|
||||||
stdout=log_file,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
base = f"http://127.0.0.1:{port}"
|
|
||||||
try:
|
|
||||||
_wait_until_healthy(proc, base, STARTUP_TIMEOUT)
|
|
||||||
except Exception as e:
|
|
||||||
proc.terminate()
|
|
||||||
proc.wait(5)
|
|
||||||
log_file.seek(0)
|
|
||||||
print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr)
|
|
||||||
raise e
|
|
||||||
return proc, base, log_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def openai_server() -> Generator[str, None, None]:
|
|
||||||
"""PyTest fixture that provides the server's base URL and cleans up."""
|
|
||||||
proc, base, log_file = launch_openai_server()
|
|
||||||
yield base
|
|
||||||
kill_process_tree(proc.pid)
|
|
||||||
log_file.close()
|
|
||||||
@@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
class TestModelCard(unittest.TestCase):
|
class TestModelCard(unittest.TestCase):
|
||||||
"""Test ModelCard protocol model"""
|
"""Test ModelCard protocol model"""
|
||||||
|
|
||||||
def test_basic_model_card_creation(self):
|
|
||||||
"""Test basic model card creation with required fields"""
|
|
||||||
card = ModelCard(id="test-model")
|
|
||||||
self.assertEqual(card.id, "test-model")
|
|
||||||
self.assertEqual(card.object, "model")
|
|
||||||
self.assertEqual(card.owned_by, "sglang")
|
|
||||||
self.assertIsInstance(card.created, int)
|
|
||||||
self.assertIsNone(card.root)
|
|
||||||
self.assertIsNone(card.max_model_len)
|
|
||||||
|
|
||||||
def test_model_card_with_optional_fields(self):
|
|
||||||
"""Test model card with optional fields"""
|
|
||||||
card = ModelCard(
|
|
||||||
id="test-model",
|
|
||||||
root="/path/to/model",
|
|
||||||
max_model_len=2048,
|
|
||||||
created=1234567890,
|
|
||||||
)
|
|
||||||
self.assertEqual(card.id, "test-model")
|
|
||||||
self.assertEqual(card.root, "/path/to/model")
|
|
||||||
self.assertEqual(card.max_model_len, 2048)
|
|
||||||
self.assertEqual(card.created, 1234567890)
|
|
||||||
|
|
||||||
def test_model_card_serialization(self):
|
def test_model_card_serialization(self):
|
||||||
"""Test model card JSON serialization"""
|
"""Test model card JSON serialization"""
|
||||||
card = ModelCard(id="test-model", max_model_len=4096)
|
card = ModelCard(id="test-model", max_model_len=4096)
|
||||||
@@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase):
|
|||||||
self.assertEqual(model_list.data[1].id, "model-2")
|
self.assertEqual(model_list.data[1].id, "model-2")
|
||||||
|
|
||||||
|
|
||||||
class TestErrorResponse(unittest.TestCase):
|
|
||||||
"""Test ErrorResponse protocol model"""
|
|
||||||
|
|
||||||
def test_basic_error_response(self):
|
|
||||||
"""Test basic error response creation"""
|
|
||||||
error = ErrorResponse(
|
|
||||||
message="Invalid request", type="BadRequestError", code=400
|
|
||||||
)
|
|
||||||
self.assertEqual(error.object, "error")
|
|
||||||
self.assertEqual(error.message, "Invalid request")
|
|
||||||
self.assertEqual(error.type, "BadRequestError")
|
|
||||||
self.assertEqual(error.code, 400)
|
|
||||||
self.assertIsNone(error.param)
|
|
||||||
|
|
||||||
def test_error_response_with_param(self):
|
|
||||||
"""Test error response with parameter"""
|
|
||||||
error = ErrorResponse(
|
|
||||||
message="Invalid temperature",
|
|
||||||
type="ValidationError",
|
|
||||||
code=422,
|
|
||||||
param="temperature",
|
|
||||||
)
|
|
||||||
self.assertEqual(error.param, "temperature")
|
|
||||||
|
|
||||||
|
|
||||||
class TestUsageInfo(unittest.TestCase):
|
|
||||||
"""Test UsageInfo protocol model"""
|
|
||||||
|
|
||||||
def test_basic_usage_info(self):
|
|
||||||
"""Test basic usage info creation"""
|
|
||||||
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
|
||||||
self.assertEqual(usage.prompt_tokens, 10)
|
|
||||||
self.assertEqual(usage.completion_tokens, 20)
|
|
||||||
self.assertEqual(usage.total_tokens, 30)
|
|
||||||
self.assertIsNone(usage.prompt_tokens_details)
|
|
||||||
|
|
||||||
def test_usage_info_with_cache_details(self):
|
|
||||||
"""Test usage info with cache details"""
|
|
||||||
usage = UsageInfo(
|
|
||||||
prompt_tokens=10,
|
|
||||||
completion_tokens=20,
|
|
||||||
total_tokens=30,
|
|
||||||
prompt_tokens_details={"cached_tokens": 5},
|
|
||||||
)
|
|
||||||
self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompletionRequest(unittest.TestCase):
|
class TestCompletionRequest(unittest.TestCase):
|
||||||
"""Test CompletionRequest protocol model"""
|
"""Test CompletionRequest protocol model"""
|
||||||
|
|
||||||
@@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase):
|
|||||||
self.assertFalse(request.stream) # default
|
self.assertFalse(request.stream) # default
|
||||||
self.assertFalse(request.echo) # default
|
self.assertFalse(request.echo) # default
|
||||||
|
|
||||||
def test_completion_request_with_options(self):
|
|
||||||
"""Test completion request with various options"""
|
|
||||||
request = CompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
prompt=["Hello", "world"],
|
|
||||||
max_tokens=100,
|
|
||||||
temperature=0.7,
|
|
||||||
top_p=0.9,
|
|
||||||
n=2,
|
|
||||||
stream=True,
|
|
||||||
echo=True,
|
|
||||||
stop=[".", "!"],
|
|
||||||
logprobs=5,
|
|
||||||
)
|
|
||||||
self.assertEqual(request.prompt, ["Hello", "world"])
|
|
||||||
self.assertEqual(request.max_tokens, 100)
|
|
||||||
self.assertEqual(request.temperature, 0.7)
|
|
||||||
self.assertEqual(request.top_p, 0.9)
|
|
||||||
self.assertEqual(request.n, 2)
|
|
||||||
self.assertTrue(request.stream)
|
|
||||||
self.assertTrue(request.echo)
|
|
||||||
self.assertEqual(request.stop, [".", "!"])
|
|
||||||
self.assertEqual(request.logprobs, 5)
|
|
||||||
|
|
||||||
def test_completion_request_sglang_extensions(self):
|
def test_completion_request_sglang_extensions(self):
|
||||||
"""Test completion request with SGLang-specific extensions"""
|
"""Test completion request with SGLang-specific extensions"""
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
@@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase):
|
|||||||
CompletionRequest(model="test-model") # missing prompt
|
CompletionRequest(model="test-model") # missing prompt
|
||||||
|
|
||||||
|
|
||||||
class TestCompletionResponse(unittest.TestCase):
|
|
||||||
"""Test CompletionResponse protocol model"""
|
|
||||||
|
|
||||||
def test_basic_completion_response(self):
|
|
||||||
"""Test basic completion response"""
|
|
||||||
choice = CompletionResponseChoice(
|
|
||||||
index=0, text="Hello world!", finish_reason="stop"
|
|
||||||
)
|
|
||||||
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
|
|
||||||
response = CompletionResponse(
|
|
||||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
|
||||||
)
|
|
||||||
self.assertEqual(response.id, "test-id")
|
|
||||||
self.assertEqual(response.object, "text_completion")
|
|
||||||
self.assertEqual(response.model, "test-model")
|
|
||||||
self.assertEqual(len(response.choices), 1)
|
|
||||||
self.assertEqual(response.choices[0].text, "Hello world!")
|
|
||||||
self.assertEqual(response.usage.total_tokens, 5)
|
|
||||||
|
|
||||||
|
|
||||||
class TestChatCompletionRequest(unittest.TestCase):
|
class TestChatCompletionRequest(unittest.TestCase):
|
||||||
"""Test ChatCompletionRequest protocol model"""
|
"""Test ChatCompletionRequest protocol model"""
|
||||||
|
|
||||||
@@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase):
|
|||||||
self.assertFalse(request.stream) # default
|
self.assertFalse(request.stream) # default
|
||||||
self.assertEqual(request.tool_choice, "none") # default when no tools
|
self.assertEqual(request.tool_choice, "none") # default when no tools
|
||||||
|
|
||||||
def test_chat_completion_with_multimodal_content(self):
|
|
||||||
"""Test chat completion with multimodal content"""
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "What's in this image?"},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
request = ChatCompletionRequest(model="test-model", messages=messages)
|
|
||||||
self.assertEqual(len(request.messages[0].content), 2)
|
|
||||||
self.assertEqual(request.messages[0].content[0].type, "text")
|
|
||||||
self.assertEqual(request.messages[0].content[1].type, "image_url")
|
|
||||||
|
|
||||||
def test_chat_completion_with_tools(self):
|
|
||||||
"""Test chat completion with tools"""
|
|
||||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
|
||||||
tools = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get weather information",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"location": {"type": "string"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model", messages=messages, tools=tools
|
|
||||||
)
|
|
||||||
self.assertEqual(len(request.tools), 1)
|
|
||||||
self.assertEqual(request.tools[0].function.name, "get_weather")
|
|
||||||
self.assertEqual(request.tool_choice, "auto") # default when tools present
|
|
||||||
|
|
||||||
def test_chat_completion_tool_choice_validation(self):
|
def test_chat_completion_tool_choice_validation(self):
|
||||||
"""Test tool choice validation logic"""
|
"""Test tool choice validation logic"""
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
@@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase):
|
|||||||
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||||
|
|
||||||
|
|
||||||
class TestChatCompletionResponse(unittest.TestCase):
|
|
||||||
"""Test ChatCompletionResponse protocol model"""
|
|
||||||
|
|
||||||
def test_basic_chat_completion_response(self):
|
|
||||||
"""Test basic chat completion response"""
|
|
||||||
message = ChatMessage(role="assistant", content="Hello there!")
|
|
||||||
choice = ChatCompletionResponseChoice(
|
|
||||||
index=0, message=message, finish_reason="stop"
|
|
||||||
)
|
|
||||||
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
|
|
||||||
response = ChatCompletionResponse(
|
|
||||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
|
||||||
)
|
|
||||||
self.assertEqual(response.id, "test-id")
|
|
||||||
self.assertEqual(response.object, "chat.completion")
|
|
||||||
self.assertEqual(response.model, "test-model")
|
|
||||||
self.assertEqual(len(response.choices), 1)
|
|
||||||
self.assertEqual(response.choices[0].message.content, "Hello there!")
|
|
||||||
|
|
||||||
def test_chat_completion_response_with_tool_calls(self):
|
|
||||||
"""Test chat completion response with tool calls"""
|
|
||||||
tool_call = ToolCall(
|
|
||||||
id="call_123",
|
|
||||||
function=FunctionResponse(
|
|
||||||
name="get_weather", arguments='{"location": "San Francisco"}'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
|
|
||||||
choice = ChatCompletionResponseChoice(
|
|
||||||
index=0, message=message, finish_reason="tool_calls"
|
|
||||||
)
|
|
||||||
usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
|
||||||
response = ChatCompletionResponse(
|
|
||||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
response.choices[0].message.tool_calls[0].function.name, "get_weather"
|
|
||||||
)
|
|
||||||
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingRequest(unittest.TestCase):
|
|
||||||
"""Test EmbeddingRequest protocol model"""
|
|
||||||
|
|
||||||
def test_basic_embedding_request(self):
|
|
||||||
"""Test basic embedding request"""
|
|
||||||
request = EmbeddingRequest(model="test-model", input="Hello world")
|
|
||||||
self.assertEqual(request.model, "test-model")
|
|
||||||
self.assertEqual(request.input, "Hello world")
|
|
||||||
self.assertEqual(request.encoding_format, "float") # default
|
|
||||||
self.assertIsNone(request.dimensions) # default
|
|
||||||
|
|
||||||
def test_embedding_request_with_list_input(self):
|
|
||||||
"""Test embedding request with list input"""
|
|
||||||
request = EmbeddingRequest(
|
|
||||||
model="test-model", input=["Hello", "world"], dimensions=512
|
|
||||||
)
|
|
||||||
self.assertEqual(request.input, ["Hello", "world"])
|
|
||||||
self.assertEqual(request.dimensions, 512)
|
|
||||||
|
|
||||||
def test_multimodal_embedding_request(self):
|
|
||||||
"""Test multimodal embedding request"""
|
|
||||||
multimodal_input = [
|
|
||||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
|
||||||
MultimodalEmbeddingInput(text="World", image=None),
|
|
||||||
]
|
|
||||||
request = EmbeddingRequest(model="test-model", input=multimodal_input)
|
|
||||||
self.assertEqual(len(request.input), 2)
|
|
||||||
self.assertEqual(request.input[0].text, "Hello")
|
|
||||||
self.assertEqual(request.input[0].image, "base64_image_data")
|
|
||||||
self.assertEqual(request.input[1].text, "World")
|
|
||||||
self.assertIsNone(request.input[1].image)
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingResponse(unittest.TestCase):
|
|
||||||
"""Test EmbeddingResponse protocol model"""
|
|
||||||
|
|
||||||
def test_basic_embedding_response(self):
|
|
||||||
"""Test basic embedding response"""
|
|
||||||
embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
|
|
||||||
usage = UsageInfo(prompt_tokens=3, total_tokens=3)
|
|
||||||
response = EmbeddingResponse(
|
|
||||||
data=[embedding_obj], model="test-model", usage=usage
|
|
||||||
)
|
|
||||||
self.assertEqual(response.object, "list")
|
|
||||||
self.assertEqual(len(response.data), 1)
|
|
||||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
|
||||||
self.assertEqual(response.data[0].index, 0)
|
|
||||||
self.assertEqual(response.usage.prompt_tokens, 3)
|
|
||||||
|
|
||||||
|
|
||||||
class TestScoringRequest(unittest.TestCase):
|
|
||||||
"""Test ScoringRequest protocol model"""
|
|
||||||
|
|
||||||
def test_basic_scoring_request(self):
|
|
||||||
"""Test basic scoring request"""
|
|
||||||
request = ScoringRequest(
|
|
||||||
model="test-model", query="Hello", items=["World", "Earth"]
|
|
||||||
)
|
|
||||||
self.assertEqual(request.model, "test-model")
|
|
||||||
self.assertEqual(request.query, "Hello")
|
|
||||||
self.assertEqual(request.items, ["World", "Earth"])
|
|
||||||
self.assertFalse(request.apply_softmax) # default
|
|
||||||
self.assertFalse(request.item_first) # default
|
|
||||||
|
|
||||||
def test_scoring_request_with_token_ids(self):
|
|
||||||
"""Test scoring request with token IDs"""
|
|
||||||
request = ScoringRequest(
|
|
||||||
model="test-model",
|
|
||||||
query=[1, 2, 3],
|
|
||||||
items=[[4, 5], [6, 7]],
|
|
||||||
label_token_ids=[8, 9],
|
|
||||||
apply_softmax=True,
|
|
||||||
item_first=True,
|
|
||||||
)
|
|
||||||
self.assertEqual(request.query, [1, 2, 3])
|
|
||||||
self.assertEqual(request.items, [[4, 5], [6, 7]])
|
|
||||||
self.assertEqual(request.label_token_ids, [8, 9])
|
|
||||||
self.assertTrue(request.apply_softmax)
|
|
||||||
self.assertTrue(request.item_first)
|
|
||||||
|
|
||||||
|
|
||||||
class TestScoringResponse(unittest.TestCase):
|
|
||||||
"""Test ScoringResponse protocol model"""
|
|
||||||
|
|
||||||
def test_basic_scoring_response(self):
|
|
||||||
"""Test basic scoring response"""
|
|
||||||
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
|
|
||||||
self.assertEqual(response.object, "scoring")
|
|
||||||
self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
|
|
||||||
self.assertEqual(response.model, "test-model")
|
|
||||||
self.assertIsNone(response.usage) # default
|
|
||||||
|
|
||||||
|
|
||||||
class TestFileOperations(unittest.TestCase):
|
|
||||||
"""Test file operation protocol models"""
|
|
||||||
|
|
||||||
def test_file_request(self):
|
|
||||||
"""Test file request model"""
|
|
||||||
file_data = b"test file content"
|
|
||||||
request = FileRequest(file=file_data, purpose="batch")
|
|
||||||
self.assertEqual(request.file, file_data)
|
|
||||||
self.assertEqual(request.purpose, "batch")
|
|
||||||
|
|
||||||
def test_file_response(self):
|
|
||||||
"""Test file response model"""
|
|
||||||
response = FileResponse(
|
|
||||||
id="file-123",
|
|
||||||
bytes=1024,
|
|
||||||
created_at=1234567890,
|
|
||||||
filename="test.jsonl",
|
|
||||||
purpose="batch",
|
|
||||||
)
|
|
||||||
self.assertEqual(response.id, "file-123")
|
|
||||||
self.assertEqual(response.object, "file")
|
|
||||||
self.assertEqual(response.bytes, 1024)
|
|
||||||
self.assertEqual(response.filename, "test.jsonl")
|
|
||||||
|
|
||||||
def test_file_delete_response(self):
|
|
||||||
"""Test file delete response model"""
|
|
||||||
response = FileDeleteResponse(id="file-123", deleted=True)
|
|
||||||
self.assertEqual(response.id, "file-123")
|
|
||||||
self.assertEqual(response.object, "file")
|
|
||||||
self.assertTrue(response.deleted)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchOperations(unittest.TestCase):
|
|
||||||
"""Test batch operation protocol models"""
|
|
||||||
|
|
||||||
def test_batch_request(self):
|
|
||||||
"""Test batch request model"""
|
|
||||||
request = BatchRequest(
|
|
||||||
input_file_id="file-123",
|
|
||||||
endpoint="/v1/chat/completions",
|
|
||||||
completion_window="24h",
|
|
||||||
metadata={"custom": "value"},
|
|
||||||
)
|
|
||||||
self.assertEqual(request.input_file_id, "file-123")
|
|
||||||
self.assertEqual(request.endpoint, "/v1/chat/completions")
|
|
||||||
self.assertEqual(request.completion_window, "24h")
|
|
||||||
self.assertEqual(request.metadata, {"custom": "value"})
|
|
||||||
|
|
||||||
def test_batch_response(self):
|
|
||||||
"""Test batch response model"""
|
|
||||||
response = BatchResponse(
|
|
||||||
id="batch-123",
|
|
||||||
endpoint="/v1/chat/completions",
|
|
||||||
input_file_id="file-123",
|
|
||||||
completion_window="24h",
|
|
||||||
created_at=1234567890,
|
|
||||||
)
|
|
||||||
self.assertEqual(response.id, "batch-123")
|
|
||||||
self.assertEqual(response.object, "batch")
|
|
||||||
self.assertEqual(response.status, "validating") # default
|
|
||||||
self.assertEqual(response.endpoint, "/v1/chat/completions")
|
|
||||||
|
|
||||||
|
|
||||||
class TestResponseFormats(unittest.TestCase):
|
|
||||||
"""Test response format protocol models"""
|
|
||||||
|
|
||||||
def test_basic_response_format(self):
|
|
||||||
"""Test basic response format"""
|
|
||||||
format_obj = ResponseFormat(type="json_object")
|
|
||||||
self.assertEqual(format_obj.type, "json_object")
|
|
||||||
self.assertIsNone(format_obj.json_schema)
|
|
||||||
|
|
||||||
def test_json_schema_response_format(self):
|
|
||||||
"""Test JSON schema response format"""
|
|
||||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
|
||||||
json_schema = JsonSchemaResponseFormat(
|
|
||||||
name="person_schema", description="Person schema", schema=schema
|
|
||||||
)
|
|
||||||
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
|
|
||||||
self.assertEqual(format_obj.type, "json_schema")
|
|
||||||
self.assertEqual(format_obj.json_schema.name, "person_schema")
|
|
||||||
self.assertEqual(format_obj.json_schema.schema_, schema)
|
|
||||||
|
|
||||||
def test_structural_tag_response_format(self):
|
|
||||||
"""Test structural tag response format"""
|
|
||||||
structures = [
|
|
||||||
{
|
|
||||||
"begin": "<thinking>",
|
|
||||||
"schema_": {"type": "string"},
|
|
||||||
"end": "</thinking>",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
format_obj = StructuralTagResponseFormat(
|
|
||||||
type="structural_tag", structures=structures, triggers=["think"]
|
|
||||||
)
|
|
||||||
self.assertEqual(format_obj.type, "structural_tag")
|
|
||||||
self.assertEqual(len(format_obj.structures), 1)
|
|
||||||
self.assertEqual(format_obj.triggers, ["think"])
|
|
||||||
|
|
||||||
|
|
||||||
class TestLogProbs(unittest.TestCase):
|
|
||||||
"""Test LogProbs protocol models"""
|
|
||||||
|
|
||||||
def test_basic_logprobs(self):
|
|
||||||
"""Test basic LogProbs model"""
|
|
||||||
logprobs = LogProbs(
|
|
||||||
text_offset=[0, 5, 11],
|
|
||||||
token_logprobs=[-0.1, -0.2, -0.3],
|
|
||||||
tokens=["Hello", " ", "world"],
|
|
||||||
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
|
|
||||||
)
|
|
||||||
self.assertEqual(len(logprobs.tokens), 3)
|
|
||||||
self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
|
|
||||||
self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
|
|
||||||
|
|
||||||
def test_choice_logprobs(self):
|
|
||||||
"""Test ChoiceLogprobs model"""
|
|
||||||
token_logprob = ChatCompletionTokenLogprob(
|
|
||||||
token="Hello",
|
|
||||||
bytes=[72, 101, 108, 108, 111],
|
|
||||||
logprob=-0.1,
|
|
||||||
top_logprobs=[
|
|
||||||
TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
|
|
||||||
self.assertEqual(len(choice_logprobs.content), 1)
|
|
||||||
self.assertEqual(choice_logprobs.content[0].token, "Hello")
|
|
||||||
|
|
||||||
|
|
||||||
class TestStreamingModels(unittest.TestCase):
|
|
||||||
"""Test streaming response models"""
|
|
||||||
|
|
||||||
def test_stream_options(self):
|
|
||||||
"""Test StreamOptions model"""
|
|
||||||
options = StreamOptions(include_usage=True)
|
|
||||||
self.assertTrue(options.include_usage)
|
|
||||||
|
|
||||||
def test_chat_completion_stream_response(self):
|
|
||||||
"""Test ChatCompletionStreamResponse model"""
|
|
||||||
delta = DeltaMessage(role="assistant", content="Hello")
|
|
||||||
choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
|
|
||||||
response = ChatCompletionStreamResponse(
|
|
||||||
id="test-id", model="test-model", choices=[choice]
|
|
||||||
)
|
|
||||||
self.assertEqual(response.object, "chat.completion.chunk")
|
|
||||||
self.assertEqual(response.choices[0].delta.content, "Hello")
|
|
||||||
|
|
||||||
|
|
||||||
class TestModelSerialization(unittest.TestCase):
|
class TestModelSerialization(unittest.TestCase):
|
||||||
"""Test model serialization with hidden states"""
|
"""Test model serialization with hidden states"""
|
||||||
|
|
||||||
@@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase):
|
|||||||
class TestValidationEdgeCases(unittest.TestCase):
|
class TestValidationEdgeCases(unittest.TestCase):
|
||||||
"""Test edge cases and validation scenarios"""
|
"""Test edge cases and validation scenarios"""
|
||||||
|
|
||||||
def test_empty_messages_validation(self):
|
|
||||||
"""Test validation with empty messages"""
|
|
||||||
with self.assertRaises(ValidationError):
|
|
||||||
ChatCompletionRequest(model="test-model", messages=[])
|
|
||||||
|
|
||||||
def test_invalid_tool_choice_type(self):
|
def test_invalid_tool_choice_type(self):
|
||||||
"""Test invalid tool choice type"""
|
"""Test invalid tool choice type"""
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
@@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase):
|
|||||||
with self.assertRaises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
||||||
|
|
||||||
def test_invalid_temperature_range(self):
|
|
||||||
"""Test invalid temperature values"""
|
|
||||||
# Note: The current protocol doesn't enforce temperature range,
|
|
||||||
# but this test documents expected behavior
|
|
||||||
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
|
|
||||||
self.assertEqual(request.temperature, 5.0) # Currently allowed
|
|
||||||
|
|
||||||
def test_model_serialization_roundtrip(self):
|
def test_model_serialization_roundtrip(self):
|
||||||
"""Test that models can be serialized and deserialized"""
|
"""Test that models can be serialized and deserialized"""
|
||||||
original_request = ChatCompletionRequest(
|
original_request = ChatCompletionRequest(
|
||||||
|
|||||||
@@ -1,52 +0,0 @@
|
|||||||
# sglang/test/srt/openai/test_server.py
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
|
|
||||||
|
|
||||||
|
|
||||||
def test_health(openai_server: str):
|
|
||||||
r = requests.get(f"{openai_server}/health")
|
|
||||||
assert r.status_code == 200
|
|
||||||
# FastAPI returns an empty body → r.text == ""
|
|
||||||
assert r.text == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_models_endpoint(openai_server: str):
|
|
||||||
r = requests.get(f"{openai_server}/v1/models")
|
|
||||||
assert r.status_code == 200, r.text
|
|
||||||
payload = r.json()
|
|
||||||
|
|
||||||
# Basic contract
|
|
||||||
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
|
|
||||||
|
|
||||||
# Validate fields of the first model card
|
|
||||||
first = payload["data"][0]
|
|
||||||
for key in ("id", "root", "max_model_len"):
|
|
||||||
assert key in first, f"missing {key} in {first}"
|
|
||||||
|
|
||||||
# max_model_len must be positive
|
|
||||||
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
|
|
||||||
|
|
||||||
# The server should report the same model id we launched it with
|
|
||||||
ids = {m["id"] for m in payload["data"]}
|
|
||||||
assert MODEL_ID in ids
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_info(openai_server: str):
|
|
||||||
r = requests.get(f"{openai_server}/get_model_info")
|
|
||||||
assert r.status_code == 200, r.text
|
|
||||||
info = r.json()
|
|
||||||
|
|
||||||
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
|
|
||||||
assert expected_keys.issubset(info.keys())
|
|
||||||
|
|
||||||
# model_path must end with the one we passed on the CLI
|
|
||||||
assert info["model_path"].endswith(MODEL_ID)
|
|
||||||
|
|
||||||
# is_generation is documented as a boolean
|
|
||||||
assert isinstance(info["is_generation"], bool)
|
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_route_returns_404(openai_server: str):
|
|
||||||
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
|
|
||||||
assert r.status_code == 404
|
|
||||||
@@ -57,11 +57,21 @@ class _MockTokenizerManager:
|
|||||||
self.create_abort_task = Mock()
|
self.create_abort_task = Mock()
|
||||||
|
|
||||||
|
|
||||||
|
class _MockTemplateManager:
|
||||||
|
"""Minimal mock for TemplateManager."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.chat_template_name: Optional[str] = "llama-3"
|
||||||
|
self.jinja_template_content_format: Optional[str] = None
|
||||||
|
self.completion_template_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ServingChatTestCase(unittest.TestCase):
|
class ServingChatTestCase(unittest.TestCase):
|
||||||
# ------------- common fixtures -------------
|
# ------------- common fixtures -------------
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tm = _MockTokenizerManager()
|
self.tm = _MockTokenizerManager()
|
||||||
self.chat = OpenAIServingChat(self.tm)
|
self.template_manager = _MockTemplateManager()
|
||||||
|
self.chat = OpenAIServingChat(self.tm, self.template_manager)
|
||||||
|
|
||||||
# frequently reused requests
|
# frequently reused requests
|
||||||
self.basic_req = ChatCompletionRequest(
|
self.basic_req = ChatCompletionRequest(
|
||||||
@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
|
|||||||
self.assertFalse(adapted.stream)
|
self.assertFalse(adapted.stream)
|
||||||
self.assertEqual(processed, self.basic_req)
|
self.assertEqual(processed, self.basic_req)
|
||||||
|
|
||||||
# # ------------- tool-call branch -------------
|
|
||||||
# def test_tool_call_request_conversion(self):
|
|
||||||
# req = ChatCompletionRequest(
|
|
||||||
# model="x",
|
|
||||||
# messages=[{"role": "user", "content": "Weather?"}],
|
|
||||||
# tools=[
|
|
||||||
# {
|
|
||||||
# "type": "function",
|
|
||||||
# "function": {
|
|
||||||
# "name": "get_weather",
|
|
||||||
# "parameters": {"type": "object", "properties": {}},
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
# ],
|
|
||||||
# tool_choice="auto",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# with patch.object(
|
|
||||||
# self.chat,
|
|
||||||
# "_process_messages",
|
|
||||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
|
||||||
# ):
|
|
||||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
|
||||||
# self.assertEqual(adapted.rid, "rid")
|
|
||||||
|
|
||||||
# def test_tool_choice_none(self):
|
|
||||||
# req = ChatCompletionRequest(
|
|
||||||
# model="x",
|
|
||||||
# messages=[{"role": "user", "content": "Hi"}],
|
|
||||||
# tools=[{"type": "function", "function": {"name": "noop"}}],
|
|
||||||
# tool_choice="none",
|
|
||||||
# )
|
|
||||||
# with patch.object(
|
|
||||||
# self.chat,
|
|
||||||
# "_process_messages",
|
|
||||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
|
||||||
# ):
|
|
||||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
|
||||||
# self.assertEqual(adapted.rid, "rid")
|
|
||||||
|
|
||||||
# ------------- multimodal branch -------------
|
|
||||||
def test_multimodal_request_with_images(self):
|
|
||||||
self.tm.model_config.is_multimodal = True
|
|
||||||
|
|
||||||
req = ChatCompletionRequest(
|
|
||||||
model="x",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "What's in the image?"},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": "data:image/jpeg;base64,"},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
self.chat,
|
|
||||||
"_apply_jinja_template",
|
|
||||||
return_value=("prompt", [1, 2], ["img"], None, [], []),
|
|
||||||
), patch.object(
|
|
||||||
self.chat,
|
|
||||||
"_apply_conversation_template",
|
|
||||||
return_value=("prompt", ["img"], None, [], []),
|
|
||||||
):
|
|
||||||
out = self.chat._process_messages(req, True)
|
|
||||||
_, _, image_data, *_ = out
|
|
||||||
self.assertEqual(image_data, ["img"])
|
|
||||||
|
|
||||||
# ------------- template handling -------------
|
|
||||||
def test_jinja_template_processing(self):
|
|
||||||
req = ChatCompletionRequest(
|
|
||||||
model="x", messages=[{"role": "user", "content": "Hello"}]
|
|
||||||
)
|
|
||||||
self.tm.chat_template_name = None
|
|
||||||
self.tm.tokenizer.chat_template = "<jinja>"
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
self.chat,
|
|
||||||
"_apply_jinja_template",
|
|
||||||
return_value=("processed", [1], None, None, [], ["</s>"]),
|
|
||||||
), patch("builtins.hasattr", return_value=True):
|
|
||||||
prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
|
|
||||||
self.assertEqual(prompt, "processed")
|
|
||||||
self.assertEqual(prompt_ids, [1])
|
|
||||||
|
|
||||||
# ------------- sampling-params -------------
|
# ------------- sampling-params -------------
|
||||||
def test_sampling_param_build(self):
|
def test_sampling_param_build(self):
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Run with:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Optional
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||||
@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
|
|||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
|
||||||
|
|
||||||
|
class _MockTemplateManager:
|
||||||
|
"""Minimal mock for TemplateManager."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.chat_template_name: Optional[str] = None
|
||||||
|
self.jinja_template_content_format: Optional[str] = None
|
||||||
|
self.completion_template_name: Optional[str] = (
|
||||||
|
None # Set to None to avoid template processing
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ServingCompletionTestCase(unittest.TestCase):
|
class ServingCompletionTestCase(unittest.TestCase):
|
||||||
"""Bundle all prompt/echo tests in one TestCase."""
|
"""Bundle all prompt/echo tests in one TestCase."""
|
||||||
|
|
||||||
@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
|
|||||||
tm.generate_request = AsyncMock()
|
tm.generate_request = AsyncMock()
|
||||||
tm.create_abort_task = Mock()
|
tm.create_abort_task = Mock()
|
||||||
|
|
||||||
self.sc = OpenAIServingCompletion(tm)
|
self.template_manager = _MockTemplateManager()
|
||||||
|
self.sc = OpenAIServingCompletion(tm, self.template_manager)
|
||||||
|
|
||||||
# ---------- prompt-handling ----------
|
# ---------- prompt-handling ----------
|
||||||
def test_single_string_prompt(self):
|
def test_single_string_prompt(self):
|
||||||
@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
|
|||||||
internal, _ = self.sc._convert_to_internal_request(req)
|
internal, _ = self.sc._convert_to_internal_request(req)
|
||||||
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||||
|
|
||||||
def test_completion_template_handling(self):
|
|
||||||
req = CompletionRequest(
|
|
||||||
model="x", prompt="def f():", suffix="return 1", max_tokens=100
|
|
||||||
)
|
|
||||||
with patch(
|
|
||||||
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
|
|
||||||
return_value=True,
|
|
||||||
), patch(
|
|
||||||
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
|
|
||||||
return_value="processed_prompt",
|
|
||||||
):
|
|
||||||
internal, _ = self.sc._convert_to_internal_request(req)
|
|
||||||
self.assertEqual(internal.text, "processed_prompt")
|
|
||||||
|
|
||||||
# ---------- echo-handling ----------
|
# ---------- echo-handling ----------
|
||||||
def test_echo_with_string_prompt_streaming(self):
|
def test_echo_with_string_prompt_streaming(self):
|
||||||
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
|
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
|
||||||
|
|||||||
@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
|
|||||||
with the original adapter.py functionality and follows OpenAI API specifications.
|
with the original adapter.py functionality and follows OpenAI API specifications.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List
|
from unittest.mock import Mock
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse
|
|
||||||
from pydantic_core import ValidationError
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
EmbeddingObject,
|
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ErrorResponse,
|
|
||||||
MultimodalEmbeddingInput,
|
MultimodalEmbeddingInput,
|
||||||
UsageInfo,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||||
@@ -58,11 +49,22 @@ class _MockTokenizerManager:
|
|||||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||||
|
|
||||||
|
|
||||||
|
# Mock TemplateManager for embedding tests
|
||||||
|
class _MockTemplateManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.chat_template_name = None # None for embeddings usually
|
||||||
|
self.jinja_template_content_format = None
|
||||||
|
self.completion_template_name = None
|
||||||
|
|
||||||
|
|
||||||
class ServingEmbeddingTestCase(unittest.TestCase):
|
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Set up test fixtures."""
|
"""Set up test fixtures."""
|
||||||
self.tokenizer_manager = _MockTokenizerManager()
|
self.tokenizer_manager = _MockTokenizerManager()
|
||||||
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
|
self.template_manager = _MockTemplateManager()
|
||||||
|
self.serving_embedding = OpenAIServingEmbedding(
|
||||||
|
self.tokenizer_manager, self.template_manager
|
||||||
|
)
|
||||||
|
|
||||||
self.request = Mock(spec=Request)
|
self.request = Mock(spec=Request)
|
||||||
self.request.headers = {}
|
self.request.headers = {}
|
||||||
@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
self.assertIsNone(adapted_request.image_data[1])
|
self.assertIsNone(adapted_request.image_data[1])
|
||||||
# self.assertEqual(adapted_request.rid, "test-id")
|
# self.assertEqual(adapted_request.rid, "test-id")
|
||||||
|
|
||||||
def test_build_single_embedding_response(self):
|
|
||||||
"""Test building response for single embedding."""
|
|
||||||
ret_data = [
|
|
||||||
{
|
|
||||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
||||||
"meta_info": {"prompt_tokens": 5},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
response = self.serving_embedding._build_embedding_response(ret_data)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, EmbeddingResponse)
|
|
||||||
self.assertEqual(response.model, "test-model")
|
|
||||||
self.assertEqual(len(response.data), 1)
|
|
||||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
|
||||||
self.assertEqual(response.data[0].index, 0)
|
|
||||||
self.assertEqual(response.data[0].object, "embedding")
|
|
||||||
self.assertEqual(response.usage.prompt_tokens, 5)
|
|
||||||
self.assertEqual(response.usage.total_tokens, 5)
|
|
||||||
self.assertEqual(response.usage.completion_tokens, 0)
|
|
||||||
|
|
||||||
def test_build_multiple_embedding_response(self):
|
|
||||||
"""Test building response for multiple embeddings."""
|
|
||||||
ret_data = [
|
|
||||||
{
|
|
||||||
"embedding": [0.1, 0.2, 0.3],
|
|
||||||
"meta_info": {"prompt_tokens": 3},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"embedding": [0.4, 0.5, 0.6],
|
|
||||||
"meta_info": {"prompt_tokens": 4},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
response = self.serving_embedding._build_embedding_response(ret_data)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, EmbeddingResponse)
|
|
||||||
self.assertEqual(len(response.data), 2)
|
|
||||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
|
||||||
self.assertEqual(response.data[0].index, 0)
|
|
||||||
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
|
|
||||||
self.assertEqual(response.data[1].index, 1)
|
|
||||||
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
|
||||||
self.assertEqual(response.usage.total_tokens, 7)
|
|
||||||
|
|
||||||
def test_handle_request_success(self):
|
|
||||||
"""Test successful embedding request handling."""
|
|
||||||
|
|
||||||
async def run_test():
|
|
||||||
# Mock the generate_request to return expected data
|
|
||||||
async def mock_generate():
|
|
||||||
yield {
|
|
||||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
||||||
"meta_info": {"prompt_tokens": 5},
|
|
||||||
}
|
|
||||||
|
|
||||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
|
||||||
return_value=mock_generate()
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.serving_embedding.handle_request(
|
|
||||||
self.basic_req, self.request
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, EmbeddingResponse)
|
|
||||||
self.assertEqual(len(response.data), 1)
|
|
||||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
|
||||||
|
|
||||||
asyncio.run(run_test())
|
|
||||||
|
|
||||||
def test_handle_request_validation_error(self):
|
|
||||||
"""Test handling request with validation error."""
|
|
||||||
|
|
||||||
async def run_test():
|
|
||||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
|
||||||
|
|
||||||
response = await self.serving_embedding.handle_request(
|
|
||||||
invalid_request, self.request
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, ORJSONResponse)
|
|
||||||
self.assertEqual(response.status_code, 400)
|
|
||||||
|
|
||||||
asyncio.run(run_test())
|
|
||||||
|
|
||||||
def test_handle_request_generation_error(self):
|
|
||||||
"""Test handling request with generation error."""
|
|
||||||
|
|
||||||
async def run_test():
|
|
||||||
# Mock generate_request to raise an error
|
|
||||||
async def mock_generate_error():
|
|
||||||
raise ValueError("Generation failed")
|
|
||||||
yield # This won't be reached but needed for async generator
|
|
||||||
|
|
||||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
|
||||||
return_value=mock_generate_error()
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.serving_embedding.handle_request(
|
|
||||||
self.basic_req, self.request
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, ORJSONResponse)
|
|
||||||
self.assertEqual(response.status_code, 400)
|
|
||||||
|
|
||||||
asyncio.run(run_test())
|
|
||||||
|
|
||||||
def test_handle_request_internal_error(self):
|
|
||||||
"""Test handling request with internal server error."""
|
|
||||||
|
|
||||||
async def run_test():
|
|
||||||
# Mock _convert_to_internal_request to raise an exception
|
|
||||||
with patch.object(
|
|
||||||
self.serving_embedding,
|
|
||||||
"_convert_to_internal_request",
|
|
||||||
side_effect=Exception("Internal error"),
|
|
||||||
):
|
|
||||||
response = await self.serving_embedding.handle_request(
|
|
||||||
self.basic_req, self.request
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, ORJSONResponse)
|
|
||||||
self.assertEqual(response.status_code, 500)
|
|
||||||
|
|
||||||
asyncio.run(run_test())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ suites = {
|
|||||||
TestFile("models/test_reward_models.py", 132),
|
TestFile("models/test_reward_models.py", 132),
|
||||||
TestFile("models/test_vlm_models.py", 437),
|
TestFile("models/test_vlm_models.py", 437),
|
||||||
TestFile("models/test_transformers_models.py", 320),
|
TestFile("models/test_transformers_models.py", 320),
|
||||||
|
TestFile("openai/test_protocol.py", 10),
|
||||||
|
TestFile("openai/test_serving_chat.py", 10),
|
||||||
|
TestFile("openai/test_serving_completions.py", 10),
|
||||||
|
TestFile("openai/test_serving_embedding.py", 10),
|
||||||
TestFile("test_abort.py", 51),
|
TestFile("test_abort.py", 51),
|
||||||
TestFile("test_block_int8.py", 22),
|
TestFile("test_block_int8.py", 22),
|
||||||
TestFile("test_create_kvindices.py", 2),
|
TestFile("test_create_kvindices.py", 2),
|
||||||
@@ -49,6 +53,7 @@ suites = {
|
|||||||
TestFile("test_hidden_states.py", 55),
|
TestFile("test_hidden_states.py", 55),
|
||||||
TestFile("test_int8_kernel.py", 8),
|
TestFile("test_int8_kernel.py", 8),
|
||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
|
TestFile("test_jinja_template_utils.py", 1),
|
||||||
TestFile("test_json_constrained.py", 98),
|
TestFile("test_json_constrained.py", 98),
|
||||||
TestFile("test_large_max_new_tokens.py", 41),
|
TestFile("test_large_max_new_tokens.py", 41),
|
||||||
TestFile("test_metrics.py", 32),
|
TestFile("test_metrics.py", 32),
|
||||||
@@ -59,14 +64,8 @@ suites = {
|
|||||||
TestFile("test_mla_fp8.py", 93),
|
TestFile("test_mla_fp8.py", 93),
|
||||||
TestFile("test_no_chunked_prefill.py", 108),
|
TestFile("test_no_chunked_prefill.py", 108),
|
||||||
TestFile("test_no_overlap_scheduler.py", 234),
|
TestFile("test_no_overlap_scheduler.py", 234),
|
||||||
TestFile("test_openai_adapter.py", 1),
|
|
||||||
TestFile("test_openai_function_calling.py", 60),
|
TestFile("test_openai_function_calling.py", 60),
|
||||||
TestFile("test_openai_server.py", 149),
|
TestFile("test_openai_server.py", 149),
|
||||||
TestFile("openai/test_server.py", 120),
|
|
||||||
TestFile("openai/test_protocol.py", 60),
|
|
||||||
TestFile("openai/test_serving_chat.py", 120),
|
|
||||||
TestFile("openai/test_serving_completions.py", 120),
|
|
||||||
TestFile("openai/test_serving_embedding.py", 120),
|
|
||||||
TestFile("test_openai_server_hidden_states.py", 240),
|
TestFile("test_openai_server_hidden_states.py", 240),
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
TestFile("test_page_size.py", 60),
|
TestFile("test_page_size.py", 60),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import unittest
|
|||||||
|
|
||||||
from xgrammar import GrammarCompiler, TokenizerInfo
|
from xgrammar import GrammarCompiler, TokenizerInfo
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Function, Tool
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
@@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
|
|||||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.openai_api.protocol import Function, Tool
|
|
||||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils.
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from sglang.srt.openai_api.utils import (
|
from sglang.srt.jinja_template_utils import (
|
||||||
detect_template_content_format,
|
detect_jinja_template_content_format,
|
||||||
process_content_for_template_format,
|
process_content_for_template_format,
|
||||||
)
|
)
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
@@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
|
|||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = detect_template_content_format(llama4_pattern)
|
result = detect_jinja_template_content_format(llama4_pattern)
|
||||||
self.assertEqual(result, "openai")
|
self.assertEqual(result, "openai")
|
||||||
|
|
||||||
def test_detect_deepseek_string_format(self):
|
def test_detect_deepseek_string_format(self):
|
||||||
@@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase):
|
|||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = detect_template_content_format(deepseek_pattern)
|
result = detect_jinja_template_content_format(deepseek_pattern)
|
||||||
self.assertEqual(result, "string")
|
self.assertEqual(result, "string")
|
||||||
|
|
||||||
def test_detect_invalid_template(self):
|
def test_detect_invalid_template(self):
|
||||||
"""Test handling of invalid template (should default to 'string')."""
|
"""Test handling of invalid template (should default to 'string')."""
|
||||||
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
|
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
|
||||||
|
|
||||||
result = detect_template_content_format(invalid_pattern)
|
result = detect_jinja_template_content_format(invalid_pattern)
|
||||||
self.assertEqual(result, "string")
|
self.assertEqual(result, "string")
|
||||||
|
|
||||||
def test_detect_empty_template(self):
|
def test_detect_empty_template(self):
|
||||||
"""Test handling of empty template (should default to 'string')."""
|
"""Test handling of empty template (should default to 'string')."""
|
||||||
result = detect_template_content_format("")
|
result = detect_jinja_template_content_format("")
|
||||||
self.assertEqual(result, "string")
|
self.assertEqual(result, "string")
|
||||||
|
|
||||||
def test_process_content_openai_format(self):
|
def test_process_content_openai_format(self):
|
||||||
@@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
is_firsts = {}
|
is_firsts = {}
|
||||||
|
is_finished = {}
|
||||||
for response in generator:
|
for response in generator:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
if usage is not None:
|
if usage is not None:
|
||||||
@@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
index = response.choices[0].index
|
index = response.choices[0].index
|
||||||
|
finish_reason = response.choices[0].finish_reason
|
||||||
|
if finish_reason is not None:
|
||||||
|
is_finished[index] = True
|
||||||
|
|
||||||
data = response.choices[0].delta
|
data = response.choices[0].delta
|
||||||
|
|
||||||
if is_firsts.get(index, True):
|
if is_firsts.get(index, True):
|
||||||
@@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
is_firsts[index] = False
|
is_firsts[index] = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if logprobs:
|
if logprobs and not is_finished.get(index, False):
|
||||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||||
@@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
assert (
|
assert (
|
||||||
isinstance(data.content, str)
|
isinstance(data.content, str)
|
||||||
or isinstance(data.reasoning_content, str)
|
or isinstance(data.reasoning_content, str)
|
||||||
or len(data.tool_calls) > 0
|
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
|
||||||
or response.choices[0].finish_reason
|
or response.choices[0].finish_reason
|
||||||
)
|
)
|
||||||
assert response.id
|
assert response.id
|
||||||
@@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
index, True
|
index, True
|
||||||
), f"index {index} is not found in the response"
|
), f"index {index} is not found in the response"
|
||||||
|
|
||||||
def _create_batch(self, mode, client):
|
|
||||||
if mode == "completion":
|
|
||||||
input_file_path = "complete_input.jsonl"
|
|
||||||
# write content to input file
|
|
||||||
content = [
|
|
||||||
{
|
|
||||||
"custom_id": "request-1",
|
|
||||||
"method": "POST",
|
|
||||||
"url": "/v1/completions",
|
|
||||||
"body": {
|
|
||||||
"model": "gpt-3.5-turbo-instruct",
|
|
||||||
"prompt": "List 3 names of famous soccer player: ",
|
|
||||||
"max_tokens": 20,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"custom_id": "request-2",
|
|
||||||
"method": "POST",
|
|
||||||
"url": "/v1/completions",
|
|
||||||
"body": {
|
|
||||||
"model": "gpt-3.5-turbo-instruct",
|
|
||||||
"prompt": "List 6 names of famous basketball player: ",
|
|
||||||
"max_tokens": 40,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"custom_id": "request-3",
|
|
||||||
"method": "POST",
|
|
||||||
"url": "/v1/completions",
|
|
||||||
"body": {
|
|
||||||
"model": "gpt-3.5-turbo-instruct",
|
|
||||||
"prompt": "List 6 names of famous tenniss player: ",
|
|
||||||
"max_tokens": 40,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
else:
|
|
||||||
input_file_path = "chat_input.jsonl"
|
|
||||||
content = [
|
|
||||||
{
|
|
||||||
"custom_id": "request-1",
|
|
||||||
"method": "POST",
|
|
||||||
"url": "/v1/chat/completions",
|
|
||||||
"body": {
|
|
||||||
"model": "gpt-3.5-turbo-0125",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello! List 3 NBA players and tell a story",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"max_tokens": 30,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"custom_id": "request-2",
|
|
||||||
"method": "POST",
|
|
||||||
"url": "/v1/chat/completions",
|
|
||||||
"body": {
|
|
||||||
"model": "gpt-3.5-turbo-0125",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are an assistant. "},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello! List three capital and tell a story",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"max_tokens": 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
with open(input_file_path, "w") as file:
|
|
||||||
for line in content:
|
|
||||||
file.write(json.dumps(line) + "\n")
|
|
||||||
|
|
||||||
with open(input_file_path, "rb") as file:
|
|
||||||
uploaded_file = client.files.create(file=file, purpose="batch")
|
|
||||||
if mode == "completion":
|
|
||||||
endpoint = "/v1/completions"
|
|
||||||
elif mode == "chat":
|
|
||||||
endpoint = "/v1/chat/completions"
|
|
||||||
completion_window = "24h"
|
|
||||||
batch_job = client.batches.create(
|
|
||||||
input_file_id=uploaded_file.id,
|
|
||||||
endpoint=endpoint,
|
|
||||||
completion_window=completion_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
return batch_job, content, uploaded_file
|
|
||||||
|
|
||||||
def run_batch(self, mode):
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
|
|
||||||
|
|
||||||
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
|
||||||
time.sleep(3)
|
|
||||||
print(
|
|
||||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
|
||||||
)
|
|
||||||
batch_job = client.batches.retrieve(batch_job.id)
|
|
||||||
assert (
|
|
||||||
batch_job.status == "completed"
|
|
||||||
), f"Batch job status is not completed: {batch_job.status}"
|
|
||||||
assert batch_job.request_counts.completed == len(content)
|
|
||||||
assert batch_job.request_counts.failed == 0
|
|
||||||
assert batch_job.request_counts.total == len(content)
|
|
||||||
|
|
||||||
result_file_id = batch_job.output_file_id
|
|
||||||
file_response = client.files.content(result_file_id)
|
|
||||||
result_content = file_response.read().decode("utf-8") # Decode bytes to string
|
|
||||||
results = [
|
|
||||||
json.loads(line)
|
|
||||||
for line in result_content.split("\n")
|
|
||||||
if line.strip() != ""
|
|
||||||
]
|
|
||||||
assert len(results) == len(content)
|
|
||||||
for delete_fid in [uploaded_file.id, result_file_id]:
|
|
||||||
del_pesponse = client.files.delete(delete_fid)
|
|
||||||
assert del_pesponse.deleted
|
|
||||||
|
|
||||||
def run_cancel_batch(self, mode):
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
|
|
||||||
|
|
||||||
assert batch_job.status not in ["cancelling", "cancelled"]
|
|
||||||
|
|
||||||
batch_job = client.batches.cancel(batch_id=batch_job.id)
|
|
||||||
assert batch_job.status == "cancelling"
|
|
||||||
|
|
||||||
while batch_job.status not in ["failed", "cancelled"]:
|
|
||||||
batch_job = client.batches.retrieve(batch_job.id)
|
|
||||||
print(
|
|
||||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
|
||||||
)
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
assert batch_job.status == "cancelled"
|
|
||||||
del_response = client.files.delete(uploaded_file.id)
|
|
||||||
assert del_response.deleted
|
|
||||||
|
|
||||||
def test_completion(self):
|
def test_completion(self):
|
||||||
for echo in [False, True]:
|
for echo in [False, True]:
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
@@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
for parallel_sample_num in [1, 2]:
|
for parallel_sample_num in [1, 2]:
|
||||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||||
|
|
||||||
def test_batch(self):
|
|
||||||
for mode in ["completion", "chat"]:
|
|
||||||
self.run_batch(mode)
|
|
||||||
|
|
||||||
def test_cancel_batch(self):
|
|
||||||
for mode in ["completion", "chat"]:
|
|
||||||
self.run_cancel_batch(mode)
|
|
||||||
|
|
||||||
def test_regex(self):
|
def test_regex(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
@@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white
|
|||||||
assert len(models) == 1
|
assert len(models) == 1
|
||||||
assert isinstance(getattr(models[0], "max_model_len", None), int)
|
assert isinstance(getattr(models[0], "max_model_len", None), int)
|
||||||
|
|
||||||
|
def test_retrieve_model(self):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
# Test retrieving an existing model
|
||||||
|
retrieved_model = client.models.retrieve(self.model)
|
||||||
|
self.assertEqual(retrieved_model.id, self.model)
|
||||||
|
self.assertEqual(retrieved_model.root, self.model)
|
||||||
|
|
||||||
|
# Test retrieving a non-existent model
|
||||||
|
with self.assertRaises(openai.NotFoundError):
|
||||||
|
client.models.retrieve("non-existent-model")
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# EBNF Test Class: TestOpenAIServerEBNF
|
# EBNF Test Class: TestOpenAIServerEBNF
|
||||||
@@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase):
|
|||||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
self.assertTrue(len(response.data[1].embedding) > 0)
|
self.assertTrue(len(response.data[1].embedding) > 0)
|
||||||
|
|
||||||
|
def test_embedding_single_batch_str(self):
|
||||||
|
"""Test embedding with a List[str] and length equals to 1"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
|
|
||||||
|
def test_embedding_single_int_list(self):
|
||||||
|
"""Test embedding with a List[int] or List[List[int]]]"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
|
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
|
|
||||||
def test_empty_string_embedding(self):
|
def test_empty_string_embedding(self):
|
||||||
"""Test embedding an empty string."""
|
"""Test embedding an empty string."""
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from transformers import (
|
|||||||
from sglang import Engine
|
from sglang import Engine
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
|
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
@@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
MultimodalInputs,
|
MultimodalInputs,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from transformers import (
|
|||||||
|
|
||||||
from sglang import Engine
|
from sglang import Engine
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
|
|
||||||
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user