145 lines
5.4 KiB
Python
145 lines
5.4 KiB
Python
import logging
|
|
from http import HTTPStatus
|
|
from typing import List, Union
|
|
|
|
from fastapi import Request
|
|
|
|
from sglang.srt.entrypoints.openai.protocol import (
|
|
DetokenizeRequest,
|
|
DetokenizeResponse,
|
|
ErrorResponse,
|
|
TokenizeRequest,
|
|
TokenizeResponse,
|
|
)
|
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OpenAIServingTokenize(OpenAIServingBase):
|
|
"""Handler for /v1/tokenize requests"""
|
|
|
|
def _request_id_prefix(self) -> str:
|
|
return "tok-"
|
|
|
|
def _convert_to_internal_request(
|
|
self, request: TokenizeRequest, raw_request: Request
|
|
) -> tuple[TokenizeRequest, TokenizeRequest]:
|
|
return request, request
|
|
|
|
async def _handle_non_streaming_request(
|
|
self,
|
|
adapted_request: TokenizeRequest,
|
|
request: TokenizeRequest,
|
|
raw_request: Request,
|
|
) -> Union[TokenizeResponse, ErrorResponse]:
|
|
try:
|
|
tokenizer = self.tokenizer_manager.tokenizer
|
|
max_model_len = getattr(tokenizer, "model_max_length", -1)
|
|
|
|
if isinstance(request.prompt, str):
|
|
token_ids = tokenizer.encode(
|
|
request.prompt,
|
|
add_special_tokens=request.add_special_tokens,
|
|
)
|
|
tokens = token_ids
|
|
count = len(token_ids)
|
|
elif isinstance(request.prompt, list):
|
|
token_ids_list = [
|
|
tokenizer.encode(
|
|
text, add_special_tokens=request.add_special_tokens
|
|
)
|
|
for text in request.prompt
|
|
]
|
|
tokens = token_ids_list
|
|
count = [len(ids) for ids in token_ids_list]
|
|
else:
|
|
return self.create_error_response(
|
|
f"Invalid prompt type: {type(request.prompt)}. Expected str or List[str]."
|
|
)
|
|
|
|
return TokenizeResponse(
|
|
tokens=tokens, count=count, max_model_len=max_model_len
|
|
)
|
|
except Exception as e:
|
|
logger.error("Error during tokenization", exc_info=True)
|
|
return self.create_error_response(
|
|
f"Internal server error during tokenization: {e}",
|
|
err_type="InternalServerError",
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
)
|
|
|
|
|
|
class OpenAIServingDetokenize(OpenAIServingBase):
|
|
"""Handler for /v1/detokenize requests"""
|
|
|
|
def _request_id_prefix(self) -> str:
|
|
return "detok-"
|
|
|
|
def _convert_to_internal_request(
|
|
self, request: DetokenizeRequest, raw_request: Request
|
|
) -> tuple[DetokenizeRequest, DetokenizeRequest]:
|
|
return request, request
|
|
|
|
async def _handle_non_streaming_request(
|
|
self,
|
|
adapted_request: DetokenizeRequest,
|
|
request: DetokenizeRequest,
|
|
raw_request: Request,
|
|
) -> Union[DetokenizeResponse, ErrorResponse]:
|
|
try:
|
|
tokenizer = self.tokenizer_manager.tokenizer
|
|
|
|
if (
|
|
isinstance(request.tokens, list)
|
|
and request.tokens
|
|
and isinstance(request.tokens[0], int)
|
|
):
|
|
if not all(isinstance(t, int) for t in request.tokens):
|
|
return self.create_error_response(
|
|
"Invalid input: 'tokens' must be a list of integers."
|
|
)
|
|
tokens_to_decode = [int(t) for t in request.tokens]
|
|
text = tokenizer.decode(
|
|
tokens_to_decode, skip_special_tokens=request.skip_special_tokens
|
|
)
|
|
text_out: Union[str, List[str]] = text
|
|
elif (
|
|
isinstance(request.tokens, list)
|
|
and request.tokens
|
|
and isinstance(request.tokens[0], list)
|
|
):
|
|
texts: List[str] = []
|
|
for token_list in request.tokens:
|
|
if not all(isinstance(t, int) for t in token_list):
|
|
return self.create_error_response(
|
|
f"Invalid input: Sublist in 'tokens' must contain only integers. Found: {token_list}"
|
|
)
|
|
decoded_text = tokenizer.decode(
|
|
[int(t) for t in token_list],
|
|
skip_special_tokens=request.skip_special_tokens,
|
|
)
|
|
texts.append(decoded_text)
|
|
text_out = texts
|
|
elif isinstance(request.tokens, list) and not request.tokens:
|
|
text_out = ""
|
|
else:
|
|
return self.create_error_response(
|
|
f"Invalid tokens type: {type(request.tokens)}. Expected List[int] or List[List[int]]."
|
|
)
|
|
|
|
return DetokenizeResponse(text=text_out)
|
|
except Exception as e:
|
|
logger.error("Error during detokenization", exc_info=True)
|
|
if "decode" in str(e).lower():
|
|
return self.create_error_response(
|
|
f"Error decoding tokens: {e}. Input tokens might be invalid for the model.",
|
|
err_type="DecodeError",
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
)
|
|
return self.create_error_response(
|
|
f"Internal server error during detokenization: {e}",
|
|
err_type="InternalServerError",
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
)
|