Files
sglang/python/sglang/srt/entrypoints/openai/serving_tokenize.py

145 lines
5.4 KiB
Python

import logging
from http import HTTPStatus
from typing import List, Union
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import (
DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeRequest,
TokenizeResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
logger = logging.getLogger(__name__)
class OpenAIServingTokenize(OpenAIServingBase):
"""Handler for /v1/tokenize requests"""
def _request_id_prefix(self) -> str:
return "tok-"
def _convert_to_internal_request(
self, request: TokenizeRequest, raw_request: Request
) -> tuple[TokenizeRequest, TokenizeRequest]:
return request, request
async def _handle_non_streaming_request(
self,
adapted_request: TokenizeRequest,
request: TokenizeRequest,
raw_request: Request,
) -> Union[TokenizeResponse, ErrorResponse]:
try:
tokenizer = self.tokenizer_manager.tokenizer
max_model_len = getattr(tokenizer, "model_max_length", -1)
if isinstance(request.prompt, str):
token_ids = tokenizer.encode(
request.prompt,
add_special_tokens=request.add_special_tokens,
)
tokens = token_ids
count = len(token_ids)
elif isinstance(request.prompt, list):
token_ids_list = [
tokenizer.encode(
text, add_special_tokens=request.add_special_tokens
)
for text in request.prompt
]
tokens = token_ids_list
count = [len(ids) for ids in token_ids_list]
else:
return self.create_error_response(
f"Invalid prompt type: {type(request.prompt)}. Expected str or List[str]."
)
return TokenizeResponse(
tokens=tokens, count=count, max_model_len=max_model_len
)
except Exception as e:
logger.error("Error during tokenization", exc_info=True)
return self.create_error_response(
f"Internal server error during tokenization: {e}",
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
class OpenAIServingDetokenize(OpenAIServingBase):
"""Handler for /v1/detokenize requests"""
def _request_id_prefix(self) -> str:
return "detok-"
def _convert_to_internal_request(
self, request: DetokenizeRequest, raw_request: Request
) -> tuple[DetokenizeRequest, DetokenizeRequest]:
return request, request
async def _handle_non_streaming_request(
self,
adapted_request: DetokenizeRequest,
request: DetokenizeRequest,
raw_request: Request,
) -> Union[DetokenizeResponse, ErrorResponse]:
try:
tokenizer = self.tokenizer_manager.tokenizer
if (
isinstance(request.tokens, list)
and request.tokens
and isinstance(request.tokens[0], int)
):
if not all(isinstance(t, int) for t in request.tokens):
return self.create_error_response(
"Invalid input: 'tokens' must be a list of integers."
)
tokens_to_decode = [int(t) for t in request.tokens]
text = tokenizer.decode(
tokens_to_decode, skip_special_tokens=request.skip_special_tokens
)
text_out: Union[str, List[str]] = text
elif (
isinstance(request.tokens, list)
and request.tokens
and isinstance(request.tokens[0], list)
):
texts: List[str] = []
for token_list in request.tokens:
if not all(isinstance(t, int) for t in token_list):
return self.create_error_response(
f"Invalid input: Sublist in 'tokens' must contain only integers. Found: {token_list}"
)
decoded_text = tokenizer.decode(
[int(t) for t in token_list],
skip_special_tokens=request.skip_special_tokens,
)
texts.append(decoded_text)
text_out = texts
elif isinstance(request.tokens, list) and not request.tokens:
text_out = ""
else:
return self.create_error_response(
f"Invalid tokens type: {type(request.tokens)}. Expected List[int] or List[List[int]]."
)
return DetokenizeResponse(text=text_out)
except Exception as e:
logger.error("Error during detokenization", exc_info=True)
if "decode" in str(e).lower():
return self.create_error_response(
f"Error decoding tokens: {e}. Input tokens might be invalid for the model.",
err_type="DecodeError",
status_code=HTTPStatus.BAD_REQUEST,
)
return self.create_error_response(
f"Internal server error during detokenization: {e}",
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)