[Feature] Add /tokenize and /detokenize OpenAI compatible endpoints (#9545)
This commit is contained in:
committed by
GitHub
parent
edd86b8853
commit
7c3f07dbcb
@@ -52,12 +52,14 @@ from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingRequest,
|
||||
ErrorResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ResponsesRequest,
|
||||
ScoringRequest,
|
||||
TokenizeRequest,
|
||||
V1RerankReqInput,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
@@ -65,6 +67,10 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
||||
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
|
||||
from sglang.srt.entrypoints.openai.serving_tokenize import (
|
||||
OpenAIServingDetokenize,
|
||||
OpenAIServingTokenize,
|
||||
)
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -229,6 +235,12 @@ async def lifespan(fast_api_app: FastAPI):
|
||||
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
|
||||
_global_state.tokenizer_manager
|
||||
)
|
||||
fast_api_app.state.openai_serving_tokenize = OpenAIServingTokenize(
|
||||
_global_state.tokenizer_manager
|
||||
)
|
||||
fast_api_app.state.openai_serving_detokenize = OpenAIServingDetokenize(
|
||||
_global_state.tokenizer_manager
|
||||
)
|
||||
|
||||
server_args: ServerArgs = fast_api_app.server_args
|
||||
|
||||
@@ -1070,6 +1082,42 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
|
||||
)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/tokenize",
|
||||
response_class=ORJSONResponse,
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
)
|
||||
@app.post(
|
||||
"/tokenize",
|
||||
response_class=ORJSONResponse,
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def openai_v1_tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
"""OpenAI-compatible tokenization endpoint."""
|
||||
return await raw_request.app.state.openai_serving_tokenize.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/detokenize",
|
||||
response_class=ORJSONResponse,
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
)
|
||||
@app.post(
|
||||
"/detokenize",
|
||||
response_class=ORJSONResponse,
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def openai_v1_detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
"""OpenAI-compatible detokenization endpoint."""
|
||||
return await raw_request.app.state.openai_serving_detokenize.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/models", response_class=ORJSONResponse)
|
||||
async def available_models():
|
||||
"""Show available models. OpenAI-compatible endpoint."""
|
||||
|
||||
@@ -801,12 +801,50 @@ class RerankResponse(BaseModel):
|
||||
meta_info: Optional[dict] = None
|
||||
|
||||
|
||||
class TokenizeRequest(BaseModel):
|
||||
"""Request schema for the /tokenize endpoint."""
|
||||
|
||||
model: str = DEFAULT_MODEL_NAME
|
||||
prompt: Union[str, List[str]]
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description="whether to add model-specific special tokens (e.g. BOS/EOS) during encoding.",
|
||||
)
|
||||
|
||||
|
||||
class TokenizeResponse(BaseModel):
|
||||
"""Response schema for the /tokenize endpoint."""
|
||||
|
||||
tokens: Union[List[int], List[List[int]]]
|
||||
count: Union[int, List[int]]
|
||||
max_model_len: int
|
||||
|
||||
|
||||
class DetokenizeRequest(BaseModel):
|
||||
"""Request schema for the /detokenize endpoint."""
|
||||
|
||||
model: str = DEFAULT_MODEL_NAME
|
||||
tokens: Union[List[int], List[List[int]]]
|
||||
skip_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description="whether to exclude special tokens (e.g. padding or EOS) during decoding.",
|
||||
)
|
||||
|
||||
|
||||
class DetokenizeResponse(BaseModel):
|
||||
"""Response schema for the /detokenize endpoint."""
|
||||
|
||||
text: Union[str, List[str]]
|
||||
|
||||
|
||||
OpenAIServingRequest = Union[
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
EmbeddingRequest,
|
||||
ScoringRequest,
|
||||
V1RerankReqInput,
|
||||
TokenizeRequest,
|
||||
DetokenizeRequest,
|
||||
]
|
||||
|
||||
|
||||
|
||||
144
python/sglang/srt/entrypoints/openai/serving_tokenize.py
Normal file
144
python/sglang/srt/entrypoints/openai/serving_tokenize.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import List, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenize(OpenAIServingBase):
|
||||
"""Handler for /v1/tokenize requests"""
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "tok-"
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self, request: TokenizeRequest, raw_request: Request
|
||||
) -> tuple[TokenizeRequest, TokenizeRequest]:
|
||||
return request, request
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: TokenizeRequest,
|
||||
request: TokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[TokenizeResponse, ErrorResponse]:
|
||||
try:
|
||||
tokenizer = self.tokenizer_manager.tokenizer
|
||||
max_model_len = getattr(tokenizer, "model_max_length", -1)
|
||||
|
||||
if isinstance(request.prompt, str):
|
||||
token_ids = tokenizer.encode(
|
||||
request.prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
tokens = token_ids
|
||||
count = len(token_ids)
|
||||
elif isinstance(request.prompt, list):
|
||||
token_ids_list = [
|
||||
tokenizer.encode(
|
||||
text, add_special_tokens=request.add_special_tokens
|
||||
)
|
||||
for text in request.prompt
|
||||
]
|
||||
tokens = token_ids_list
|
||||
count = [len(ids) for ids in token_ids_list]
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Invalid prompt type: {type(request.prompt)}. Expected str or List[str]."
|
||||
)
|
||||
|
||||
return TokenizeResponse(
|
||||
tokens=tokens, count=count, max_model_len=max_model_len
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error during tokenization", exc_info=True)
|
||||
return self.create_error_response(
|
||||
f"Internal server error during tokenization: {e}",
|
||||
err_type="InternalServerError",
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServingDetokenize(OpenAIServingBase):
|
||||
"""Handler for /v1/detokenize requests"""
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "detok-"
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self, request: DetokenizeRequest, raw_request: Request
|
||||
) -> tuple[DetokenizeRequest, DetokenizeRequest]:
|
||||
return request, request
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: DetokenizeRequest,
|
||||
request: DetokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[DetokenizeResponse, ErrorResponse]:
|
||||
try:
|
||||
tokenizer = self.tokenizer_manager.tokenizer
|
||||
|
||||
if (
|
||||
isinstance(request.tokens, list)
|
||||
and request.tokens
|
||||
and isinstance(request.tokens[0], int)
|
||||
):
|
||||
if not all(isinstance(t, int) for t in request.tokens):
|
||||
return self.create_error_response(
|
||||
"Invalid input: 'tokens' must be a list of integers."
|
||||
)
|
||||
tokens_to_decode = [int(t) for t in request.tokens]
|
||||
text = tokenizer.decode(
|
||||
tokens_to_decode, skip_special_tokens=request.skip_special_tokens
|
||||
)
|
||||
text_out: Union[str, List[str]] = text
|
||||
elif (
|
||||
isinstance(request.tokens, list)
|
||||
and request.tokens
|
||||
and isinstance(request.tokens[0], list)
|
||||
):
|
||||
texts: List[str] = []
|
||||
for token_list in request.tokens:
|
||||
if not all(isinstance(t, int) for t in token_list):
|
||||
return self.create_error_response(
|
||||
f"Invalid input: Sublist in 'tokens' must contain only integers. Found: {token_list}"
|
||||
)
|
||||
decoded_text = tokenizer.decode(
|
||||
[int(t) for t in token_list],
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
)
|
||||
texts.append(decoded_text)
|
||||
text_out = texts
|
||||
elif isinstance(request.tokens, list) and not request.tokens:
|
||||
text_out = ""
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Invalid tokens type: {type(request.tokens)}. Expected List[int] or List[List[int]]."
|
||||
)
|
||||
|
||||
return DetokenizeResponse(text=text_out)
|
||||
except Exception as e:
|
||||
logger.error("Error during detokenization", exc_info=True)
|
||||
if "decode" in str(e).lower():
|
||||
return self.create_error_response(
|
||||
f"Error decoding tokens: {e}. Input tokens might be invalid for the model.",
|
||||
err_type="DecodeError",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
return self.create_error_response(
|
||||
f"Internal server error during detokenization: {e}",
|
||||
err_type="InternalServerError",
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
Reference in New Issue
Block a user